diff --git a/src/measures/loss_functions_interface.jl b/src/measures/loss_functions_interface.jl index 76838fef..d7ba5ea5 100644 --- a/src/measures/loss_functions_interface.jl +++ b/src/measures/loss_functions_interface.jl @@ -115,7 +115,7 @@ MMI.target_scitype(::Type{<:DistanceLoss}) = Union{Vec{Continuous},Vec{Count}} function value(measure::DistanceLoss, yhat, X, y, ::Nothing, ::Val{false}, ::Val{true}) - return LossFunctions.value(getfield(measure, :loss), yhat, y) + return LossFunctions.value(getfield(measure, :loss), y, yhat) end function value(measure::DistanceLoss, yhat, X, y, w, @@ -137,7 +137,7 @@ function value(measure::MarginLoss, yhat, X, y, ::Nothing, check_pools(yhat, y) probs_of_observed = broadcast(pdf, yhat, y) return (LossFunctions.value).(getfield(measure, :loss), - _scale.(probs_of_observed), 1) + 1, _scale.(probs_of_observed)) end function value(measure::MarginLoss, yhat, X, y, w, diff --git a/test/measures/loss_functions_interface.jl b/test/measures/loss_functions_interface.jl index 04c6af9d..ba7b3763 100644 --- a/test/measures/loss_functions_interface.jl +++ b/test/measures/loss_functions_interface.jl @@ -36,10 +36,10 @@ const DISTANCE_LOSSES = MLJBase.DISTANCE_LOSSES for M_ex in MARGIN_LOSSES m = eval(:(MLJBase.$M_ex())) @test MLJBase.value(m, yhat, X, y, nothing) ≈ - LossFunctions.value(getfield(m, :loss), yhatm, ym) ≈ + LossFunctions.value(getfield(m, :loss), ym, yhatm) ≈ m(yhat, y) @test mean(MLJBase.value(m, yhat, X, y, w)) ≈ - LossFunctions.value(getfield(m, :loss), yhatm, ym, + LossFunctions.value(getfield(m, :loss), ym, yhatm, LossFunctions.AggMode.WeightedMean(w)) ≈ mean(m(yhat, y, w)) end @@ -58,10 +58,10 @@ end m_ex = MLJBase.snakecase(M_ex) @test m == eval(:(MLJBase.$m_ex)) @test MLJBase.value(m, yhat, X, y, nothing) ≈ - LossFunctions.value(getfield(m, :loss), yhat, y) ≈ + LossFunctions.value(getfield(m, :loss), y, yhat) ≈ m(yhat, y) @test mean(MLJBase.value(m, yhat, X, y, w)) ≈ - LossFunctions.value(getfield(m, :loss), yhat, y, + LossFunctions.value(getfield(m, :loss), y, yhat, LossFunctions.AggMode.WeightedMean(w)) ≈ mean(m(yhat ,y, w)) end