diff --git a/src/gluonts/model/forecast_generator.py b/src/gluonts/model/forecast_generator.py index b9ff10c8dc..2db978ea72 100644 --- a/src/gluonts/model/forecast_generator.py +++ b/src/gluonts/model/forecast_generator.py @@ -87,7 +87,7 @@ def make_predictions(prediction_net, inputs: dict): # Feed inputs as positional arguments for MXNet predictors import mxnet as mx - if isinstance(prediction_net, (mx.gluon.Block)): + if isinstance(prediction_net, mx.gluon.Block): return prediction_net(*inputs.values()) except ImportError: pass