diff --git a/smote_variants/_version.py b/smote_variants/_version.py index e78ea73..daf748c 100755 --- a/smote_variants/_version.py +++ b/smote_variants/_version.py @@ -6,4 +6,4 @@ @author: gykovacs """ -__version__= '0.6.6' +__version__= '0.6.7' diff --git a/smote_variants/base/_simplexsampling.py b/smote_variants/base/_simplexsampling.py index dc75430..1bf9f76 100644 --- a/smote_variants/base/_simplexsampling.py +++ b/smote_variants/base/_simplexsampling.py @@ -281,6 +281,54 @@ def add_samples(*, samples_by_count = [base_points + diffweight*s for s in splits] return np.vstack(samples_by_count) +def counts_to_vector(counts): + """ + Expand a count vector to a + + Args: + counts (np.array): count vector + + Returns: + np.array: the expanded vector + """ + + return np.hstack([np.repeat(idx, count) for idx, count in enumerate(counts)]) + +def deterministic_sample(choices, n_to_sample, p): + """ + Take a deterministic sample + + Args: + choices (list): the list of choices + n_to_sample (int): the number of samples to take + p (np.array): the distribution + + Returns: + np.array: the choices + """ + + sample_counts = np.ceil(n_to_sample * p).astype(int) + + n_to_remove = np.sum(sample_counts) - n_to_sample + + if n_to_remove == 0: + return choices[counts_to_vector(sample_counts)] + + non_zero_mask = sample_counts > 0 + + removal_indices = np.floor(np.linspace(0.0, np.sum(non_zero_mask), n_to_remove, endpoint=False)).astype(int) + + tmp = sample_counts[non_zero_mask] + tmp[removal_indices] = tmp[removal_indices] - 1 + + sample_counts[non_zero_mask] = tmp + + assert np.sum(sample_counts) == n_to_sample + + samples = choices[counts_to_vector(sample_counts)] + + return samples + class SimplexSamplingMixin(RandomStateMixin): """ The mixin class for all simplex sampling based techniques. @@ -436,11 +484,17 @@ def simplices(self, weights = weights * node_weights - # sample the simplices choices = np.arange(all_simplices.shape[0]) - selected_indices = self.random_state.choice(choices, - n_to_sample, - p=weights/np.sum(weights)) + + if self.simplex_sampling == 'random': + # sample the simplices + selected_indices = self.random_state.choice(choices, + n_to_sample, + p=weights/np.sum(weights)) + elif self.simplex_sampling == 'deterministic': + selected_simplices = deterministic_sample(choices, + n_to_sample, + p=weights/np.sum(weights)) return all_simplices[selected_indices] def add_gaussian_noise(self, samples): diff --git a/tests/oversampling/test_simplex_deterministic.py b/tests/oversampling/test_simplex_deterministic.py index 66ffc3d..ee478d2 100644 --- a/tests/oversampling/test_simplex_deterministic.py +++ b/tests/oversampling/test_simplex_deterministic.py @@ -21,7 +21,8 @@ def test_simplex_deterministic(smote_class): Args: smote_class (class): an oversampler class. """ - ss_params = {'within_simplex_sampling': 'deterministic'} + ss_params = {'within_simplex_sampling': 'deterministic', + 'simplex_sampling': 'deterministic'} X, y = smote_class(ss_params=ss_params).sample(dataset['data'], dataset['target'])