Skip to content

Commit

Permalink
trim smote infinite loop fix
Browse files Browse the repository at this point in the history
  • Loading branch information
gykovacs committed Aug 22, 2022
1 parent 212b8c2 commit 6e7286b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
4 changes: 3 additions & 1 deletion smote_variants/evaluation/_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def do_evaluation(self):
'classifier_params': classifier[2],
'classifier_module': classifier[0],
'error': error,
'warning': str(warning_list)}
'warning': str(warning_list),
'oversampling_error': oversampling['error'],
'oversampling_warning': oversampling['warning']}

if self.cache_path is not None:
dump_dict(evaluation, target_filename, self.serialization,
Expand Down
6 changes: 3 additions & 3 deletions smote_variants/oversampling/_trim_smote.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,9 @@ def generate_samples(self, X, y,
nnmt.fit(X_seed_min)
indices = nnmt.kneighbors(X_seed_min, return_distance=False)

#n_dim_orig = self.n_dim
#self.n_dim = np.min([self.n_dim, X_seed_min.shape[0]])
samples = self.sample_simplex(X=X_seed_min,
indices=indices,
n_to_sample=n_to_sample)
#self.n_dim = n_dim_orig

return samples

Expand All @@ -364,6 +361,9 @@ def sampling_algorithm(self, X, y):

seeds = self.trimming(X, y)

if len([s for s in seeds if self.precision(s[1]) > self.min_precision/10.0]) == 0:
return self.return_copies(X, y, "no seeds found")

# filtering the resulting set
filtered_seeds = [s for s in seeds if self.precision(s[1]) > self.min_precision]

Expand Down
6 changes: 6 additions & 0 deletions tests/oversampling/algorithm_level/test_trim_smote.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ def test_specific():

assert len(X_samp) > 0

obj.min_precision = 10000.0

X_samp, _ = obj.sample(dataset['data'], dataset['target'])

assert len(X_samp) > 0

assert len(obj.parameter_combinations(raw=True)) > 0

X = np.array([[1, 2], [2, 3], [3, 4], [4, 5],
Expand Down

0 comments on commit 6e7286b

Please sign in to comment.