Skip to content

Commit

Permalink
Update utils_ate_bounds.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jaabmar authored Apr 25, 2024
1 parent f8d3b8f commit eafe188
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/test_confounding/ate_bounds/utils_ate_bounds.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, Optional, Union

import numpy as np
from mlinsights.mlmodel import QuantileLinearRegression
# from mlinsights.mlmodel import QuantileLinearRegression
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.linear_model import QuantileRegressor
from sklearn.metrics import make_scorer, mean_pinball_loss
Expand All @@ -26,7 +26,7 @@ def get_quantile_regressor(
quantile_param_dist: Optional[Dict[str, Union[float, int]]] = None,
fast_solver: bool = False,
) -> Optional[
Union[QuantileRegressor, GradientBoostingRegressor, QuantileLinearRegression]
Union[QuantileRegressor, GradientBoostingRegressor]
]:
"""
Trains a quantile regressor model on given data and returns the best model based on test scores.
Expand All @@ -46,7 +46,7 @@ def get_quantile_regressor(
fast_solver (bool, optional): Use a faster solver for linear quantile regression. Default is False.
Returns:
Optional[Union[QuantileRegressor, GradientBoostingRegressor, QuantileLinearRegression]]:
Optional[Union[QuantileRegressor, GradientBoostingRegressor]]:
Best model object based on test scores.
"""

Expand All @@ -58,7 +58,8 @@ def get_quantile_regressor(

# Solve linear quantile regression approximately (no linear program) --> much faster and scales to 100k+ datapoints
if fast_solver:
quant_reg = QuantileLinearRegression(quantile=tau, max_iter=500)
# quant_reg = QuantileLinearRegression(quantile=tau, max_iter=500)
quant_reg = QuantileRegressor(quantile=tau, max_iter=500)
quant_reg.fit(X, y)
return quant_reg

Expand Down

0 comments on commit eafe188

Please sign in to comment.