Skip to content

Commit

Permalink
[BugFix] Remove raisers in specs
Browse files Browse the repository at this point in the history
ghstack-source-id: a005a62847aa2ff1d286f2c4ad13fd14f9e631d3
Pull Request resolved: #2651
  • Loading branch information
vmoens committed Dec 14, 2024
1 parent 9e2d214 commit bb6f87a
Showing 1 changed file with 0 additions and 8 deletions.
8 changes: 0 additions & 8 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,7 +1402,6 @@ def type_check(self, value: torch.Tensor, key: NestedKey | None = None) -> None:
spec.type_check(val)

def is_in(self, value) -> bool:
raise RuntimeError
if self.dim == 0 and not hasattr(value, "unbind"):
# We don't use unbind because value could be a tuple or a nested tensor
return all(
Expand Down Expand Up @@ -1834,7 +1833,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
if self.mask is None:
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
shape_match = val.shape == shape
Expand Down Expand Up @@ -2288,7 +2286,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
val_shape = _remove_neg_shapes(tensordict.utils._shape(val))
shape = torch.broadcast_shapes(self._safe_shape, val_shape)
shape = list(shape)
Expand Down Expand Up @@ -2489,7 +2486,6 @@ def one(self, shape=None):
)

def is_in(self, val: Any) -> bool:
raise RuntimeError
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
return (
is_non_tensor(val)
Expand Down Expand Up @@ -2682,7 +2678,6 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
return torch.empty(shape, device=self.device, dtype=self.dtype).random_()

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
return val.shape == shape and val.dtype == self.dtype

Expand Down Expand Up @@ -3034,7 +3029,6 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten
return torch.cat(out, -1)

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
vals = self._split(val)
if vals is None:
return False
Expand Down Expand Up @@ -3435,7 +3429,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
if self.mask is None:
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
shape_match = val.shape == shape
Expand Down Expand Up @@ -4066,7 +4059,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val.squeeze(0) if val_is_scalar else val

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
if self.mask is not None:
vals = val.unbind(-1)
splits = self._split_self()
Expand Down

1 comment on commit bb6f87a

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: bb6f87a Previous: 9e2d214 Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] 22.821619751468905 iter/sec (stddev: 0.16449992306829694) 48.819954522949345 iter/sec (stddev: 0.0004559314935193707) 2.14

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.