Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

label_gain used in 'rank_xendcg' objective despite docs saying only 'lambdarank' #6722

Open
AlexKC99 opened this issue Nov 14, 2024 · 0 comments
Labels

Comments

@AlexKC99
Copy link

AlexKC99 commented Nov 14, 2024

Description

https://lightgbm.readthedocs.io/en/latest/Parameters.html
From the docs, label_gain is purported to only apply to 'lambdarank' applications. However, when using the 'rank_xendcg' objective, changing the label_gain parameter impacts the training result. I highlight this in a dummy example below:

Reproducible example

import lightgbm as lgb
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

np.random.seed(42)
num_samples = 100
num_features = 5

X = pd.DataFrame(np.random.rand(num_samples, num_features), columns=[f"feature_{i}" for i in range(num_features)])
y = np.random.choice([0, 1, 2, 3, 4], size=num_samples)
groups = [20] * 5

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

train_size = int(len(X_train) / num_samples * sum(groups))
groups_train = groups[:train_size // 20]
groups_val = groups[train_size // 20:]

train_data = lgb.Dataset(X_train, label=y_train, group=groups_train)
val_data = lgb.Dataset(X_val, label=y_val, group=groups_val, reference=train_data)

params = {
'objective': 'rank_xendcg',
'metric': 'ndcg',
'num_leaves': 31,
'learning_rate': 0.05,
'eval_at':[20],
'verbose': -1
}

print("Training with default label_gain...")
ranker_model_default = lgb.train(
params,
train_data,
num_boost_round=10,
valid_sets=[train_data, val_data],
valid_names=['train', 'val'],
)

params['label_gain'] = [0, 1, 2, 3, 4]
print("\nTraining with custom label_gain...")
ranker_model_custom_gain = lgb.train(
params,
train_data,
num_boost_round=10,
valid_sets=[train_data, val_data],
valid_names=['train', 'val'],
)

print("\nDefault label_gain NDCG:", ranker_model_default.best_score)
print("Custom label_gain NDCG:", ranker_model_custom_gain.best_score)

Default label_gain NDCG: defaultdict(<class 'collections.OrderedDict'>, {'train': OrderedDict([('ndcg@20', 0.8419975079526356)]), 'val': OrderedDict([('ndcg@20', 0.6362625303185142)])})
Custom label_gain NDCG: defaultdict(<class 'collections.OrderedDict'>, {'train': OrderedDict([('ndcg@20', 0.8813059338868674)]), 'val': OrderedDict([('ndcg@20', 0.7695575603309358)])})

Environment info

Name: lightgbm
Version: 4.0.0
Summary: LightGBM Python Package

Additional Comments

This leads to undesirable behaviour where changing target values does not map to changes in desired gain. Am I missing something?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants