diff --git a/_modules/grl/generative_models/conditional_flow_model/independent_conditional_flow_model.html b/_modules/grl/generative_models/conditional_flow_model/independent_conditional_flow_model.html
index a5412bc..e16dcf5 100644
--- a/_modules/grl/generative_models/conditional_flow_model/independent_conditional_flow_model.html
+++ b/_modules/grl/generative_models/conditional_flow_model/independent_conditional_flow_model.html
@@ -569,10 +569,24 @@
Source code for grl.generative_models.conditional_flow_model.independent_con
# x.shape = (B*N, D)
if condition is not None:
- condition = torch.repeat_interleave(
- condition, torch.prod(extra_batch_size), dim=0
- )
- # condition.shape = (B*N, D)
+ if isinstance(condition, torch.Tensor):
+ condition = torch.repeat_interleave(
+ condition, torch.prod(extra_batch_size), dim=0
+ )
+ # condition.shape = (B*N, D)
+ elif isinstance(condition, TensorDict):
+ condition = TensorDict(
+ {
+ key: torch.repeat_interleave(
+ condition[key], torch.prod(extra_batch_size), dim=0
+ )
+ for key in condition.keys()
+ },
+ batch_size=torch.prod(extra_batch_size) * condition.shape,
+ device=condition.device,
+ )
+ else:
+ raise NotImplementedError("Not implemented")
if isinstance(solver, DPMSolver):
raise NotImplementedError("Not implemented")
diff --git a/_modules/grl/generative_models/conditional_flow_model/optimal_transport_conditional_flow_model.html b/_modules/grl/generative_models/conditional_flow_model/optimal_transport_conditional_flow_model.html
index f9b344f..0b66935 100644
--- a/_modules/grl/generative_models/conditional_flow_model/optimal_transport_conditional_flow_model.html
+++ b/_modules/grl/generative_models/conditional_flow_model/optimal_transport_conditional_flow_model.html
@@ -520,10 +520,24 @@ Source code for grl.generative_models.conditional_flow_model.optimal_transpo
# x.shape = (B*N, D)
if condition is not None:
- condition = torch.repeat_interleave(
- condition, torch.prod(extra_batch_size), dim=0
- )
- # condition.shape = (B*N, D)
+ if isinstance(condition, torch.Tensor):
+ condition = torch.repeat_interleave(
+ condition, torch.prod(extra_batch_size), dim=0
+ )
+ # condition.shape = (B*N, D)
+ elif isinstance(condition, TensorDict):
+ condition = TensorDict(
+ {
+ key: torch.repeat_interleave(
+ condition[key], torch.prod(extra_batch_size), dim=0
+ )
+ for key in condition.keys()
+ },
+ batch_size=torch.prod(extra_batch_size) * condition.shape,
+ device=condition.device,
+ )
+ else:
+ raise NotImplementedError("Not implemented")
if isinstance(solver, DPMSolver):
raise NotImplementedError("Not implemented")
diff --git a/_modules/grl/generative_models/diffusion_model/diffusion_model.html b/_modules/grl/generative_models/diffusion_model/diffusion_model.html
index 4724196..6940591 100644
--- a/_modules/grl/generative_models/diffusion_model/diffusion_model.html
+++ b/_modules/grl/generative_models/diffusion_model/diffusion_model.html
@@ -547,10 +547,16 @@ Source code for grl.generative_models.diffusion_model.diffusion_model
)
# condition.shape = (B*N, D)
elif isinstance(condition, TensorDict):
- for key in condition.keys():
- condition[key] = torch.repeat_interleave(
- condition[key], torch.prod(extra_batch_size), dim=0
- )
+ condition = TensorDict(
+ {
+ key: torch.repeat_interleave(
+ condition[key], torch.prod(extra_batch_size), dim=0
+ )
+ for key in condition.keys()
+ },
+ batch_size=torch.prod(extra_batch_size) * condition.shape,
+ device=condition.device,
+ )
else:
raise NotImplementedError("Not implemented")
diff --git a/_modules/grl/generative_models/diffusion_model/energy_conditional_diffusion_model.html b/_modules/grl/generative_models/diffusion_model/energy_conditional_diffusion_model.html
index 3eb6820..c43221d 100644
--- a/_modules/grl/generative_models/diffusion_model/energy_conditional_diffusion_model.html
+++ b/_modules/grl/generative_models/diffusion_model/energy_conditional_diffusion_model.html
@@ -650,26 +650,31 @@ Source code for grl.generative_models.diffusion_model.energy_conditional_dif
# x.shape = (B*N, D)
if condition is not None:
- if isinstance(condition, TensorDict):
- repeated_condition = TensorDict(
+ if isinstance(condition, torch.Tensor):
+ condition = torch.repeat_interleave(
+ condition, torch.prod(extra_batch_size), dim=0
+ )
+ # condition.shape = (B*N, D)
+ elif isinstance(condition, treetensor.torch.Tensor):
+ for key in condition.keys():
+ condition[key] = torch.repeat_interleave(
+ condition[key], torch.prod(extra_batch_size), dim=0
+ )
+ # condition.shape = (B*N, D)
+ elif isinstance(condition, TensorDict):
+ condition = TensorDict(
{
key: torch.repeat_interleave(
- value, torch.prod(extra_batch_size), dim=0
+ condition[key], torch.prod(extra_batch_size), dim=0
)
- for key, value in condition.items()
+ for key in condition.keys()
},
- batch_size=int(
- torch.prod(
- torch.tensor([*condition.batch_size, extra_batch_size])
- )
- ),
+ batch_size=torch.prod(extra_batch_size) * condition.shape,
+ device=condition.device,
)
- repeated_condition.to(condition.device)
- condition = repeated_condition
else:
- condition = torch.repeat_interleave(
- condition, torch.prod(extra_batch_size), dim=0
- )
+ raise NotImplementedError("Not implemented")
+
if isinstance(solver, DPMSolver):
def noise_function_with_energy_guidance(t, x, condition):