diff --git a/_modules/grl/generative_models/conditional_flow_model/independent_conditional_flow_model.html b/_modules/grl/generative_models/conditional_flow_model/independent_conditional_flow_model.html index a5412bc..e16dcf5 100644 --- a/_modules/grl/generative_models/conditional_flow_model/independent_conditional_flow_model.html +++ b/_modules/grl/generative_models/conditional_flow_model/independent_conditional_flow_model.html @@ -569,10 +569,24 @@

Source code for grl.generative_models.conditional_flow_model.independent_con # x.shape = (B*N, D) if condition is not None: - condition = torch.repeat_interleave( - condition, torch.prod(extra_batch_size), dim=0 - ) - # condition.shape = (B*N, D) + if isinstance(condition, torch.Tensor): + condition = torch.repeat_interleave( + condition, torch.prod(extra_batch_size), dim=0 + ) + # condition.shape = (B*N, D) + elif isinstance(condition, TensorDict): + condition = TensorDict( + { + key: torch.repeat_interleave( + condition[key], torch.prod(extra_batch_size), dim=0 + ) + for key in condition.keys() + }, + batch_size=torch.prod(extra_batch_size) * condition.shape, + device=condition.device, + ) + else: + raise NotImplementedError("Not implemented") if isinstance(solver, DPMSolver): raise NotImplementedError("Not implemented") diff --git a/_modules/grl/generative_models/conditional_flow_model/optimal_transport_conditional_flow_model.html b/_modules/grl/generative_models/conditional_flow_model/optimal_transport_conditional_flow_model.html index f9b344f..0b66935 100644 --- a/_modules/grl/generative_models/conditional_flow_model/optimal_transport_conditional_flow_model.html +++ b/_modules/grl/generative_models/conditional_flow_model/optimal_transport_conditional_flow_model.html @@ -520,10 +520,24 @@

Source code for grl.generative_models.conditional_flow_model.optimal_transpo # x.shape = (B*N, D) if condition is not None: - condition = torch.repeat_interleave( - condition, torch.prod(extra_batch_size), dim=0 - ) - # condition.shape = (B*N, D) + if isinstance(condition, torch.Tensor): + condition = torch.repeat_interleave( + condition, torch.prod(extra_batch_size), dim=0 + ) + # condition.shape = (B*N, D) + elif isinstance(condition, TensorDict): + condition = TensorDict( + { + key: torch.repeat_interleave( + condition[key], torch.prod(extra_batch_size), dim=0 + ) + for key in condition.keys() + }, + batch_size=torch.prod(extra_batch_size) * condition.shape, + device=condition.device, + ) + else: + raise NotImplementedError("Not implemented") if isinstance(solver, DPMSolver): raise NotImplementedError("Not implemented") diff --git a/_modules/grl/generative_models/diffusion_model/diffusion_model.html b/_modules/grl/generative_models/diffusion_model/diffusion_model.html index 4724196..6940591 100644 --- a/_modules/grl/generative_models/diffusion_model/diffusion_model.html +++ b/_modules/grl/generative_models/diffusion_model/diffusion_model.html @@ -547,10 +547,16 @@

Source code for grl.generative_models.diffusion_model.diffusion_model

) # condition.shape = (B*N, D) elif isinstance(condition, TensorDict): - for key in condition.keys(): - condition[key] = torch.repeat_interleave( - condition[key], torch.prod(extra_batch_size), dim=0 - ) + condition = TensorDict( + { + key: torch.repeat_interleave( + condition[key], torch.prod(extra_batch_size), dim=0 + ) + for key in condition.keys() + }, + batch_size=torch.prod(extra_batch_size) * condition.shape, + device=condition.device, + ) else: raise NotImplementedError("Not implemented") diff --git a/_modules/grl/generative_models/diffusion_model/energy_conditional_diffusion_model.html b/_modules/grl/generative_models/diffusion_model/energy_conditional_diffusion_model.html index 3eb6820..c43221d 100644 --- a/_modules/grl/generative_models/diffusion_model/energy_conditional_diffusion_model.html +++ b/_modules/grl/generative_models/diffusion_model/energy_conditional_diffusion_model.html @@ -650,26 +650,31 @@

Source code for grl.generative_models.diffusion_model.energy_conditional_dif # x.shape = (B*N, D) if condition is not None: - if isinstance(condition, TensorDict): - repeated_condition = TensorDict( + if isinstance(condition, torch.Tensor): + condition = torch.repeat_interleave( + condition, torch.prod(extra_batch_size), dim=0 + ) + # condition.shape = (B*N, D) + elif isinstance(condition, treetensor.torch.Tensor): + for key in condition.keys(): + condition[key] = torch.repeat_interleave( + condition[key], torch.prod(extra_batch_size), dim=0 + ) + # condition.shape = (B*N, D) + elif isinstance(condition, TensorDict): + condition = TensorDict( { key: torch.repeat_interleave( - value, torch.prod(extra_batch_size), dim=0 + condition[key], torch.prod(extra_batch_size), dim=0 ) - for key, value in condition.items() + for key in condition.keys() }, - batch_size=int( - torch.prod( - torch.tensor([*condition.batch_size, extra_batch_size]) - ) - ), + batch_size=torch.prod(extra_batch_size) * condition.shape, + device=condition.device, ) - repeated_condition.to(condition.device) - condition = repeated_condition else: - condition = torch.repeat_interleave( - condition, torch.prod(extra_batch_size), dim=0 - ) + raise NotImplementedError("Not implemented") + if isinstance(solver, DPMSolver): def noise_function_with_energy_guidance(t, x, condition):