Trie<TValue> Generic

This code is adapted from a C# trie source I found on the web, attributed to Kerry D. Wong. In my version, each string in the trie has a "payload" of generic type TValue. To use this trie to simply search for substrings, the payload could always be set to true, as illustrated with the demo program.

One thing I changed is that the node base class supports two types of nodes, both of which can be mixed in the same Trie. For the normal TrieNode, each node has an array of children which correspond exactly to the letters of the alphabet beginning with a movable base character. This is the classic Trie arrangement. The array approach might be wasteful if only a few widely dispersed letters are used. The SparseTrieNode uses the .NET dictionary class, although for all I know, it may be just as wasteful.

Finally, this trie automatically adapts allow for storage of arbitrary Unicode strings. The array at each node—which characterizes a trie—adjusts its base and length to accomodate the range of Unicode characters which need to be stored at that node. This allows for case-sensitive matching, for example.

The C# 3.0 initialization syntax is handy for this trie, but enabling it requires a dummy implementation of IEnumerable in order to compile. The CLR doesn't seem to call GetEnumerator() and I suggest that you don't try to enumerate with its result either.

First the demo program:
01using System;
02using System.Collections.Generic;
03using System.Linq;  // only used in Main()
04  
05class Program
06{
07    // trie with payload of type <String>
08    static Trie<String> value_trie = new Trie<String>
09    {
10        { "rabbit", "cute" },
11        { "giraffe", "tall" },
12        { "ape", "smart" },
13        { "hippo", "large" },
14    };
15  
16    // degenerate case of a trie without payload
17    static Trie<bool> simple_trie = new Trie<bool>
18    {
19        { "rabbit", true },
20        { "giraffe", true },
21        { "ape", true },
22        { "hippo", true },
23    };
24  
25    static void Main(String[] args)
26    {
27        String s = "Once upon a time, a rabbit met an ape in the woods.";
28  
29        // Retrieve payloads for words in the string.
30        //
31        // output:
32        //      cute
33        //      smart
34        foreach (String word in value_trie.AllSubstringValues(s))
35                Console.WriteLine(word);
36  
37        // Simply test a string for any of the words in the trie.
38        // Note that the Any() operator ensures that the input is no longer
39        // traversed once a single result is found.
40        //
41        // output:
42        //      True
43        Console.WriteLine(simple_trie.AllSubstringValues(s).Any(e=>e));
44  
45        s = "Four score and seven years ago.";
46        // output:
47        //      False
48        Console.WriteLine(simple_trie.AllSubstringValues(s).Any(e => e));
49    }
50}
The trie class:
001using System;
002using System.Collections.Generic;
003using System.Diagnostics;
004using System.Linq;
005 
006public class Trie<TValue> : System.Collections.IEnumerable, IEnumerable<Trie<TValue>.TrieNodeBase>
007{
008    public abstract class TrieNodeBase
009    {
010        protected TValue m_value = default(TValue);
011 
012        public TValue Value
013        {
014            get { return m_value; }
015            set { m_value = value; }
016        }
017 
018        public bool HasValue { get { return !Object.Equals(m_value, default(TValue)); } }
019        public abstract bool IsLeaf { get; }
020 
021        public abstract TrieNodeBase this[char c] { get; }
022 
023        public abstract TrieNodeBase[] Nodes { get; }
024 
025        public abstract void SetLeaf();
026 
027        public abstract int ChildCount { get; }
028 
029        public abstract bool ShouldOptimize { get; }
030 
031        public abstract KeyValuePair<Char, TrieNodeBase>[] CharNodePairs();
032 
033        public abstract TrieNodeBase AddChild(char c, ref int node_count);
034 
035        /// <summary>
036        /// Includes current node value
037        /// </summary>
038        /// <returns></returns>
039        public IEnumerable<TValue> SubsumedValues()
040        {
041            if (Value != null)
042                yield return Value;
043            if (Nodes != null)
044                foreach (TrieNodeBase child in Nodes)
045                    if (child != null)
046                        foreach (TValue t in child.SubsumedValues())
047                            yield return t;
048        }
049 
050        /// <summary>
051        /// Includes current node
052        /// </summary>
053        /// <returns></returns>
054        public IEnumerable<TrieNodeBase> SubsumedNodes()
055        {
056            yield return this;
057            if (Nodes != null)
058                foreach (TrieNodeBase child in Nodes)
059                    if (child != null)
060                        foreach (TrieNodeBase n in child.SubsumedNodes())
061                            yield return n;
062        }
063 
064        /// <summary>
065        /// Doesn't include current node
066        /// </summary>
067        /// <returns></returns>
068        public IEnumerable<TrieNodeBase> SubsumedNodesExceptThis()
069        {
070            if (Nodes != null)
071                foreach (TrieNodeBase child in Nodes)
072                    if (child != null)
073                        foreach (TrieNodeBase n in child.SubsumedNodes())
074                            yield return n;
075        }
076 
077        /// <summary>
078        /// Note: doesn't de-optimize optimized nodes if re-run later
079        /// </summary>
080        public void OptimizeChildNodes()
081        {
082            if (Nodes != null)
083                foreach (var q in CharNodePairs())
084                {
085                    TrieNodeBase n_old = q.Value;
086                    if (n_old.ShouldOptimize)
087                    {
088                        TrieNodeBase n_new = new SparseTrieNode(n_old.CharNodePairs());
089                        n_new.m_value = n_old.m_value;
090                        Trie<TValue>.c_sparse_nodes++;
091                        ReplaceChild(q.Key, n_new);
092                    }
093                    n_old.OptimizeChildNodes();
094                }
095        }
096 
097        public abstract void ReplaceChild(Char c, TrieNodeBase n);
098 
099    };
100 
101    ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
102    ///
103    /// Sparse Trie Node
104    ///
105    /// currently, this one's "nodes" value is never null, because we leave leaf nodes as the non-sparse type,
106    /// (with nodes==null) and they currently never get converted back. Consequently, IsLeaf should always be 'false'.
107    /// However, we're gonna do the check anyway.
108    ///
109    ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
110    public class SparseTrieNode : TrieNodeBase
111    {
112        Dictionary<Char, TrieNodeBase> d;
113 
114        public SparseTrieNode(IEnumerable<KeyValuePair<Char, TrieNodeBase>> ie)
115        {
116            d = new Dictionary<char, TrieNodeBase>();
117            foreach (var kvp in ie)
118                d.Add(kvp.Key, kvp.Value);
119        }
120 
121        public override TrieNodeBase this[Char c]
122        {
123            get
124            {
125                TrieNodeBase node;
126                return d.TryGetValue(c, out node) ? node : null;
127            }
128        }
129 
130        public override TrieNodeBase[] Nodes { get { return d.Values.ToArray(); } }
131 
132        /// <summary>
133        /// do not use in current form. This means, run OptimizeSparseNodes *after* any pruning
134        /// </summary>
135        public override void SetLeaf() { d = null; }
136 
137        public override int ChildCount { get { return d.Count; } }
138 
139        public override KeyValuePair<Char, TrieNodeBase>[] CharNodePairs()
140        {
141            return d.ToArray();
142        }
143 
144        public override TrieNodeBase AddChild(char c, ref int node_count)
145        {
146            TrieNodeBase node;
147            if (!d.TryGetValue(c, out node))
148            {
149                node = new TrieNode();
150                node_count++;
151                d.Add(c, node);
152            }
153            return node;
154        }
155 
156        public override void ReplaceChild(Char c, TrieNodeBase n)
157        {
158            d[c] = n;
159        }
160 
161        public override bool ShouldOptimize { get { return false; } }
162        public override bool IsLeaf { get { return d == null; } }
163 
164    };
165 
166    ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
167    ///
168    /// Non-sparse Trie Node
169    ///
170    ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
171    public class TrieNode : TrieNodeBase
172    {
173        private TrieNodeBase[] nodes = null;
174        private Char m_base;
175 
176        public override int ChildCount { get { return (nodes != null) ? nodes.Count(e => e != null) : 0; } }
177        public int AllocatedChildCount { get { return (nodes != null) ? nodes.Length : 0; } }
178 
179        public override TrieNodeBase[] Nodes { get { return nodes; } }
180 
181        public override void SetLeaf() { nodes = null; }
182 
183        public override KeyValuePair<Char, TrieNodeBase>[] CharNodePairs()
184        {
185            KeyValuePair<Char, TrieNodeBase>[] rg = new KeyValuePair<char, TrieNodeBase>[ChildCount];
186            Char ch = m_base;
187            int i = 0;
188            foreach (TrieNodeBase child in nodes)
189            {
190                if (child != null)
191                    rg[i++] = new KeyValuePair<char, TrieNodeBase>(ch, child);
192                ch++;
193            }
194            return rg;
195        }
196 
197        public override TrieNodeBase this[char c]
198        {
199            get
200            {
201                if (nodes != null && m_base <= c && c < m_base + nodes.Length)
202                    return nodes[c - m_base];
203                return null;
204            }
205        }
206 
207        public override TrieNodeBase AddChild(char c, ref int node_count)
208        {
209            if (nodes == null)
210            {
211                m_base = c;
212                nodes = new TrieNodeBase[1];
213            }
214            else if (c >= m_base + nodes.Length)
215            {
216                Array.Resize(ref nodes, c - m_base + 1);
217            }
218            else if (c < m_base)
219            {
220                Char c_new = (Char)(m_base - c);
221                TrieNodeBase[] tmp = new TrieNodeBase[nodes.Length + c_new];
222                nodes.CopyTo(tmp, c_new);
223                m_base = c;
224                nodes = tmp;
225            }
226 
227            TrieNodeBase node = nodes[c - m_base];
228            if (node == null)
229            {
230                node = new TrieNode();
231                node_count++;
232                nodes[c - m_base] = node;
233            }
234            return node;
235        }
236 
237        public override void ReplaceChild(Char c, TrieNodeBase n)
238        {
239            if (nodes == null || c >= m_base + nodes.Length || c < m_base)
240                throw new Exception();
241            nodes[c - m_base] = n;
242        }
243 
244        public override bool ShouldOptimize
245        {
246            get
247            {
248                if (nodes == null)
249                    return false;
250                return (ChildCount * 9 < nodes.Length);      // empirically determined optimal value (space & time)
251            }
252        }
253 
254        public override bool IsLeaf { get { return nodes == null; } }
255    };
256 
257    ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
258    ///
259    /// Trie proper begins here
260    ///
261    ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
262 
263    private TrieNodeBase _root = new TrieNode();
264    public int c_nodes = 0;
265    public static int c_sparse_nodes = 0;
266 
267    // in combination with Add(...), enables C# 3.0 initialization syntax, even though it never seems to call it
268    public System.Collections.IEnumerator GetEnumerator()
269    {
270        return _root.SubsumedNodes().GetEnumerator();
271    }
272 
273    IEnumerator<TrieNodeBase> IEnumerable<TrieNodeBase>.GetEnumerator()
274    {
275        return _root.SubsumedNodes().GetEnumerator();
276    }
277 
278    public IEnumerable<TValue> Values { get { return _root.SubsumedValues(); } }
279 
280    public void OptimizeSparseNodes()
281    {
282        if (_root.ShouldOptimize)
283        {
284            _root = new SparseTrieNode(_root.CharNodePairs());
285            c_sparse_nodes++;
286        }
287        _root.OptimizeChildNodes();
288    }
289 
290    public TrieNodeBase Root { get { return _root; } }
291 
292    public TrieNodeBase Add(String s, TValue v)
293    {
294        TrieNodeBase node = _root;
295        foreach (Char c in s)
296            node = node.AddChild(c,ref c_nodes);
297 
298        node.Value = v;
299        return node;
300    }
301 
302    public bool Contains(String s)
303    {
304        TrieNodeBase node = _root;
305        foreach (Char c in s)
306        {
307            node = node[c];
308            if (node == null)
309                return false;
310        }
311        return node.HasValue;
312    }
313 
314    /// <summary>
315    /// Debug only; this is hideously inefficient
316    /// </summary>
317    public String GetKey(TrieNodeBase seek)
318    {
319        String sofar = String.Empty;
320 
321        GetKeyHelper fn = null;
322        fn = (TrieNodeBase cur) =>
323        {
324            sofar += " ";   // placeholder
325            foreach (var kvp in cur.CharNodePairs())
326            {
327                Util.SetStringChar(ref sofar, sofar.Length - 1, kvp.Key);
328                if (kvp.Value == seek)
329                    return true;
330                if (kvp.Value.Nodes != null && fn(kvp.Value))
331                    return true;
332            }
333            sofar = sofar.Substring(0, sofar.Length - 1);
334            return false;
335        };
336 
337        if (fn(_root))
338            return sofar;
339        return null;
340    }
341 
342 
343    /// <summary>
344    /// Debug only; this is hideously inefficient
345    /// </summary>
346    delegate bool GetKeyHelper(TrieNodeBase cur);
347    public String GetKey(TValue seek)
348    {
349        String sofar = String.Empty;
350 
351        GetKeyHelper fn = null;
352        fn = (TrieNodeBase cur) =>
353             {
354                 sofar += " "// placeholder
355                 foreach (var kvp in cur.CharNodePairs())
356                 {
357                     Util.SetStringChar(ref sofar, sofar.Length - 1, kvp.Key);
358                     if (kvp.Value.Value != null && kvp.Value.Value.Equals(seek))
359                         return true;
360                     if (kvp.Value.Nodes != null && fn(kvp.Value))
361                         return true;
362                 }
363                 sofar = sofar.Substring(0, sofar.Length - 1);
364                 return false;
365             };
366 
367        if (fn(_root))
368            return sofar;
369        return null;
370    }
371 
372    public TrieNodeBase FindNode(String s_in)
373    {
374        TrieNodeBase node = _root;
375        foreach (Char c in s_in)
376            if ((node = node[c]) == null)
377                return null;
378        return node;
379    }
380 
381    /// <summary>
382    /// If continuation from the terminal node is possible with a different input string, then that node is not
383    /// returned as a 'last' node for the given input. In other words, 'last' nodes must be leaf nodes, where
384    /// continuation possibility is truly unknown. The presense of a nodes array that we couldn't match to
385    /// means the search fails; it is not the design of the 'OrLast' feature to provide 'closest' or 'best'
386    /// matching but rather to enable truncated tails still in the context of exact prefix matching.
387    /// </summary>
388    public TrieNodeBase FindNodeOrLast(String s_in, out bool f_exact)
389    {
390        TrieNodeBase node = _root;
391        foreach (Char c in s_in)
392        {
393            if (node.IsLeaf)
394            {
395                f_exact = false;
396                return node;
397            }
398            if ((node = node[c]) == null)
399            {
400                f_exact = false;
401                return null;
402            }
403        }
404        f_exact = true;
405        return node;
406    }
407 
408    // even though I found some articles that attest that using a foreach enumerator with arrays (and Lists)
409    // returns a value type, thus avoiding spurious garbage, I had already changed the code to not use enumerator.
410    public unsafe TValue Find(String s_in)
411    {
412        TrieNodeBase node = _root;
413        fixed (Char* pin_s = s_in)
414        {
415            Char* p = pin_s;
416            Char* p_end = p + s_in.Length;
417            while (p < p_end)
418            {
419                if ((node = node[*p]) == null)
420                    return default(TValue);
421                p++;
422            }
423            return node.Value;
424        }
425    }
426 
427    public unsafe TValue Find(Char* p_tag, int cb_ctag)
428    {
429        TrieNodeBase node = _root;
430        Char* p_end = p_tag + cb_ctag;
431        while (p_tag < p_end)
432        {
433            if ((node = node[*p_tag]) == null)
434                return default(TValue);
435            p_tag++;
436        }
437        return node.Value;
438    }
439 
440    public IEnumerable<TValue> FindAll(String s_in)
441    {
442        TrieNodeBase node = _root;
443        foreach (Char c in s_in)
444        {
445            if ((node = node[c]) == null)
446                break;
447            if (node.Value != null)
448                yield return node.Value;
449        }
450    }
451 
452    public IEnumerable<TValue> SubsumedValues(String s)
453    {
454        TrieNodeBase node = FindNode(s);
455        if (node == null)
456            return Enumerable.Empty<TValue>();
457        return node.SubsumedValues();
458    }
459 
460    public IEnumerable<TrieNodeBase> SubsumedNodes(String s)
461    {
462        TrieNodeBase node = FindNode(s);
463        if (node == null)
464            return Enumerable.Empty<TrieNodeBase>();
465        return node.SubsumedNodes();
466    }
467 
468    public IEnumerable<TValue> AllSubstringValues(String s)
469    {
470        int i_cur = 0;
471        while (i_cur < s.Length)
472        {
473            TrieNodeBase node = _root;
474            int i = i_cur;
475            while (i < s.Length)
476            {
477                node = node[s[i]];
478                if (node == null)
479                    break;
480                if (node.Value != null)
481                    yield return node.Value;
482                i++;
483            }
484            i_cur++;
485        }
486    }
487 
488    /// <summary>
489    /// note: only returns nodes with non-null values
490    /// </summary>
491    public void DepthFirstTraverse(Action<String,TrieNodeBase> callback)
492    {
493        Char[] rgch = new Char[100];
494        int depth = 0;
495 
496        Action<TrieNodeBase> fn = null;
497        fn = (TrieNodeBase cur) =>
498        {
499            if (depth >= rgch.Length)
500            {
501                Char[] tmp = new Char[rgch.Length * 2];
502                Buffer.BlockCopy(rgch, 0, tmp, 0, rgch.Length * sizeof(Char));
503                rgch = tmp;
504            }
505            foreach (var kvp in cur.CharNodePairs())
506            {
507                rgch[depth] = kvp.Key;
508                TrieNodeBase n = kvp.Value;
509                if (n.Nodes != null)
510                {
511                    depth++;
512                    fn(n);
513                    depth--;
514                }
515                else if (n.Value == null)       // leaf nodes should always have a value
516                    throw new Exception();
517 
518                if (n.Value != null)
519                    callback(new String(rgch, 0, depth+1), n);
520            }
521        };
522 
523        fn(_root);
524    }
525 
526 
527    /// <summary>
528    /// note: only returns nodes with non-null values
529    /// </summary>
530    public void EnumerateLeafPaths(Action<String,IEnumerable<TrieNodeBase>> callback)
531    {
532        Stack<TrieNodeBase> stk = new Stack<TrieNodeBase>();
533        Char[] rgch = new Char[100];
534 
535        Action<TrieNodeBase> fn = null;
536        fn = (TrieNodeBase cur) =>
537        {
538            if (stk.Count >= rgch.Length)
539            {
540                Char[] tmp = new Char[rgch.Length * 2];
541                Buffer.BlockCopy(rgch, 0, tmp, 0, rgch.Length * sizeof(Char));
542                rgch = tmp;
543            }
544            foreach (var kvp in cur.CharNodePairs())
545            {
546                rgch[stk.Count] = kvp.Key;
547                TrieNodeBase n = kvp.Value;
548                stk.Push(n);
549                if (n.Nodes != null)
550                    fn(n);
551                else
552                {
553                    if (n.Value == null)        // leaf nodes should always have a value
554                        throw new Exception();
555                    callback(new String(rgch, 0, stk.Count), stk);
556                }
557                stk.Pop();
558            }
559        };
560 
561        fn(_root);
562    }
563 
564    ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
565    ///
566    /// Convert a trie with one value type to another
567    ///
568    ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
569    public Trie<TNew> ToTrie<TNew>(Func<TValue, TNew> value_converter)
570    {
571        Trie<TNew> t = new Trie<TNew>();
572        DepthFirstTraverse((s,n)=>{
573            t.Add(s,value_converter(n.Value));
574        });
575        return t;
576    }
577};
578 
579///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
580///
581///
582///
583///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
584public static class TrieExtension
585{
586    public static Trie<TValue> ToTrie<TValue>(this IEnumerable<String> src, Func<String, int, TValue> selector)
587    {
588        Trie<TValue> t = new Trie<TValue>();
589        int idx = 0;
590        foreach (String s in src)
591            t.Add(s,selector(s,idx++));
592        return t;
593    }
594 
595    public static Trie<TValue> ToTrie<TValue>(this Dictionary<String, TValue> src)
596    {
597        Trie<TValue> t = new Trie<TValue>();
598        foreach (var kvp in src)
599            t.Add(kvp.Key, kvp.Value);
600        return t;
601    }
602 
603    public static IEnumerable<TValue> AllSubstringValues<TValue>(this String s, Trie<TValue> trie)
604    {
605        return trie.AllSubstringValues(s);
606    }
607 
608    public static void AddToValueHashset<TKey, TValue>(this Dictionary<TKey, HashSet<TValue>> d, TKey k, TValue v)
609    {
610        HashSet<TValue> hs;
611        if (d.TryGetValue(k, out hs))
612            hs.Add(v);
613        else
614            d.Add(k, new HashSet<TValue> { v });
615    }
616};