Skip to content

Commit

Permalink
some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
gykovacs committed Jan 1, 2024
1 parent 60ff996 commit c26cdd9
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,25 +59,27 @@ def sample(self, X, y):
n_maj, n_min = np.bincount(y)

X_samp, _ = self.oversampler.sample(X, y)
X_new = X_samp[X.shape[0] :]
X_new = X_samp[X.shape[0]:]
X_min = X[y == 1]
X_maj = X[y == 0]

dists = X_new - X_maj[:, None]

dists = np.mean(np.sqrt(np.sum((dists) ** 2, axis=2)), axis=1)
dists = np.min(np.sqrt(np.sum((dists) ** 2, axis=2)), axis=1)

if self.mode == "random":
inv_dists = 1.0 / np.where(dists < 1e-8, 1e-8, dists)
p = dists / np.sum(inv_dists)
dists = (dists - np.min(dists)) / (np.max(dists) - np.min(dists))
inv_dists = 1.0 - dists

p = inv_dists / np.sum(inv_dists)

mask = self.random_state.choice(np.arange(n_maj), n_min, p=p, replace=False)
X_maj = X_maj[mask]
elif self.mode == "farthest":
sorting = np.argsort(dists)
X_maj = X_maj[sorting[:n_min]]
X_maj = X_maj[sorting][:n_min]

X_res = np.vstack([X_maj, X_min]) # pylint: disable=invalid-name
y_res = np.hstack([np.repeat(0, n_min), np.repeat(1, n_min)])
X_res = np.vstack([X_maj, X_min]).copy() # pylint: disable=invalid-name
y_res = np.hstack([np.repeat(0, X_maj.shape[0]), np.repeat(1, X_min.shape[0])])

return X_res, y_res
7 changes: 6 additions & 1 deletion smote_variants/undersampling/_random_undersampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ def sample(self, X, y):
np.array, np.array: the undersampled feature vectors and class labels
"""
n_maj, n_min = np.bincount(y)
#mask = self.random_state.choice(np.arange(n_maj), n_min, p=np.repeat(1.0/n_maj, n_maj), replace=False)
mask = self.random_state.choice(np.arange(n_maj), n_min, replace=False)
X_maj = X[y == 0][mask]
X_min = X[y == 1]

return np.vstack([X_maj, X[y == 1]]), np.hstack([np.repeat(0, n_min), np.repeat(1, n_min)])
X_res = np.vstack([X_maj, X_min]).copy() # pylint: disable=invalid-name
y_res = np.hstack([np.repeat(0, X_maj.shape[0]), np.repeat(1, X_min.shape[0])])

return X_res, y_res
24 changes: 23 additions & 1 deletion tests/undersampling/test_undersampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
RandomUndersampling,
OversamplingDrivenUndersampling,
)
from smote_variants.oversampling import SMOTE

undersamplings = [RandomUndersampling, OversamplingDrivenUndersampling]

Expand All @@ -37,3 +36,26 @@ def test_undersampling_normal(undersampling):
_, y_samp = undersampling_obj.sample(X, y)

assert np.sum(y_samp) == len(y_samp) - np.sum(y_samp)

@pytest.mark.parametrize("undersampling", undersamplings)
def test_undersampling_normal_farthest(undersampling):
"""
Testing the undersampling farthest strategy
Args:
undersampling (cls): the undersampling class
"""

dataset = load_normal()

X = dataset["data"]
y = dataset["target"]

params = undersampling.parameter_combinations()

for param in params:
undersampling_obj = undersampling(**param)

_, y_samp = undersampling_obj.sample(X, y)

assert np.sum(y_samp) == len(y_samp) - np.sum(y_samp)

0 comments on commit c26cdd9

Please sign in to comment.