Skip to content

Commit

Permalink
Fixing branch pruning.
Browse files Browse the repository at this point in the history
  • Loading branch information
AG committed Mar 4, 2024
1 parent f6c97ef commit 945d4db
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 27 deletions.
31 changes: 17 additions & 14 deletions create_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,30 @@ def branch_pruner(trie):
if not isinstance(trie, dict):
return

# Iterate through each key in the trie
# Recursively prune the branches of each subtree
for key in list(trie.keys()):
# Recursively prune the branches of each subtree
branch_pruner(trie[key])
if key != '\ranked':
branch_pruner(trie[key])

# If the current level contains '\ranked', start pruning based on BRANCH_PRUNE_COUNT
if '\ranked' in trie:
ranked_keys = trie['\ranked']
# Determine the keys to keep: top BRANCH_PRUNE_COUNT keys based on the order in '\ranked'
keys_to_keep = set(ranked_keys[:BRANCH_PRUNE_COUNT])
# Sort the ranked words by their scores in descending order and keep only the top BRANCH_PRUNE_COUNT words
top_ranked_words = sorted(trie['\ranked'].items(), key=lambda item: item[1], reverse=True)[:BRANCH_PRUNE_COUNT]
# Convert the list of tuples back to a dictionary
top_ranked_dict = {word: score for word, score in top_ranked_words}

# print(top_ranked_dict)

# Add '\ranked' to the keys to keep to avoid pruning it
keys_to_keep.add('\ranked')
# Prune the trie to keep only the branches corresponding to the top ranked words
keys_to_keep = set(word for word, _ in top_ranked_words)
keys_to_keep.add('\ranked') # Ensure '\ranked' itself is kept

# Prune the keys not in keys_to_keep
for key in list(trie.keys()):
if key not in keys_to_keep:
del trie[key]
# Update the '\ranked' list to reflect the pruned keys
trie['\ranked'] = ranked_keys[:BRANCH_PRUNE_COUNT]
if key not in keys_to_keep and key != '\ranked':
del trie[key] # Remove branches not among the top ranked

# Update the '\ranked' dictionary to reflect only the top ranked words and their scores
trie['\ranked'] = top_ranked_dict

def convert_to_array(obj):
"""
Expand Down
2 changes: 1 addition & 1 deletion dictionary.js

Large diffs are not rendered by default.

24 changes: 12 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,19 @@ def _slugify(text):
# Define a function to update the trie structure with predictive words
def update_trie(trie, predictive_words):
for word in predictive_words:
# Ensure each word has a sub-trie if it does not exist
if word not in trie:
trie[word] = {}
# Ensure the '\ranked' key exists with a default list if not already present
if '\ranked' not in trie:
trie['\ranked'] = []
# Add the word to '\ranked' if it's not already in the list
if word not in trie['\ranked']:
trie['\ranked'].append(word)
else:
# Promote the word by one position if it's not already at the start
index = trie['\ranked'].index(word)
if index > 0:
trie['\ranked'].insert(max(0, index - 1), trie['\ranked'].pop(index))

# Ensure the '\ranked' key exists at the current level if not already present
if '\ranked' not in trie:
trie['\ranked'] = {}

# Update the score in '\ranked' at the current level for the current word
trie['\ranked'][word] = trie['\ranked'].get(word, 0) + 1

# Move to the sub-trie of the current word for the next iteration
# This ensures the structure for subsequent words while keeping '\ranked' updated at the parent level
trie = trie[word]

# Define a function to load or initialize the trie from memory
Expand Down Expand Up @@ -163,7 +163,7 @@ def main():
words = row.split()

# Every now and then save our progress.
# print(f"Saving the current position of %s" % current_position)
print(f"Saving the current position of %s" % current_position)
# Save the current progress (file position)
with open(progress_file, 'w') as f:
f.write(str(current_position))
Expand Down

0 comments on commit 945d4db

Please sign in to comment.