-
Notifications
You must be signed in to change notification settings - Fork 156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transformers that need to see target (eg, recursive feature elimination) #874
Comments
cc @pazzo83 |
Would it make sense to create a new model type for this? It's like a supervised transformer. |
I think the general consensus is a move away from types to traits. I'm afraid the discussions are a little scattered. See, for example #852 (comment). We could add a trait for "supervised transformers" but perhaps this is unnecessary as the |
So would my transformer that relies on a target subtype
I guess I'm not familiar enough with the inner workings of the API yet to know whether that check needs to be modified, but could you expand a bit on what you mean by using traits to allow for this functionality? |
I was able to partially get this working with the following patches (after I declared my own abstract type that my transformers subtype): abstract type TargetTransformer <: MMI.Unsupervised end
MLJModelInterface.fit_data_scitype(M::Type{<:TargetTransformer}) =
Tuple{input_scitype(M), target_scitype(M)}
MLJBase.check(model::TargetTransformer, args...; full=false) = MLJBase.check_supervised(model, full, args...)
MLJBase.warn_scitype(model::TargetTransformer, X, y) =
"The scitype of `y`, in `machine(model, X, y, ...)` "*
"is incompatible with "*
"`model=$model`:\nscitype(y) = "*
"$(MLJBase.elscitype(y))\ntarget_scitype(model) "*
"= $(MLJBase.target_scitype(model))." However, when put into a pipeline, it no longer works - it seems because it is still unsupervised the target is not getting passed through (see here: https://github.com/JuliaAI/MLJBase.jl/blob/dev/src/composition/models/pipelines.jl#L72). This might be kind of hacky and I think you are suggesting something a bit different? |
@pazzo83 Thanks for looking at this. This isn't far from what I was imagining. Only, rather than introduce a new abstract type, I'd overload MLJModelInterface.fit_data_scitype(M::Type{<:MyTransformer}) =
Tuple{input_scitype(M), target_scitype(M)} Then, to fix the the type checking, modify the existing Ditto This is all a little tricky as we want flexibility of design, but want also catch users' unintentional mistakes with informative errors. Warnings are better than errors here, but even warnings should be thrown only as necessary. And we'd need to test the changes with a dummy "supervised" transformer in tests.
No. That will require a bit more work. Nevertheless, I'm pretty sure one could include these transformers in custom composite models (exported learning networks) without issues. So they would be useful even without the pipeline enhancement. I'd support a PR that fixes the checks without worrying about pipelines just yet. Would also be great to have an actual "supervised" transformer implementation to try this out on. Have you already started on something? |
Thanks for the feedback! I can definitely put together a PR for this - I have some local code I've been working on so I can incorporate your feedback and go from there. |
I've been looking at this over the last couple of days based on the feedback here: JuliaAI/MLJBase.jl#705 Would it work if we removed all the various function check(model::Model, args...; full=false)
nowarns = true
F = fit_data_scitype(model)
(F >: Unknown || F >: Tuple{Unknown} || F >: NTuple{<:Any,Unknown}) &&
return true
S = Tuple{elscitype.(args)...}
if !(S <: F)
@warn warn_generic_scitype_mismatch(S, F)
nowarns = false
end
end I got it working if I rewrote the line: |
Yes! That is what I think we should do. And it indeed looks like you found a bug with the I would expand the return value of "The number and/or types of data arguments do not match what the specified model supports. Commonly, but non exclusively, supervised models are constructed using the syntax Thanks for getting back go this! |
Just a note that scitype checks have now (MLJBase 18.0) been relaxed to allow transformers that need a target. |
Resolved. |
A number of feature-reduction strategies only make sense in the context of a supervised learning task because they must consult a target variable when trained. For example, one might wants to drop features which correlate poorly with the target. In fact all but the first of sklearn's feature selectors are of this kind.
At the level of the basic model API, a transformer (or any other model) can specify any number of arguments to be used in training. So there is nothing wrong with a transformer with a
fit
method likeThere is now a trait defined in MLJModelInterface to explicitly articulate the acceptable
fit
signatures (up to scitype). For any model type that subtypesUnsupervised
this falls back to a single argument where the scitype must coincide withinput_scitype(model)
. So for transformers that needs the target in training, you would override the trait with a declaration such as:and be sure to declare a
target_scitype
, just as you would for a supervised model. That should do it.It may be that some argument checks for machines have to be tweaked in MLJBase (edit now done) but this should be very easy and essentially non-breaking.
Most happy to provide support to anyone wishing to implement such transformers.
The text was updated successfully, but these errors were encountered: