Skip to content

Commit

Permalink
Fixed usage of updated ScoredAnalysis.
Browse files Browse the repository at this point in the history
  • Loading branch information
owo committed Mar 2, 2023
1 parent 7a0753a commit 794d18e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
7 changes: 6 additions & 1 deletion camel_tools/disambig/bert/unfactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,12 @@ def _scored_analyses(self, word_dd, prediction):
if len(analyses) == 0:
# If the word is not found in the analyzer,
# return the predictions from BERT
return [ScoredAnalysis(0, bert_analysis)]
return [ScoredAnalysis(0, # score
bert_analysis, # analysis
bert_analysis['diac'], # diac
-99, # pos_lex_logprob
-99, # lex_logprob
)]

scored = [(self._scorer(a,
bert_analysis,
Expand Down
16 changes: 11 additions & 5 deletions camel_tools/disambig/mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,17 @@ def _scored_analyses(self, word_dd):

max_score = max([s[0] for s in scored])

scored_analyses = [ScoredAnalysis(s[0] / max_score, s[1])
for s in scored]
if max_score == 0:
max_score = 1

scored_analyses = [
ScoredAnalysis(
s / max_score, # score
a, # analysis
a['diac'], # diac
a.get('pos_lex_logprob', -99), # pos_lex_logprob
a.get('lex_logprob', -99), # lex_logprob
) for s, a in scored]

return scored_analyses[0:self._top]

Expand All @@ -202,9 +211,6 @@ def _scored_analyses(self, word_dd):
probabilities = [10 ** _get_pos_lex_logprob(a) for a in analyses]
max_prob = max(probabilities)

scored_analyses = [ScoredAnalysis(p / max_prob, a)
for a, p in zip(analyses, probabilities)]

scored_analyses = [
ScoredAnalysis(
p / max_prob, # score
Expand Down

0 comments on commit 794d18e

Please sign in to comment.