-
Notifications
You must be signed in to change notification settings - Fork 156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Discussion of the API for Clustering models #852
Comments
I think this is also a question of MLJ's vision for the future. Does MLJ want to standardize the different tasks (classification, regression, dimensionality reduction, outlier detection, clustering, association rules, ...), such that they are clearly defined, but restricted? In this case, it might make sense to focus on supervised learning like caret does? Personally, I would like to see the core API task-independent and flexible, such that it does not rule out a lot of use cases.
I also think that's awkward. I think the resulting clustering assignments should live in the
Sounds good, although I would see that as a responsibility of the individual algorithm/package authors.
Would make sense to rely on the trait for this imho. By the way what is the use of an
Having worked with the new outlier detection subtypes I'm pretty sure subtypes are not the way to go. Additionally, I've implemented the type hierarchy in JuliaAI/MLJBase.jl#656 (comment) for MMI/Base and I did not like the API. I'm pretty sure a trait-based system is preferable. However, I learned that refactoring is not that bad and quite easily doable. The reason why subtypes don't work is that they mix up different concepts, e.g. the supervised/unsupervised type defines characteristics of the input data while the probabilistic/deterministic trait defines characteristics of the output data. Each time you want to define something that works for all probabilistic models (classifiers, clusterers, outlier detectors, ...), you'd have to define/rely on some type union. Mixins would capture such relationships, but Julia does not have Mixins. The most consistent solution would probably be to directly subtype from |
Thanks for chiming in here with some detailed feedback. Very much appreciated. Will get back to this eventually. |
Clustering models in MLJ are implemented as
Unsupervised
models. While they share common functionality, this has not been
properly documented, and there is no model subtype or trait that tags
a model as implementing that common interface.
I am opening this issue to summarize the existing interface for
purposes of discussion of possible enhancements/modifications, and how
we might go about formalizing the interface.
Existing interface
Existing "clustering" models in the MLJ registry share the following interface
fit:
fit(model, verbosity, X)
sees the training dataX
and learnsparameters, output as
fitresult
, required to label new data. Asfor general models, training-related outcomes that are not part of
fitresult
, but which the user may want to access, are returned in thereport
(a named tuple with informative keys).predict:
predict(models, fitresult, Xnew)
, if implemented, returns either(i) the clustering labels (assignments) for new data
Xnew
, as aCategoricalVector
(unordered) (scitypeAbstractVector{<:Multiclass}
); or (ii) probabilistic predictionsfor the the clustering labels (a vector of
UnivariateFinite
distributions). Important: The categorical vector (or
UnivariateFinite
vector) includes all cluster labels in its pool, not just the
predicted ones. So, for example, in the deterministic case,
levels(predict(models, fitresult, Xnew))
is the same for allXnew
, a vector with one element per cluster.transform:
transform(models, fitresult, Xnew)
, if implemented, performsdimension reduction, returning a table with
Continuous
columns,one for each cluster.
models that do not generalize to new data (e.g., ScikitLearn's
DBSCAN, AgglomerativeClustering) implement neither
predict
nortransform
(because all MLJ operations are understood to generalizeto new data). The cluster labels (we could call them training
labels to distinguish them from new
prediction
s in other models)appear in the
fitresult
, which the user accesses usingfitted_params
to get a user friendlyversion.
The trait
input_scitype(::Type{MyClusterer})
returns the requiredscitype of
X
(alwaysTable(Continuous)
).The trait
output_scitype(::Type{MyClusterer})
returns the scitypeof the output of
transform
(alsoTable(Continuous)
)I notice that the ScikitLearn clusterers just bundles all training outcomes into the
fitresult
(and nothing in
report
) which does not strictly comply with the published API, eg here. Also, the same API would imply that "non-generalizing" models should place all training outcomes in thereport
, instead of thefitresult
, but they do the opposite.I also notice that when training labels are added to the report, they are often just integer vectors,
while for consistency they should be categorical vectors, as returned by
predict
.I believe
GMMClusterer
is the only probabilistic clusterer.Comment
So, does this interface rule out some clustering models we have yet to encounter?
Are there further requirements should we impose?
I have thought that models that do not generalise could be conceptualised
as
Static
transformers, but that imposes the requirement thattransform
returns everything of interest (there is not
fit
to generate a report or fitresult) which can bebe awkward.
For consistency, I'd have thought the
target_scitye
trait should returnAbstractVector{<:Multiclass}
, as this is the scittype of whatpredict
(orpredict_mode
) returns.But I see this has not been implemented consistently.
Currently there is no way to distinguish which models predict
probabilities, which predict actual labels, and which do not predict
at all. The existing
prediction_type
trait could make thisdistinction (
:probabilistic
,:deterministic
,:unknown
). Atpresent the models I have checked all return
:unknown
(thefallback).
Another question is whether we anchor the interface with a new
subtype(s) or use traits.
The text was updated successfully, but these errors were encountered: