Trie<TValue> Generic
This code is adapted from a C# trie source I found on the web, attributed to Kerry D. Wong. In my version, each string in the
trie has a "payload" of generic type TValue. To use this trie to simply search for substrings, the payload could
always be set to true, as illustrated with the demo program.
One thing I changed is that the node base class supports two types of nodes, both of which can be mixed in the same
Trie. For the normal TrieNode, each node has an array of children which correspond exactly to the letters of the alphabet
beginning with a movable base character. This is the classic Trie arrangement. The array approach might be wasteful if
only a few widely dispersed letters are used. The SparseTrieNode uses the .NET dictionary class, although for all I
know, it may be just as wasteful.
Finally, this trie automatically adapts allow for storage of arbitrary Unicode strings.
The array at each node—which characterizes a trie—adjusts its base and length to accomodate the range of Unicode
characters which need to be stored at that node. This allows for case-sensitive matching, for example.
The C# 3.0 initialization syntax is handy for this trie, but enabling it requires a dummy implementation of IEnumerable
in order to compile. The CLR doesn't seem to call GetEnumerator() and I suggest that you don't try to enumerate with
its result either.
First the demo program:
02 | using System.Collections.Generic; |
08 | static Trie<String> value_trie = new Trie<String> |
11 | { "giraffe" , "tall" }, |
17 | static Trie< bool > simple_trie = new Trie< bool > |
25 | static void Main(String[] args) |
27 | String s = "Once upon a time, a rabbit met an ape in the woods." ; |
34 | foreach (String word in value_trie.AllSubstringValues(s)) |
35 | Console.WriteLine(word); |
43 | Console.WriteLine(simple_trie.AllSubstringValues(s).Any(e=>e)); |
45 | s = "Four score and seven years ago." ; |
48 | Console.WriteLine(simple_trie.AllSubstringValues(s).Any(e => e)); |
The trie class:
002 | using System.Collections.Generic; |
003 | using System.Diagnostics; |
006 | public class Trie<TValue> : System.Collections.IEnumerable, IEnumerable<Trie<TValue>.TrieNodeBase> |
008 | public abstract class TrieNodeBase |
010 | protected TValue m_value = default (TValue); |
014 | get { return m_value; } |
015 | set { m_value = value; } |
018 | public bool HasValue { get { return !Object.Equals(m_value, default (TValue)); } } |
019 | public abstract bool IsLeaf { get ; } |
021 | public abstract TrieNodeBase this [ char c] { get ; } |
023 | public abstract TrieNodeBase[] Nodes { get ; } |
025 | public abstract void SetLeaf(); |
027 | public abstract int ChildCount { get ; } |
029 | public abstract bool ShouldOptimize { get ; } |
031 | public abstract KeyValuePair<Char, TrieNodeBase>[] CharNodePairs(); |
033 | public abstract TrieNodeBase AddChild( char c, ref int node_count); |
036 | /// Includes current node value |
038 | /// <returns></returns> |
039 | public IEnumerable<TValue> SubsumedValues() |
044 | foreach (TrieNodeBase child in Nodes) |
046 | foreach (TValue t in child.SubsumedValues()) |
051 | /// Includes current node |
053 | /// <returns></returns> |
054 | public IEnumerable<TrieNodeBase> SubsumedNodes() |
058 | foreach (TrieNodeBase child in Nodes) |
060 | foreach (TrieNodeBase n in child.SubsumedNodes()) |
065 | /// Doesn't include current node |
067 | /// <returns></returns> |
068 | public IEnumerable<TrieNodeBase> SubsumedNodesExceptThis() |
071 | foreach (TrieNodeBase child in Nodes) |
073 | foreach (TrieNodeBase n in child.SubsumedNodes()) |
078 | /// Note: doesn't de-optimize optimized nodes if re-run later |
080 | public void OptimizeChildNodes() |
083 | foreach (var q in CharNodePairs()) |
085 | TrieNodeBase n_old = q.Value; |
086 | if (n_old.ShouldOptimize) |
088 | TrieNodeBase n_new = new SparseTrieNode(n_old.CharNodePairs()); |
089 | n_new.m_value = n_old.m_value; |
090 | Trie<TValue>.c_sparse_nodes++; |
091 | ReplaceChild(q.Key, n_new); |
093 | n_old.OptimizeChildNodes(); |
097 | public abstract void ReplaceChild(Char c, TrieNodeBase n); |
101 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// |
105 | /// currently, this one's "nodes" value is never null, because we leave leaf nodes as the non-sparse type, |
106 | /// (with nodes==null) and they currently never get converted back. Consequently, IsLeaf should always be 'false'. |
107 | /// However, we're gonna do the check anyway. |
109 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// |
110 | public class SparseTrieNode : TrieNodeBase |
112 | Dictionary<Char, TrieNodeBase> d; |
114 | public SparseTrieNode(IEnumerable<KeyValuePair<Char, TrieNodeBase>> ie) |
116 | d = new Dictionary< char , TrieNodeBase>(); |
117 | foreach (var kvp in ie) |
118 | d.Add(kvp.Key, kvp.Value); |
121 | public override TrieNodeBase this [Char c] |
126 | return d.TryGetValue(c, out node) ? node : null ; |
130 | public override TrieNodeBase[] Nodes { get { return d.Values.ToArray(); } } |
133 | /// do not use in current form. This means, run OptimizeSparseNodes *after* any pruning |
135 | public override void SetLeaf() { d = null ; } |
137 | public override int ChildCount { get { return d.Count; } } |
139 | public override KeyValuePair<Char, TrieNodeBase>[] CharNodePairs() |
144 | public override TrieNodeBase AddChild( char c, ref int node_count) |
147 | if (!d.TryGetValue(c, out node)) |
149 | node = new TrieNode(); |
156 | public override void ReplaceChild(Char c, TrieNodeBase n) |
161 | public override bool ShouldOptimize { get { return false ; } } |
162 | public override bool IsLeaf { get { return d == null ; } } |
166 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// |
168 | /// Non-sparse Trie Node |
170 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// |
171 | public class TrieNode : TrieNodeBase |
173 | private TrieNodeBase[] nodes = null ; |
176 | public override int ChildCount { get { return (nodes != null ) ? nodes.Count(e => e != null ) : 0; } } |
177 | public int AllocatedChildCount { get { return (nodes != null ) ? nodes.Length : 0; } } |
179 | public override TrieNodeBase[] Nodes { get { return nodes; } } |
181 | public override void SetLeaf() { nodes = null ; } |
183 | public override KeyValuePair<Char, TrieNodeBase>[] CharNodePairs() |
185 | KeyValuePair<Char, TrieNodeBase>[] rg = new KeyValuePair< char , TrieNodeBase>[ChildCount]; |
188 | foreach (TrieNodeBase child in nodes) |
191 | rg[i++] = new KeyValuePair< char , TrieNodeBase>(ch, child); |
197 | public override TrieNodeBase this [ char c] |
201 | if (nodes != null && m_base <= c && c < m_base + nodes.Length) |
202 | return nodes[c - m_base]; |
207 | public override TrieNodeBase AddChild( char c, ref int node_count) |
212 | nodes = new TrieNodeBase[1]; |
214 | else if (c >= m_base + nodes.Length) |
216 | Array.Resize( ref nodes, c - m_base + 1); |
220 | Char c_new = (Char)(m_base - c); |
221 | TrieNodeBase[] tmp = new TrieNodeBase[nodes.Length + c_new]; |
222 | nodes.CopyTo(tmp, c_new); |
227 | TrieNodeBase node = nodes[c - m_base]; |
230 | node = new TrieNode(); |
232 | nodes[c - m_base] = node; |
237 | public override void ReplaceChild(Char c, TrieNodeBase n) |
239 | if (nodes == null || c >= m_base + nodes.Length || c < m_base) |
240 | throw new Exception(); |
241 | nodes[c - m_base] = n; |
244 | public override bool ShouldOptimize |
250 | return (ChildCount * 9 < nodes.Length); |
254 | public override bool IsLeaf { get { return nodes == null ; } } |
257 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// |
259 | /// Trie proper begins here |
261 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// |
263 | private TrieNodeBase _root = new TrieNode(); |
264 | public int c_nodes = 0; |
265 | public static int c_sparse_nodes = 0; |
268 | public System.Collections.IEnumerator GetEnumerator() |
270 | return _root.SubsumedNodes().GetEnumerator(); |
273 | IEnumerator<TrieNodeBase> IEnumerable<TrieNodeBase>.GetEnumerator() |
275 | return _root.SubsumedNodes().GetEnumerator(); |
278 | public IEnumerable<TValue> Values { get { return _root.SubsumedValues(); } } |
280 | public void OptimizeSparseNodes() |
282 | if (_root.ShouldOptimize) |
284 | _root = new SparseTrieNode(_root.CharNodePairs()); |
287 | _root.OptimizeChildNodes(); |
290 | public TrieNodeBase Root { get { return _root; } } |
292 | public TrieNodeBase Add(String s, TValue v) |
294 | TrieNodeBase node = _root; |
295 | foreach (Char c in s) |
296 | node = node.AddChild(c, ref c_nodes); |
302 | public bool Contains(String s) |
304 | TrieNodeBase node = _root; |
305 | foreach (Char c in s) |
311 | return node.HasValue; |
315 | /// Debug only; this is hideously inefficient |
317 | public String GetKey(TrieNodeBase seek) |
319 | String sofar = String.Empty; |
321 | GetKeyHelper fn = null ; |
322 | fn = (TrieNodeBase cur) => |
325 | foreach (var kvp in cur.CharNodePairs()) |
327 | Util.SetStringChar( ref sofar, sofar.Length - 1, kvp.Key); |
328 | if (kvp.Value == seek) |
330 | if (kvp.Value.Nodes != null && fn(kvp.Value)) |
333 | sofar = sofar.Substring(0, sofar.Length - 1); |
344 | /// Debug only; this is hideously inefficient |
346 | delegate bool GetKeyHelper(TrieNodeBase cur); |
347 | public String GetKey(TValue seek) |
349 | String sofar = String.Empty; |
351 | GetKeyHelper fn = null ; |
352 | fn = (TrieNodeBase cur) => |
355 | foreach (var kvp in cur.CharNodePairs()) |
357 | Util.SetStringChar( ref sofar, sofar.Length - 1, kvp.Key); |
358 | if (kvp.Value.Value != null && kvp.Value.Value.Equals(seek)) |
360 | if (kvp.Value.Nodes != null && fn(kvp.Value)) |
363 | sofar = sofar.Substring(0, sofar.Length - 1); |
372 | public TrieNodeBase FindNode(String s_in) |
374 | TrieNodeBase node = _root; |
375 | foreach (Char c in s_in) |
376 | if ((node = node[c]) == null ) |
382 | /// If continuation from the terminal node is possible with a different input string, then that node is not |
383 | /// returned as a 'last' node for the given input. In other words, 'last' nodes must be leaf nodes, where |
384 | /// continuation possibility is truly unknown. The presense of a nodes array that we couldn't match to |
385 | /// means the search fails; it is not the design of the 'OrLast' feature to provide 'closest' or 'best' |
386 | /// matching but rather to enable truncated tails still in the context of exact prefix matching. |
388 | public TrieNodeBase FindNodeOrLast(String s_in, out bool f_exact) |
390 | TrieNodeBase node = _root; |
391 | foreach (Char c in s_in) |
398 | if ((node = node[c]) == null ) |
410 | public unsafe TValue Find(String s_in) |
412 | TrieNodeBase node = _root; |
413 | fixed (Char* pin_s = s_in) |
416 | Char* p_end = p + s_in.Length; |
419 | if ((node = node[*p]) == null ) |
420 | return default (TValue); |
427 | public unsafe TValue Find(Char* p_tag, int cb_ctag) |
429 | TrieNodeBase node = _root; |
430 | Char* p_end = p_tag + cb_ctag; |
431 | while (p_tag < p_end) |
433 | if ((node = node[*p_tag]) == null ) |
434 | return default (TValue); |
440 | public IEnumerable<TValue> FindAll(String s_in) |
442 | TrieNodeBase node = _root; |
443 | foreach (Char c in s_in) |
445 | if ((node = node[c]) == null ) |
447 | if (node.Value != null ) |
448 | yield return node.Value; |
452 | public IEnumerable<TValue> SubsumedValues(String s) |
454 | TrieNodeBase node = FindNode(s); |
456 | return Enumerable.Empty<TValue>(); |
457 | return node.SubsumedValues(); |
460 | public IEnumerable<TrieNodeBase> SubsumedNodes(String s) |
462 | TrieNodeBase node = FindNode(s); |
464 | return Enumerable.Empty<TrieNodeBase>(); |
465 | return node.SubsumedNodes(); |
468 | public IEnumerable<TValue> AllSubstringValues(String s) |
471 | while (i_cur < s.Length) |
473 | TrieNodeBase node = _root; |
480 | if (node.Value != null ) |
481 | yield return node.Value; |
489 | /// note: only returns nodes with non-null values |
491 | public void DepthFirstTraverse(Action<String,TrieNodeBase> callback) |
493 | Char[] rgch = new Char[100]; |
496 | Action<TrieNodeBase> fn = null ; |
497 | fn = (TrieNodeBase cur) => |
499 | if (depth >= rgch.Length) |
501 | Char[] tmp = new Char[rgch.Length * 2]; |
502 | Buffer.BlockCopy(rgch, 0, tmp, 0, rgch.Length * sizeof (Char)); |
505 | foreach (var kvp in cur.CharNodePairs()) |
507 | rgch[depth] = kvp.Key; |
508 | TrieNodeBase n = kvp.Value; |
515 | else if (n.Value == null ) |
516 | throw new Exception(); |
519 | callback( new String(rgch, 0, depth+1), n); |
528 | /// note: only returns nodes with non-null values |
530 | public void EnumerateLeafPaths(Action<String,IEnumerable<TrieNodeBase>> callback) |
532 | Stack<TrieNodeBase> stk = new Stack<TrieNodeBase>(); |
533 | Char[] rgch = new Char[100]; |
535 | Action<TrieNodeBase> fn = null ; |
536 | fn = (TrieNodeBase cur) => |
538 | if (stk.Count >= rgch.Length) |
540 | Char[] tmp = new Char[rgch.Length * 2]; |
541 | Buffer.BlockCopy(rgch, 0, tmp, 0, rgch.Length * sizeof (Char)); |
544 | foreach (var kvp in cur.CharNodePairs()) |
546 | rgch[stk.Count] = kvp.Key; |
547 | TrieNodeBase n = kvp.Value; |
554 | throw new Exception(); |
555 | callback( new String(rgch, 0, stk.Count), stk); |
564 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// |
566 | /// Convert a trie with one value type to another |
568 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// |
569 | public Trie<TNew> ToTrie<TNew>(Func<TValue, TNew> value_converter) |
571 | Trie<TNew> t = new Trie<TNew>(); |
572 | DepthFirstTraverse((s,n)=>{ |
573 | t.Add(s,value_converter(n.Value)); |
579 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// |
583 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// |
584 | public static class TrieExtension |
586 | public static Trie<TValue> ToTrie<TValue>( this IEnumerable<String> src, Func<String, int , TValue> selector) |
588 | Trie<TValue> t = new Trie<TValue>(); |
590 | foreach (String s in src) |
591 | t.Add(s,selector(s,idx++)); |
595 | public static Trie<TValue> ToTrie<TValue>( this Dictionary<String, TValue> src) |
597 | Trie<TValue> t = new Trie<TValue>(); |
598 | foreach (var kvp in src) |
599 | t.Add(kvp.Key, kvp.Value); |
603 | public static IEnumerable<TValue> AllSubstringValues<TValue>( this String s, Trie<TValue> trie) |
605 | return trie.AllSubstringValues(s); |
608 | public static void AddToValueHashset<TKey, TValue>( this Dictionary<TKey, HashSet<TValue>> d, TKey k, TValue v) |
611 | if (d.TryGetValue(k, out hs)) |
614 | d.Add(k, new HashSet<TValue> { v }); |