diff --git a/benchmark/kmeans.py b/benchmark/kmeans.py index 1bf3582..773d470 100644 --- a/benchmark/kmeans.py +++ b/benchmark/kmeans.py @@ -111,7 +111,6 @@ def timeit(self, name, engine_provider=None, is_slow=False): except NotSupportedByEngineError as e: print((repr(e) + "\n")) return - t0 = perf_counter() estimator.fit(X, sample_weight=sample_weight) t1 = perf_counter() diff --git a/sklearn_numba_dpex/kmeans/drivers.py b/sklearn_numba_dpex/kmeans/drivers.py index 92cdd8c..926c30c 100644 --- a/sklearn_numba_dpex/kmeans/drivers.py +++ b/sklearn_numba_dpex/kmeans/drivers.py @@ -33,6 +33,8 @@ make_select_samples_far_from_centroid_kernel, make_centroid_shifts_kernel, make_reduce_centroid_data_kernel, + make_is_same_clustering_kernel, + make_get_nb_distinct_clusters_kernel, ) from sklearn_numba_dpex.common._utils import _square, _plus, _minus @@ -462,20 +464,70 @@ def prepare_data_for_lloyd(X_t, init, tol, copy_x): return X_t, X_mean, init, tol -def restore_data_after_lloyd(X_t, X_mean): +def restore_data_after_lloyd(X_t, best_centers_t, X_mean, copy_x): + if X_mean is None: + return + n_features, n_samples = X_t.shape + n_clusters = best_centers_t.shape[1] device = X_t.device.sycl_device max_work_group_size = device.max_work_group_size - X_t = dpt.asarray(X_t, copy=False) - broadcast_X_plus_X_mean = make_broadcast_ops_1d_2d_axis1_kernel( - n_features, - n_samples, - ops=_plus, - work_group_size=max_work_group_size, + best_centers_t = dpt.asarray(best_centers_t, copy=False) + broadcast_init_plus_X_mean = make_broadcast_ops_1d_2d_axis1_kernel( + n_features, n_clusters, ops=_plus, work_group_size=max_work_group_size + ) + broadcast_init_plus_X_mean(best_centers_t, X_mean) + + # NB: copy_x being set to False does not mean that no copy actually happened, only + # that no copy was forced if it was not necessary with respect to what device, + # dtype and order that are required at compute time. Nevertheless, there's no + # simple way to check if a copy happened without assumptions on the type of the raw + # input submitted by the user, but at the moment it is unknown what those + # assumptions could be. As a result, the following instructions are ran every time, + # even if it isn't useful when a copy has been made. + # TODO: is there a set of assumptions that exhaustively describes the set of + # accepted inputs, and also enables checking if a copy happened or not in a simple + # way ? + if not copy_x: + X_t = dpt.asarray(X_t, copy=False) + broadcast_X_plus_X_mean = make_broadcast_ops_1d_2d_axis1_kernel( + n_features, n_samples, ops=_plus, work_group_size=max_work_group_size + ) + broadcast_X_plus_X_mean(X_t, X_mean) + + +def is_same_clustering(labels1, labels2, n_clusters): + """Check if two arrays of labels are the same up to a permutation of the labels""" + device = labels1.device.sycl_device + + is_same_clustering_kernel = make_is_same_clustering_kernel( + n_samples=labels1.shape[0], + n_clusters=n_clusters, + work_group_size=device.max_work_group_size, + device=device, ) - broadcast_X_plus_X_mean(X_t, X_mean) + return is_same_clustering_kernel(labels1, labels2) + + +def get_nb_distinct_clusters(labels, n_clusters): + device = labels.device.sycl_device + + get_nb_distinct_clusters_kernel = make_get_nb_distinct_clusters_kernel( + n_samples=labels.shape[0], + n_clusters=n_clusters, + work_group_size=device.max_work_group_size, + device=device, + ) + + clusters_seen = dpt.zeros(sh=(n_clusters,), dtype=np.int32, device=device) + + nb_distinct_clusters = dpt.zeros(sh=(1,), dtype=np.int32, device=device) + + get_nb_distinct_clusters_kernel(labels, clusters_seen, nb_distinct_clusters) + + return nb_distinct_clusters[0] def get_labels_inertia(X_t, centroids_t, sample_weight, with_inertia): diff --git a/sklearn_numba_dpex/kmeans/engine.py b/sklearn_numba_dpex/kmeans/engine.py index b53dbb5..c86d578 100644 --- a/sklearn_numba_dpex/kmeans/engine.py +++ b/sklearn_numba_dpex/kmeans/engine.py @@ -1,6 +1,5 @@ import numbers import contextlib -import importlib import numpy as np import dpnp @@ -9,7 +8,6 @@ import sklearn import sklearn.utils.validation as sklearn_validation -from sklearn.cluster._kmeans import KMeansCythonEngine from sklearn.utils import check_random_state, check_array from sklearn.utils.validation import _is_arraylike_not_scalar @@ -21,6 +19,8 @@ prepare_data_for_lloyd, lloyd, restore_data_after_lloyd, + is_same_clustering, + get_nb_distinct_clusters, get_labels_inertia, get_euclidean_distances, kmeans_plusplus, @@ -31,10 +31,7 @@ class _DeviceUnset: pass -# At the moment not all steps are implemented with numba_dpex, we inherit missing steps -# from the default sklearn KMeansCythonEngine for convenience, this inheritance will be -# removed later on when the other parts have been implemented. -class KMeansEngine(KMeansCythonEngine): +class KMeansEngine: """GPU optimized implementation of Lloyd's k-means. The current implementation is called "fused fixed", it consists in a sliding window @@ -91,7 +88,7 @@ def __init__(self, estimator): "https://github.com/IntelPython/numba-dpex/issues/767" ) self.order = order - super().__init__(estimator) + self.estimator = estimator def prepare_fit(self, X, y=None, sample_weight=None): estimator = self.estimator @@ -126,21 +123,7 @@ def unshift_centers(self, X, best_centers): if (X_mean := self.X_mean) is None: return - best_centers += dpt.asnumpy(X_mean.get_array()) - - # NB: self.estimator.copy_x being set to False does not mean that no copy - # actually happened, only that no copy was forced if it was not necessary - # with respect to what device, dtype and order that are required at compute - # time. Nevertheless, there's no simple way to check if a copy happened - # without assumptions on the type of the raw input submitted by the user, - # but at the moment it is unknown what those assumptions could be. - # As a result, the following instructions are ran every time, even if it - # isn't useful when a copy has been made. - # TODO: is there a set of assumptions that exhaustively describes the set - # of accepted inputs, and also enables checking if a copy happened or not - # in a simple way ? - if not self.estimator.copy_x: - restore_data_after_lloyd(X.T, X_mean) + restore_data_after_lloyd(X.T, best_centers.T, X_mean, self.estimator.copy_x) def init_centroids(self, X): init = self.init @@ -157,8 +140,8 @@ def init_centroids(self, X): centers_t = self._check_init(centers, X) else: - # NB: sampling without replacement must be executed sequentially so - # it's better done on CPU + # NB: sampling without replacement must be executed sequentially so it's + # better done on CPU centers_idx = self.random_state.choice( X.shape[0], size=n_clusters, replace=False ) @@ -197,22 +180,18 @@ def kmeans_single(self, X, sample_weight, centers_init_t): self.tol, ) - # TODO: explore leveraging dpnp to benefit from USM to avoid moving centroids - # back and forth between device and host memory in case a subsequent `.predict` - # call is requested on the same GPU later. - return ( - dpt.asnumpy(assignments_idx).astype(np.int32, copy=False), - inertia, - # XXX: having a C-contiguous centroid array is expected in sklearn in some - # unit test and by the cython engine. - # ???: rather that returning whatever dtype the driver returns (which might - # depends on device support for float64), shouldn't we cast to a dtype that - # is always consistent with the input ? (e.g. cast to float64 if the input - # was given as float64 ?) But what assumptions can we make on the input - # so we can infer its input dtype without risking triggering a copy of it ? - np.ascontiguousarray(dpt.asnumpy(best_centroids.T)), - n_iteration, - ) + # ???: rather that returning whatever dtype the driver returns (which might + # depends on device support for float64), shouldn't we cast to a dtype that + # is always consistent with the input ? (e.g. cast to float64 if the input + # was given as float64 ?) But what assumptions can we make on the input + # so we can infer its input dtype without risking triggering a copy of it ? + return assignments_idx, inertia, best_centroids.T, n_iteration + + def is_same_clustering(self, labels, best_labels, n_clusters): + return is_same_clustering(labels, best_labels, n_clusters) + + def get_nb_distinct_clusters(self, best_labels): + return get_nb_distinct_clusters(best_labels, self.estimator.n_clusters) def prepare_prediction(self, X, sample_weight): X = self._validate_data(X, reset=False) @@ -222,7 +201,7 @@ def prepare_prediction(self, X, sample_weight): def get_labels(self, X, sample_weight): # TODO: sample_weight actually not used for get_labels. Fix in sklearn ? labels, _ = self._get_labels_inertia(X, sample_weight, with_inertia=False) - return dpt.asnumpy(labels).astype(np.int32, copy=False) + return labels def get_score(self, X, sample_weight): _, inertia = self._get_labels_inertia(X, sample_weight, with_inertia=True) @@ -254,7 +233,7 @@ def get_euclidean_distances(self, X): self.estimator.cluster_centers_, X, copy=False ) euclidean_distances = get_euclidean_distances(X.T, cluster_centers) - return dpt.asnumpy(euclidean_distances) + return euclidean_distances def _validate_data(self, X, reset=True): if isinstance(X, dpnp.ndarray): @@ -349,6 +328,39 @@ def _check_init(self, init, X, copy=False): return init_t +def KMeansEngineDebug(KMeansEngine): + """KmeansEngine engine with debug features. + + Fitted attributes and outputs are set or returned as numpy arrays. The fitted + attributes are converted on the fly when used. This engine enables running + scikit-learn test pipeline.""" + + def kmeans_single(self, X, sample_weight, centers_init_t): + assignments_idx, inertia, best_centroids, n_iteration = super().kmeans_single( + X, sample_weight, centers_init_t + ) + + return ( + dpt.asnumpy(assignments_idx), + inertia, + # XXX: having a C-contiguous centroid array is expected in sklearn in some + # unit test and by the cython engine. + np.ascontiguousarray(dpt.asnumpy(best_centroids.T)), + n_iteration, + ) + + def get_nb_distinct_clusters(self, best_labels): + return super(KMeansEngine, self).get_nb_distinct_clusters(best_labels) + + def get_labels(self, X, sample_weight): + labels = super().get_labels(X, sample_weight) + return dpt.asnumpy(labels) + + def get_euclidean_distances(self, X): + euclidean_distances = super().get_euclidean_distances(X) + return dpt.asnumpy(euclidean_distances) + + def _get_namespace(*arrays): return dpt, True diff --git a/sklearn_numba_dpex/kmeans/kernels/__init__.py b/sklearn_numba_dpex/kmeans/kernels/__init__.py index b0444d3..a2ed738 100644 --- a/sklearn_numba_dpex/kmeans/kernels/__init__.py +++ b/sklearn_numba_dpex/kmeans/kernels/__init__.py @@ -14,6 +14,8 @@ make_select_samples_far_from_centroid_kernel, make_centroid_shifts_kernel, make_reduce_centroid_data_kernel, + make_is_same_clustering_kernel, + make_get_nb_distinct_clusters_kernel, ) @@ -29,4 +31,6 @@ "make_select_samples_far_from_centroid_kernel", "make_centroid_shifts_kernel", "make_reduce_centroid_data_kernel", + "make_is_same_clustering_kernel", + "make_get_nb_distinct_clusters_kernel", ) diff --git a/sklearn_numba_dpex/kmeans/kernels/utils.py b/sklearn_numba_dpex/kmeans/kernels/utils.py index ea33a12..187552d 100644 --- a/sklearn_numba_dpex/kmeans/kernels/utils.py +++ b/sklearn_numba_dpex/kmeans/kernels/utils.py @@ -3,6 +3,7 @@ import numpy as np import numba_dpex as dpex +import dpctl.tensor as dpt zero_idx = np.int64(0) @@ -245,3 +246,82 @@ def reduce_centroid_data( empty_clusters_list[current_n_empty_clusters] = cluster_idx return reduce_centroid_data[global_size, work_group_size] + + +@lru_cache +def make_is_same_clustering_kernel(n_samples, n_clusters, work_group_size, device): + # TODO: are there possible optimizations for this kernel ? + # - fusing the two kernels (requires a lock ?) + # - early stop + + def is_same_clustering(labels1, labels2): + mapping = dpt.empty(sh=(n_clusters,), dtype=np.int32, device=device) + result = dpt.asarray([1], dtype=np.int32, device=device) + _build_mapping[global_size, work_group_size](labels1, labels2, mapping) + _is_same_clustering[global_size, work_group_size]( + labels1, labels2, mapping, result + ) + return bool(result[0]) + + @dpex.kernel + # fmt: off + def _build_mapping( + labels1, # IN (n_samples,) + labels2, # IN (n_samples,) + mapping, # BUFFER (n_clusters,) + ): + # fmt: on + """`mapping` is expected to be an empty array with dtype int32 initialized with + -1 of size `n_clusters`: + `mapping = dpt.full(sh=n_clusters, fill_value=-1, dtype=np.int32)` + """ + sample_idx = dpex.get_global_id(zero_idx) + if sample_idx >= n_samples: + return + + label1 = labels1[sample_idx] + label2 = labels2[sample_idx] + mapping[label1] = label2 + + @dpex.kernel + # fmt: off + def _is_same_clustering( + labels1, + labels2, + mapping, + result + ): + # fmt: on + sample_idx = dpex.get_global_id(zero_idx) + if mapping[labels1[sample_idx]] != labels2[sample_idx]: + result[zero_idx] = zero_idx + + global_size = math.ceil(n_samples / work_group_size) * work_group_size + + return is_same_clustering + + +@lru_cache +def make_get_nb_distinct_clusters_kernel( + n_samples, n_clusters, work_group_size, device +): + one_incr = np.int32(1) + + @dpex.kernel + def _get_nb_distinct_clusters(labels, clusters_seen, nb_distinct_clusters): + sample_idx = dpex.get_global_id(zero_idx) + + if sample_idx >= n_samples: + return + + label = labels[sample_idx] + + if clusters_seen[label] > zero_idx: + return + + is_new = dpex.atomic.add(clusters_seen, zero_idx, one_incr) + if is_new == one_incr: + dpex.atomic.add(nb_distinct_clusters, zero_idx, one_incr) + + global_size = math.ceil(n_samples / work_group_size) * work_group_size + return _get_nb_distinct_clusters[global_size, work_group_size]