Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use devices/dtypes based on passed in tensors #68

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions allrank/models/losses/bce.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from torch.nn import BCELoss

from allrank.data.dataset_loading import PADDED_Y_VALUE
from allrank.models.model_utils import get_torch_device


def bce(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE):
Expand All @@ -13,7 +12,7 @@ def bce(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE):
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:return: loss value, a torch.Tensor
"""
device = get_torch_device()
device = y_pred.device

y_pred = y_pred.clone()
y_true = y_true.clone()
Expand All @@ -25,7 +24,7 @@ def bce(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE):
ls[mask] = 0.0

document_loss = torch.sum(ls, dim=-1)
sum_valid = torch.sum(valid_mask, dim=-1).type(torch.float32) > torch.tensor(0.0, dtype=torch.float32, device=device)
sum_valid = torch.sum(valid_mask, dim=-1).type(y_pred.dtype) > torch.tensor(0.0, dtype=y_pred.dtype, device=device)

loss_output = torch.sum(document_loss) / torch.sum(sum_valid)

Expand Down
11 changes: 5 additions & 6 deletions allrank/models/losses/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import numpy as np

from allrank.models.losses import DEFAULT_EPS
from allrank.models.model_utils import get_torch_device


def sinkhorn_scaling(mat, mask=None, tol=1e-6, max_iter=50):
Expand Down Expand Up @@ -41,11 +40,11 @@ def deterministic_neural_sort(s, tau, mask):
:param mask: mask indicating padded elements
:return: approximate permutation matrices of shape [batch_size, slate_length, slate_length]
"""
dev = get_torch_device()
dev = s.device

n = s.size()[1]
one = torch.ones((n, 1), dtype=torch.float32, device=dev)
s = s.masked_fill(mask[:, :, None], -1e8)
one = torch.ones((n, 1), dtype=s.dtype, device=dev)
s = s.masked_fill(mask[:, :, None], -1e4)
A_s = torch.abs(s - s.permute(0, 2, 1))
A_s = A_s.masked_fill(mask[:, :, None] | mask[:, None, :], 0.0)

Expand All @@ -54,7 +53,7 @@ def deterministic_neural_sort(s, tau, mask):
temp = [n - m + 1 - 2 * (torch.arange(n - m, device=dev) + 1) for m in mask.squeeze(-1).sum(dim=1)]
temp = [t.type(torch.float32) for t in temp]
temp = [torch.cat((t, torch.zeros(n - len(t), device=dev))) for t in temp]
scaling = torch.stack(temp).type(torch.float32).to(dev) # type: ignore
scaling = torch.stack(temp).type(s.dtype).to(dev) # type: ignore

s = s.masked_fill(mask[:, :, None], 0.0)
C = torch.matmul(s, scaling.unsqueeze(-2))
Expand Down Expand Up @@ -95,7 +94,7 @@ def stochastic_neural_sort(s, n_samples, tau, mask, beta=1.0, log_scores=True, e
:param eps: epsilon for the logarithm function
:return: approximate permutation matrices of shape [n_samples, batch_size, slate_length, slate_length]
"""
dev = get_torch_device()
dev = s.device

batch_size = s.size()[0]
n = s.size()[1]
Expand Down
9 changes: 4 additions & 5 deletions allrank/models/losses/neuralNDCG.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from allrank.models.losses import DEFAULT_EPS
from allrank.models.losses.loss_utils import deterministic_neural_sort, sinkhorn_scaling, stochastic_neural_sort
from allrank.models.metrics import dcg
from allrank.models.model_utils import get_torch_device


def neuralNDCG(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, temperature=1., powered_relevancies=True, k=None,
Expand All @@ -24,7 +23,7 @@ def neuralNDCG(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, temperatur
:param log_scores: log_scores parameter for NeuralSort algorithm, used if stochastic == True
:return: loss value, a torch.Tensor
"""
dev = get_torch_device()
dev = y_pred.device

if k is None:
k = y_true.shape[1]
Expand Down Expand Up @@ -90,7 +89,7 @@ def neuralNDCG_transposed(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE,
:param tol: tolerance for Sinkhorn scaling
:return: loss value, a torch.Tensor
"""
dev = get_torch_device()
dev = y_pred.device

if k is None:
k = y_true.shape[1]
Expand All @@ -107,7 +106,7 @@ def neuralNDCG_transposed(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE,
P_hat_masked = sinkhorn_scaling(P_hat.view(P_hat.shape[0] * y_pred.shape[0], y_pred.shape[1], y_pred.shape[1]),
mask.repeat_interleave(P_hat.shape[0], dim=0), tol=tol, max_iter=max_iter)
P_hat_masked = P_hat_masked.view(P_hat.shape[0], y_pred.shape[0], y_pred.shape[1], y_pred.shape[1])
discounts = (torch.tensor(1) / torch.log2(torch.arange(y_true.shape[-1], dtype=torch.float) + 2.)).to(dev)
discounts = (torch.tensor(1) / torch.log2(torch.arange(y_true.shape[-1], dtype=y_pred.dtype) + 2.)).to(dev)

# This takes care of the @k metric truncation - if something is @>k, it is useless and gets 0.0 discount
discounts[k:] = 0.
Expand All @@ -133,4 +132,4 @@ def neuralNDCG_transposed(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE,
return torch.tensor(0.)

mean_ndcg = ndcg.sum() / ((~idcg_mask).sum() * ndcg.shape[0]) # type: ignore
return -1. * mean_ndcg # -1 cause we want to maximize NDCG
return -1. * mean_ndcg # -1 because we want to maximize NDCG
11 changes: 5 additions & 6 deletions allrank/models/losses/ordinal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from torch.nn import BCELoss

from allrank.data.dataset_loading import PADDED_Y_VALUE
from allrank.models.model_utils import get_torch_device


def with_ordinals(y, n, padded_value_indicator=PADDED_Y_VALUE):
Expand All @@ -13,11 +12,11 @@ def with_ordinals(y, n, padded_value_indicator=PADDED_Y_VALUE):
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:return: ordinals, shape [batch_size, slate_length, n]
"""
dev = get_torch_device()
one_to_n = torch.arange(start=1, end=n + 1, dtype=torch.float, device=dev)
dev = y.device
one_to_n = torch.arange(start=1, end=n + 1, dtype=y.dtype, device=dev)
unsqueezed = y.unsqueeze(2).repeat(1, 1, n)
mask = unsqueezed == padded_value_indicator
ordinals = (unsqueezed >= one_to_n).type(torch.float)
ordinals = (unsqueezed >= one_to_n).type(y.dtype)
ordinals[mask] = padded_value_indicator
return ordinals

Expand All @@ -31,7 +30,7 @@ def ordinal(y_pred, y_true, n, padded_value_indicator=PADDED_Y_VALUE):
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:return: loss value, a torch.Tensor
"""
device = get_torch_device()
device = y_pred.device

y_pred = y_pred.clone()
y_true = with_ordinals(y_true.clone(), n)
Expand All @@ -43,7 +42,7 @@ def ordinal(y_pred, y_true, n, padded_value_indicator=PADDED_Y_VALUE):
ls[mask] = 0.0

document_loss = torch.sum(ls, dim=2)
sum_valid = torch.sum(valid_mask, dim=2).type(torch.float32) > torch.tensor(0.0, dtype=torch.float32, device=device)
sum_valid = torch.sum(valid_mask, dim=2).type(y_pred.dtype) > torch.tensor(0.0, dtype=y_pred.dtype, device=device)

loss_output = torch.sum(document_loss) / torch.sum(sum_valid)

Expand Down
13 changes: 7 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@
README = (HERE / "README.md").read_text()

reqs = [
"scikit-learn>=1.1.0, <=1.2.1",
"pandas>=1.0.5, <=1.3.5",
"numpy>=1.18.5, <=1.21.6",
"scipy>=1.4.1, <=1.7.3",
"scikit-learn>=1.1.0",
"pandas>=1.0.5",
"numpy>=1.18.5",
"scipy>=1.4.1",
"attrs>=19.3.0",
"flatten_dict>=0.3.0",
"tensorboardX>=2.1.0",
"gcsfs==0.6.2",
"gcsfs>=0.6.2",
"google-auth>=2.15.0",
"fsspec <= 2023.1.0"
"fsspec",
"torchvision",
]

setup(
Expand Down