-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Store fitted attributes as dpt and return dpt tensors #65
Conversation
c04e512
to
bd8308c
Compare
I wanted to duplicate the But with current plugin design in scikit-learn/scikit-learn#24497 I don't see a way around. One module can only implement one engine. Maybe there's a workaround consisting in exposing several top level namespaces (e.g Or WDYT about changing the plugin spec to enable a different syntax that allow passing dotted string to the E.g rewriting https://github.com/scikit-learn/scikit-learn/pull/24497/files#diff-1d31de81e903bd6529fbe68f8009b7113e3b7de4f1465572ef88af4d03a7dc5bR35 such that entry points whose value prefixes match the user inputed string are selected. (i.e. using prefix selection rather than Edit: somewhat related to #21 |
I think we should allow to programmatically register additional provider names (without in addition the from sklearn._engine import register_engine
register_engine(
engine_name="kmeans",
engine_class="sklearn_numba_dpex.kmeans.engine:KMeansEngineDebug",
provider_name="sklearn_numba_dpex_debug",
) In this case we break the fact that the provider name is usually the same as the toplevel import package of the engine class, so this might impose some refactoring of the engine lookup / import logic to avoid relying on this assumption. |
I agree although since it requires work on the sklearn side, I think we can still push this PR on with an emvironment variable based switch in the meantime, I'll add a commit going this direction |
bd8308c
to
c9a356d
Compare
Out of WIP |
…r for running sklearn tests
c9a356d
to
ef49d7e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @fcharras. Here are a few comments.
initialized with value 1. | ||
""" | ||
sample_idx = dpex.get_global_id(zero_idx) | ||
if sample_idx >= n_samples: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if sample_idx >= n_samples: | |
# Early return if result is already False. | |
if sample_idx >= n_samples or result[zero_idx] == zero_idx: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this add a read operation and doesn't stop future tasks, it just ensures that future tasks will exit early after the read. So on average it's better when clusterings are different but not when clusterings are the same. Still better since different clusterings should be the most common outcome.
Co-authored-by: Julien Jerphanion <[email protected]>
Comments should be addressed in last commit and suggestions applied. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LTGM. Thank you, @fcharras!
Co-authored-by: Julien Jerphanion <[email protected]>
So in fact |
The test was the issue, we have |
The compatibility layer for having the sklearn test keep working is still TODO.Let alone that, I think this would be the final state ofKMeansEngine
.