-
Notifications
You must be signed in to change notification settings - Fork 90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support passing extra data for loss function via MLJ interface #249
Conversation
@OkonSamuel @ablaom I was not sure whether there is a way to pass additional custom data to a machine, so I am currently simply allowing the user to pass a |
Benchmark Results
Benchmark PlotsA plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR. |
b0403e8
to
08a13e1
Compare
3708906
to
31b2787
Compare
You want to provide Subsampling of training data in MLJ. When observations are subsampled by If per-observation weights are not ever going to be supported, then perhaps two signatures Another possibility, which I quite like, is to insist that |
@MilesCranmer this works for my use case. Thanks again working on this quickly. |
[Diff since v0.23.1](v0.23.1...v0.23.2) **Merged pull requests:** - Formatting overhaul (#278) (@MilesCranmer) - Avoid julia-formatter on pre-commit.ci (#279) (@MilesCranmer) - Make it easier to select expression from Pareto front for evaluation (#289) (@MilesCranmer) **Closed issues:** - Garbage collection too passive on worker processes (#237) - How can I set the maximum number of nests? (#285)
@ablaom it seems like function fit_only!(
mach::Machine{<:Any,cache_data};
rows=nothing,
verbosity=1,
force=false,
composite=nothing,
) where cache_data |
I've tried a few different strategies it doesn't seem like there's a good way to let users to pass arbitrary data (of any shape) to be used in a custom loss function. I think this isn't a limitation necessarily, it just is a point at which high-level interfaces should not be used, as such levels of customisation would break various assumptions anyways. For now I think we need to close this @tomaklutfu, doesn't seem like there's any robust way to do this right now. I would recommend either:
Cheers, |
Correct. Custom kwargs to |
Thanks @MilesCranmer . I did use custom loss function sub-typed via a struct with fields for extra data. It worked without hurdles. |
For example, say we create a custom loss function that compares both
f(x)
andf'(x)
against data. We can access the values withdataset.y
and the derivatives withdataset.extra.y
. This.extra
property allows you to store arbitrary named tuples for accessing in a custom loss function.Here, we have also taken advantage of mini-batching, using the
idx
to sample from bothdataset.y
as well asdataset.extra
.You can now use this loss function this by passing a
NamedTuple
for thew
input tomachine
, which is usually a vector of weights. If you pass a vector, it will be treated as the weights. But if you pass a NamedTuple, it will get added to theextra
property ofDataset
.e.g.,