Skip to content

Commit

Permalink
Decrease allocation size for promote algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
Validark committed Jul 12, 2022
1 parent 101e7cb commit 345518f
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 87 deletions.
157 changes: 82 additions & 75 deletions c#/PruningRadixTrie/PruningRadixTrie.cs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ public class PruningRadixTrie
{ // sorry for the indentation, but we need to prevent right-ward shift as much as possible!
private readonly (Node node, int LCP)[] rootPeers = { default };

public int Count = 0;
public int termCount = 0;

// The min value of topK at which a Heap-based DEPQ is used instead of an insertion sorted array
public const int DEPQ_THRESHOLD = 125;
Expand All @@ -284,7 +284,7 @@ public void WriteTermsToFile(String path)
file.Write($"{term}\t{score}\n");

sw.Stop();
Console.WriteLine(Count.ToString("N0") + " terms written in " + sw.ElapsedMilliseconds.ToString("0,.##") + " seconds.");
Console.WriteLine(termCount.ToString("N0") + " terms written in " + sw.ElapsedMilliseconds.ToString("0,.##") + " seconds.");
isCacheValid = true;
}
catch (Exception e)
Expand All @@ -299,19 +299,19 @@ private static void GetAllTerms(List<(String term, score_int score)> results, No
foreach (var peer in node.peers) GetAllTerms(results, peer.node);
}

// We could just use GetTopkTermsForPrefix("", Count), but this should be faster
// We could just use GetTopkTermsForPrefix("", termCount), but this should be faster
public List<(String term, score_int score)> GetAllTerms()
{
var results = new List<(String term, score_int score)>(Count);
if (Count != 0) GetAllTerms(results, rootPeers[0].node);
var results = new List<(String term, score_int score)>(termCount);
if (termCount != 0) GetAllTerms(results, rootPeers[0].node);
results.Sort((a, b) => b.score.CompareTo(a.score));
return results;
}

public List<(String term, score_int score)> GetAllTermsUnsorted()
{
var results = new List<(String term, score_int score)>(Count);
if (Count != 0) GetAllTerms(results, rootPeers[0].node);
var results = new List<(String term, score_int score)>(termCount);
if (termCount != 0) GetAllTerms(results, rootPeers[0].node);
return results;
}

Expand All @@ -324,7 +324,7 @@ private static void GetAllTerms(List<(String term, score_int score)> results, No
public List<(String term, score_int score)> GetTopkTermsForPrefix(String prefix, int topK)
{
Node node = rootPeers[0].node;
var results = new List<(String, score_int)>(Math.Min(topK, Count));
var results = new List<(String, score_int)>(Math.Min(topK, termCount));

if (topK <= 0 || node.key == null) return results;
if (prefix == null) prefix = String.Empty;
Expand Down Expand Up @@ -580,7 +580,7 @@ private static void checkTrie(HashSet<(Node node, int LCP)[]> encounteredPeers,

public void validateTrie()
{
var nodeCount = Count;
var nodeCount = termCount;
var encounteredPeers = new HashSet<(Node node, int LCP)[]>(nodeCount);
var diffs = new HashSet<int>();

Expand Down Expand Up @@ -658,40 +658,47 @@ private static void ImmutableArrayRemove<T>(ref T[] arr, int index)
arr = newArray;
}

private static void ExtractSubKNodes(
// Extracts all nodes from source with LCP > peer.LCP
// Also ensures there is enough space in source for destinationCount nodes
private static void ExtractLCPsBelowThreshold(
ref (Node node, int LCP)[] source,
int LCP,
(Node node, int LCP)[] destination,
ref (Node node, int LCP)[] destination,
ref int destinationCount
)
{
int sourceLen = source.Length;
var newNodePeersLen = 0;
var nextSourceLength = 0;
var nextDestLength = destinationCount;

for (int i = 0; i < sourceLen; i++)
foreach (var peer in source)
if (LCP <= peer.LCP) nextSourceLength++;
else if (peer.LCP >= 0) nextDestLength++;

if (nextDestLength > destination.Length) // Run this condition no matter what because destinationCount could have been higher from the start of this function
{
var peer = source[i];
if (LCP <= peer.LCP) newNodePeersLen++;
else if (peer.LCP >= 0) destination[destinationCount++] = peer;
// compute the next highest power of 2 of 32-bit integer
nextDestLength--;
nextDestLength |= nextDestLength >> 1;
nextDestLength |= nextDestLength >> 2;
nextDestLength |= nextDestLength >> 4;
nextDestLength |= nextDestLength >> 8;
nextDestLength |= nextDestLength >> 16;
nextDestLength++;
Array.Resize(ref destination, nextDestLength);
}

if (newNodePeersLen != sourceLen)
if (nextSourceLength != source.Length) // only run if there needs to be some actual change
{
if (newNodePeersLen == 0) {
source = Array.Empty<(Node node, int LCP)>();
return;
}
var nextSource = nextSourceLength == 0 ? Array.Empty<(Node node, int LCP)>() : new (Node node, int LCP)[nextSourceLength];
nextSourceLength = 0;
nextDestLength = destinationCount;

var newNodePeers = new (Node node, int LCP)[newNodePeersLen];
foreach (var peer in source)
if (LCP <= peer.LCP) nextSource[nextSourceLength++] = peer;
else if (peer.LCP >= 0) destination[nextDestLength++] = peer;

do
{
var peer = source[--sourceLen];
if (LCP <= peer.LCP)
newNodePeers[--newNodePeersLen] = peer;
} while (newNodePeersLen != 0);

source = newNodePeers;
destinationCount = nextDestLength;
source = nextSource;
}
}

Expand Down Expand Up @@ -755,7 +762,7 @@ bool Delete(String key)
}

// If we broke down here, it means we found a `node` that represents `key`
Count -= 1;
termCount -= 1;

var nodesToPush = node.peers; // All these nodes need to be placed back in the trie
if (nodesToPush.Length == 0)
Expand Down Expand Up @@ -844,8 +851,8 @@ private score_int GetScoreForString(String key)
#endif
void AddTerm(String term, score_int score)
{
var newScore = unchecked((System.UInt64)GetScoreForString(term) + (System.UInt64)score);
Set(term, unchecked((long)Math.Min(newScore, score_int.MaxValue)));
var newScore = (ulong)GetScoreForString(term) + (ulong)score;
Set(term, (long)Math.Min(newScore, (ulong)score_int.MaxValue));
}

#if FORCE_STATIC
Expand All @@ -868,7 +875,7 @@ void Set(String term, score_int score)
if (node.key == null) // Degenerate case: trie is empty
{
parentPeers[indexInParent].node = new Node(term, score);
Count = 1;
termCount = 1;
return;
}

Expand Down Expand Up @@ -903,12 +910,7 @@ void Set(String term, score_int score)
if (score > node.score)
{ // if score is higher than the observed node's, insert new Node in its place
// and then traverse `node` to find the peers for new Node
var newPeers = new (Node node, int LCP)[termLength + 1 - prevLCP]; // maximum possible capacity

var newPeersCount = GetPeers(newPeers, node, LCP, term, LCP == prevLCP);

if (newPeersCount != newPeers.Length)
Array.Resize(ref newPeers, newPeersCount);
var newPeers = FindBranchPointsInNode(node, prevLCP, LCP, term);

#if COMPRESS_STRINGS
term = term.Substring(prevLCP);
Expand All @@ -928,7 +930,7 @@ void Set(String term, score_int score)

if (index < 0)
{ // Found a leaf! This Node has a unique LCP! Just add it to the list!
Count += 1;
termCount += 1;

#if COMPRESS_STRINGS
term = term.Substring(prevLCP);
Expand Down Expand Up @@ -1087,7 +1089,7 @@ private static void PushNodes(
right -= 1; // we already checked localMaximumsLastIndex - 1
while (localMaximumsIndex <= right)
{
int mid = unchecked((int)(((uint)localMaximumsIndex + (uint)right) >> 1)); // divides by two regardless of overflow
int mid = (int)(((uint)localMaximumsIndex + (uint)right) >> 1); // divides by two regardless of overflow

if (LCP < localMaximums[mid].LCP)
right = mid - 1;
Expand Down Expand Up @@ -1146,22 +1148,23 @@ private static int findBranch((Node node, int LCP)[] peers, int LCP)

// Fill newPeers with the peers intended for `term`
// If a node representing `term` is found somewhere in the trie, delete it
private int GetPeers((Node node, int LCP)[] newPeers, Node node, int LCP, String term, bool nodeMatchesTheSameAsPrevious)
{
// if (LCP == termLen && LCP == node.key.Length) <- This is impossible
private (Node node, int LCP)[] FindBranchPointsInNode(Node node, int prevLCP, int LCP, String term)
{ // if (LCP == termLen && LCP == node.key.Length) <- This is impossible
var termLength = term.Length;
var maximumPossibleCapacityOfNewPeers = termLength + 1 - prevLCP; // [1, inf)
var newPeers = new (Node node, int LCP)[Math.Min(8, maximumPossibleCapacityOfNewPeers)]; // because max could be massive, just default to 8
var newPeersCount = 1;

// Add `node` to newPeers, and its peers with diffs less than LCP
// This is guaranteed to be in sorted order
if (!nodeMatchesTheSameAsPrevious)
ExtractSubKNodes(ref node.peers, LCP, newPeers, ref newPeersCount);
if (LCP != prevLCP)
ExtractLCPsBelowThreshold(ref node.peers, LCP, ref newPeers, ref newPeersCount);

// This code inserts `node` after dealing with `node.peers` so that
// the `node.peers` pointer doesn't need to be written twice
// the `node.peers` pointer doesn't need to be set twice
newPeers[0] = (node, LCP);
(Node node, int LCP)[] grandPeers = newPeers;
int indexInGrand = 0;
var termLength = term.Length;
var nodeKeyLength = node.key.Length;
#if COMPRESS_STRINGS
nodeKeyLength += LCP;
Expand All @@ -1179,19 +1182,19 @@ private int GetPeers((Node node, int LCP)[] newPeers, Node node, int LCP, String
// if we reached then end of this horizontal line, terminate
if (indexInParent < 0)
{
Count += 1;
return newPeersCount;
termCount += 1;
if (newPeersCount != newPeers.Length) Array.Resize(ref newPeers, newPeersCount);
return newPeers;
}

node = parentPeers[indexInParent].node;
nodeKeyLength = node.key.Length;

if (LCP != node.key.Length &&
#if COMPRESS_STRINGS
term[LCP] == node.key[0]
#else
term[LCP] == node.key[LCP]
#endif
)
#if COMPRESS_STRINGS
if ( 0 != nodeKeyLength && term[LCP] == node.key[0])
#else
if (LCP != nodeKeyLength && term[LCP] == node.key[LCP])
#endif
{
SupplantNodeFromParentWithNextBranchingNode(
grandPeers,
Expand All @@ -1207,10 +1210,9 @@ private int GetPeers((Node node, int LCP)[] newPeers, Node node, int LCP, String
indexInGrand = indexInParent;
}

nodeKeyLength = node.key.Length;
#if COMPRESS_STRINGS
nodeKeyLength += LCP;
var startLCP = LCP;
nodeKeyLength += startLCP;
#else
var startLCP = 0;
#endif
Expand All @@ -1222,7 +1224,7 @@ private int GetPeers((Node node, int LCP)[] newPeers, Node node, int LCP, String
{ // if the characters `term` does not exactly match node
var extractIndex = newPeersCount;
newPeersCount += 1;
ExtractSubKNodes(ref node.peers, LCP, newPeers, ref newPeersCount);
ExtractLCPsBelowThreshold(ref node.peers, LCP, ref newPeers, ref newPeersCount);
#if COMPRESS_STRINGS
node = new Node(node.key.Substring(LCP - startLCP), node.score, node.peers);
#endif
Expand All @@ -1234,7 +1236,7 @@ private int GetPeers((Node node, int LCP)[] newPeers, Node node, int LCP, String

// LCP now equals termLength

// if `node` doesn't represent `term` at this point, go look for the right `node`
// if `node` doesn't represent `term` at this point, go find the node that does
if (LCP != nodeKeyLength)
{
while (true)
Expand All @@ -1245,8 +1247,9 @@ private int GetPeers((Node node, int LCP)[] newPeers, Node node, int LCP, String
// if we reached then end of this horizontal line, terminate
if (indexInParent < 0)
{
Count += 1;
return newPeersCount;
termCount += 1;
if (newPeersCount != newPeers.Length) Array.Resize(ref newPeers, newPeersCount);
return newPeers;
}

node = parentPeers[indexInParent].node;
Expand All @@ -1270,21 +1273,25 @@ private int GetPeers((Node node, int LCP)[] newPeers, Node node, int LCP, String
indexInGrand = indexInParent;
}
}
// node.key now represents term

// node.key represents term
// Move all peers (that aren't marked for death) from node.peers to newPeers
var oldPeersCount = newPeersCount;
foreach (var peer in node.peers)
if (peer.LCP >= 0) newPeersCount++;

if (newPeersCount != newPeers.Length) Array.Resize(ref newPeers, newPeersCount);
if (newPeersCount > oldPeersCount)
{
// Move all peers (that aren't marked for death) from node.peers to newPeers
var place = newPeersCount;
var newPeersIndex = oldPeersCount;
foreach (var peer in node.peers)
if (peer.LCP >= 0) newPeers[newPeersCount++] = peer;
if (peer.LCP >= 0) newPeers[newPeersIndex++] = peer;

if (place < newPeersCount)
UnboundedInsertionSortElementsLeft(newPeersCount, newPeers, place);
UnboundedInsertionSortElementsLeft(newPeersCount, newPeers, oldPeersCount);
}

// `node` is the Node that represents `term`, so we can throw it away now
return newPeersCount;
return newPeers;
}

static int UnboundedInsertionSortElementsLeft(int newPeersCount, (Node node, int LCP)[] newPeers, int place)
Expand Down Expand Up @@ -1350,7 +1357,7 @@ void AddTerms(List<(String term, score_int score)> terms)
if (count == 0) return;

#if !FORCE_STATIC
if (Count != 0)
if (termCount != 0)
{
// Optimization idea: Depending on how large `terms` is and how small the structure is,
// it might be faster to just nuke the previous structure and do the equivalent of
Expand Down Expand Up @@ -1465,7 +1472,7 @@ void AddTerms(List<(String term, score_int score)> terms)
parent[positionInParent].node = node;
}

this.Count = count;
this.termCount = count;
#if FORCE_STATIC
FillDictionaryStructure();
#endif
Expand All @@ -1474,7 +1481,7 @@ void AddTerms(List<(String term, score_int score)> terms)
public bool ReadTermsFromFile(String path)
{
#if FORCE_STATIC
if (Count != 0)
if (this.termCount != 0)
throw new Exception("Cannot read terms into an already initialized static PruningRadixTrie");
#endif
if (!System.IO.File.Exists(path))
Expand Down
23 changes: 11 additions & 12 deletions c#/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
This directory holds the C# implementation of the [*Dynamic Score-Decomposed Trie*](https://validark.github.io/DynSDT/demo).

This code is designed to work as a drop-in replacement to the [*Pruning Radix Trie*](https://github.com/wolfgarbe/PruningRadixTrie)
(and be many times faster). Note that this version does not have a separate `Node.cs` file
(so we can have a single `using score_int = System.Int64;` directive), so if you literally drag and drop
the `PruningRadixTrie` files onto your [*Pruning Radix Trie*](https://github.com/wolfgarbe/PruningRadixTrie)
files, you will have to delete the old `Node.cs` file. Note: the only change to the `csproj` file is that
`PruningRadixTrie.csproj` has the `TargetFramework` updated from `netstandard2.0` to `netstandard2.1`.
(and be many times faster).

## Limitations of this implementation

- This implementation differs a bit from the one in [the paper](https://validark.github.io/DynSDT/) because Nodes are implemented as structs, and C# doesn't support raw pointers unless one wants to deal with unsafe code and "pointer pinning". This is why whenever a struct is updated all of its copies in various places are updated as well.
- ${\rm G{\small et}T {\small op}}k {\rm T{\small erms}F{\small or}P{\small refix}}(p,\ k)$ allocates a list of size ${\rm M{\small in}}(k,\ c)$, where $c$ is the number of total string terms in the data structure. So if you pass in `int.MaxValue` to $k$ you're going to allocate $c$ slots for the result array.
- There are some differences from the original PruningRadixTrie implementation:
- `termCount` is an `int`, not a `long`
- The following public members are not supported: `termCountLoaded UpdateMaxCounts FindAllChildTerms BinarySearchComparer BinarySearchComparer`
- There is no `Node.cs` file (so we can have a single `using score_int = System.Int64;` directive), so if you literally drag and drop the `PruningRadixTrie` files onto your [*Pruning Radix Trie*](https://github.com/wolfgarbe/PruningRadixTrie) files, you will have to delete the old `Node.cs` file.
- In the `csproj` file, `PruningRadixTrie.csproj` has the `TargetFramework` updated from `netstandard2.0` to `netstandard2.1`.

## Preprocessor directives
This implementation has a few preprocessor directives which can toggle behavior:
Expand All @@ -16,10 +22,3 @@ This implementation has a few preprocessor directives which can toggle behavior:
|FORCE_STATIC|makes Set/Delete/AddTerm private, enables hashMaps which skip the first 2 characters, and enables string interning (when COMPRESS_STRINGS is also used)|
|DEBUG_METHODS|enables some methods that allow for testing that a given function produces the right output|
|INVARIANT_CHECKS|enables some runtime checks which are impossible to trigger unless there is a major flaw in logic somewhere. Mainly for debugging/fuzzing.|

The reason why this implementation differs a bit from the one in [the paper](https://validark.github.io/DynSDT/)
is that Nodes are implemented as structs, and C# doesn't support raw pointers unless one wants to
deal with unsafe code and "pointer pinning". This is why whenever a struct is updated all of its copies
in various places are updated as well.

#### TODO: Finish polishing tests and upload them

0 comments on commit 345518f

Please sign in to comment.