Skip to content

Commit

Permalink
remove X_train and X_test from parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Vincent-Maladiere committed Dec 18, 2024
1 parent 863f9cf commit 52864a2
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 32 deletions.
46 changes: 16 additions & 30 deletions hazardous/metrics/_concordance_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@
def concordance_index_incidence(
y_test,
y_pred,
y_train=None,
ipcw_estimator="km",
time_grid=None,
taus=None,
y_train=None,
X_train=None,
X_test=None,
event_of_interest=1,
ipcw_estimator="km",
tied_tol=1e-8,
):
r"""Time-dependent concordance index for prognostic models with competing risks \
Expand Down Expand Up @@ -94,34 +92,26 @@ def concordance_index_incidence(
Cumulative incidence for the event of interest, at the time points
from the input time_grid.
y_train : array, dictionnary or dataframe of shape (n_samples, 2), default=None
The train target, consisting in the 'event' and 'duration' columns.
Only used when ipcw_estimator = "km".
ipcw_estimator : None, 'km', or fitted estimator, default="km"
The inverse probability of censoring weighted (IPCW) estimator.
- Pass None to set uniform weights to all samples.
- Pass "km" to use the Kaplan-Meier IPCW estimator. It fits using y_train,
which must be set.
time_grid: array of shape (n_time_grid,), default=None
Time points used to predict the cumulative incidence.
taus: array of shape (n_taus,), default=None
float or vector, timepoints at which the concordance index is
evaluated.
y_train : array, dictionnary or dataframe of shape (n_samples, 2)
The train target, consisting in the 'event' and 'duration' columns.
X_train: array or dataframe of shape (n_samples_train, n_features), default=None
Covariates, used to learn a censor model if the inverse probability of
censoring weights (IPCW) is conditional on features (for instance Cox).
Unused if ipcw is None or 'km'.
X_test: array or dataframe of shape (n_samples_test, n_features), default=None
Covariates, used to predict weights a censor model if the inverse probability
of censoring weights (IPCW) is conditional on features (for instance Cox).
Unused if ipcw is None or 'km'.
event_of_interest: int, default=1
For competing risks, the event of interest.
ipcw_estimator : {None or 'km'}, default="km"
The inverse probability of censoring weighted (IPCW) estimator.
- None set uniform weights to all samples
- "km" use the Kaplan-Meier estimator
tied_tol : float, default=1e-8
The tolerance range to consider two probabilities equal.
Expand Down Expand Up @@ -149,13 +139,11 @@ def concordance_index_incidence(
c_index_report = _concordance_index_incidence_report(
y_test,
y_pred,
y_train=y_train,
ipcw_estimator=ipcw_estimator,
time_grid=time_grid,
taus=taus,
y_train=y_train,
X_train=X_train,
X_test=X_test,
event_of_interest=event_of_interest,
ipcw_estimator=ipcw_estimator,
tied_tol=tied_tol,
)
return np.array(c_index_report["cindex"])
Expand All @@ -164,13 +152,11 @@ def concordance_index_incidence(
def _concordance_index_incidence_report(
y_test,
y_pred,
y_train=None,
ipcw_estimator="km",
time_grid=None,
taus=None,
y_train=None,
X_train=None,
X_test=None,
event_of_interest=1,
ipcw_estimator="km",
tied_tol=1e-8,
):
"""Report version of function `concordance_index_incidence`.
Expand Down
14 changes: 12 additions & 2 deletions hazardous/tests/test_cindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,12 @@ def test_concordance_index_incidence_report_competitive():
time_grid = [10, 20, 30]

res = _concordance_index_incidence_report(
y_test, y_pred, time_grid, taus=None, y_train=y_test, event_of_interest=1
y_test,
y_pred,
y_train=y_test,
time_grid=time_grid,
taus=None,
event_of_interest=1,
)
assert res == {
"cindex": [1.0],
Expand All @@ -472,7 +477,12 @@ def test_concordance_index_incidence_report_competitive():
}

res = _concordance_index_incidence_report(
y_test, y_pred, time_grid, taus=None, y_train=y_test, event_of_interest=2
y_test,
y_pred,
y_train=y_test,
time_grid=time_grid,
taus=None,
event_of_interest=2,
)

assert res == {
Expand Down

0 comments on commit 52864a2

Please sign in to comment.