From a4cc32649196daaa43ac4acc9c97ea3b530cbca0 Mon Sep 17 00:00:00 2001 From: zjowowen <93968541+zjowowen@users.noreply.github.com> Date: Wed, 30 Oct 2024 12:54:54 +0000 Subject: [PATCH] =?UTF-8?q?Deploying=20to=20gh-pages=20from=20@=20opendila?= =?UTF-8?q?b/GenerativeRL@4d9276e327e9c2b982c58e7c95f7fdbba26c6a78=20?= =?UTF-8?q?=F0=9F=9A=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- _modules/grl/algorithms/gmpg.html | 1963 +++++++++++++++++ _modules/grl/algorithms/gmpo.html | 1799 +++++++++++++++ _modules/grl/algorithms/qgpo.html | 1201 ++++++++++ _modules/grl/algorithms/srpo.html | 1137 ++++++++++ .../simulators/gym_env_simulator.html | 780 +++++++ .../one_shot_value_function.html | 516 +++++ .../rl_modules/value_network/q_network.html | 558 +++++ .../value_network/value_network.html | 562 +++++ _modules/index.html | 8 + api_doc/algorithms/index.html | 1248 ++++++++++- api_doc/rl_modules/index.html | 516 ++++- genindex.html | 244 +- objects.inv | Bin 4464 -> 6340 bytes py-modindex.html | 10 + searchindex.js | 2 +- 15 files changed, 10515 insertions(+), 29 deletions(-) create mode 100644 _modules/grl/algorithms/gmpg.html create mode 100644 _modules/grl/algorithms/gmpo.html create mode 100644 _modules/grl/algorithms/qgpo.html create mode 100644 _modules/grl/algorithms/srpo.html create mode 100644 _modules/grl/rl_modules/simulators/gym_env_simulator.html create mode 100644 _modules/grl/rl_modules/value_network/one_shot_value_function.html create mode 100644 _modules/grl/rl_modules/value_network/q_network.html create mode 100644 _modules/grl/rl_modules/value_network/value_network.html diff --git a/_modules/grl/algorithms/gmpg.html b/_modules/grl/algorithms/gmpg.html new file mode 100644 index 0000000..f5defb3 --- /dev/null +++ b/_modules/grl/algorithms/gmpg.html @@ -0,0 +1,1963 @@ + + + + + + + + + + + + + + + grl.algorithms.gmpg — GenerativeRL v0.0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ +
+
+ + + + + + + + + + + +
+
+
+ + + + + + + + + + + + + + + + +
+ + + + +
+
+ +
+ Shortcuts +
+
+ +
+
+ +
+ +
+
+ +

Source code for grl.algorithms.gmpg

+import os
+import copy
+from typing import List, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from easydict import EasyDict
+from rich.progress import track
+from tensordict import TensorDict
+from torchrl.data import TensorDictReplayBuffer
+from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
+
+import wandb
+from grl.agents.gm import GPAgent
+
+from grl.datasets import create_dataset
+from grl.datasets.gp import GPDataset, GPD4RLDataset
+from grl.generative_models.diffusion_model import DiffusionModel
+from grl.generative_models.conditional_flow_model.optimal_transport_conditional_flow_model import (
+    OptimalTransportConditionalFlowModel,
+)
+from grl.generative_models.conditional_flow_model.independent_conditional_flow_model import (
+    IndependentConditionalFlowModel,
+)
+from grl.generative_models.bridge_flow_model.schrodinger_bridge_conditional_flow_model import (
+    SchrodingerBridgeConditionalFlowModel,
+)
+
+from grl.rl_modules.simulators import create_simulator
+from grl.rl_modules.value_network.q_network import DoubleQNetwork
+from grl.rl_modules.value_network.value_network import VNetwork, DoubleVNetwork
+from grl.utils.config import merge_two_dicts_into_newone
+from grl.utils.log import log
+from grl.utils import set_seed
+from grl.utils.statistics import sort_files_by_criteria
+from grl.generative_models.metric import compute_likelihood
+from grl.utils.plot import plot_distribution, plot_histogram2d_x_y
+
+
+def asymmetric_l2_loss(u, tau):
+    return torch.mean(torch.abs(tau - (u < 0).float()) * u**2)
+
+
+
[docs]class GMPGCritic(nn.Module): + """ + Overview: + Critic network. + Interfaces: + ``__init__``, ``forward`` + """ + +
[docs] def __init__(self, config: EasyDict): + """ + Overview: + Initialization of GPO critic network. + Arguments: + config (:obj:`EasyDict`): The configuration dict. + """ + + super().__init__() + self.config = config + self.q_alpha = config.q_alpha + self.q = DoubleQNetwork(config.DoubleQNetwork) + self.q_target = copy.deepcopy(self.q).requires_grad_(False) + self.v = VNetwork(config.VNetwork)
+ +
[docs] def forward( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict] = None, + ) -> torch.Tensor: + """ + Overview: + Return the output of GPO critic. + Arguments: + action (:obj:`torch.Tensor`): The input action. + state (:obj:`torch.Tensor`): The input state. + """ + + return self.q(action, state)
+ +
[docs] def compute_double_q( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Return the output of two Q networks. + Arguments: + action (:obj:`Union[torch.Tensor, TensorDict]`): The input action. + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + Returns: + q1 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the first Q network. + q2 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the second Q network. + """ + return self.q.compute_double_q(action, state)
+ +
[docs] def in_support_ql_loss( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict], + reward: Union[torch.Tensor, TensorDict], + next_state: Union[torch.Tensor, TensorDict], + done: Union[torch.Tensor, TensorDict], + fake_next_action: Union[torch.Tensor, TensorDict], + discount_factor: float = 1.0, + ) -> torch.Tensor: + """ + Overview: + Calculate the Q loss. + Arguments: + action (:obj:`torch.Tensor`): The input action. + state (:obj:`torch.Tensor`): The input state. + reward (:obj:`torch.Tensor`): The input reward. + next_state (:obj:`torch.Tensor`): The input next state. + done (:obj:`torch.Tensor`): The input done. + fake_next_action (:obj:`torch.Tensor`): The input fake next action. + discount_factor (:obj:`float`): The discount factor. + """ + with torch.no_grad(): + softmax = nn.Softmax(dim=1) + next_energy = ( + self.q_target( + fake_next_action, + torch.stack([next_state] * fake_next_action.shape[1], axis=1), + ) + .detach() + .squeeze(dim=-1) + ) + next_v = torch.sum( + softmax(self.q_alpha * next_energy) * next_energy, dim=-1, keepdim=True + ) + # Update Q function + targets = reward + (1.0 - done.float()) * discount_factor * next_v.detach() + q0, q1 = self.q.compute_double_q(action, state) + q_loss = ( + torch.nn.functional.mse_loss(q0, targets) + + torch.nn.functional.mse_loss(q1, targets) + ) / 2 + return q_loss, torch.mean(q0), torch.mean(targets)
+ + def v_loss(self, state, action, next_state, tau): + with torch.no_grad(): + target_q = self.q_target(action, state).detach() + next_v = self.v(next_state).detach() + # Update value function + v = self.v(state) + adv = target_q - v + v_loss = asymmetric_l2_loss(adv, tau) + return v_loss, next_v + + def iql_q_loss(self, state, action, reward, done, next_v, discount): + q_target = reward + (1.0 - done.float()) * discount * next_v.detach() + qs = self.q.compute_double_q(action, state) + q_loss = sum(torch.nn.functional.mse_loss(q, q_target) for q in qs) / len(qs) + return q_loss, torch.mean(qs[0]), torch.mean(q_target)
+ + +
[docs]class GMPGPolicy(nn.Module): + +
[docs] def __init__(self, config: EasyDict): + super().__init__() + self.config = config + self.device = config.device + + self.critic = GMPGCritic(config.critic) + self.model_type = config.model_type + if self.model_type == "DiffusionModel": + self.base_model = DiffusionModel(config.model) + self.guided_model = DiffusionModel(config.model) + self.model_loss_type = config.model_loss_type + assert self.model_loss_type in ["score_matching", "flow_matching"] + elif self.model_type == "OptimalTransportConditionalFlowModel": + self.base_model = OptimalTransportConditionalFlowModel(config.model) + self.guided_model = OptimalTransportConditionalFlowModel(config.model) + elif self.model_type == "IndependentConditionalFlowModel": + self.base_model = IndependentConditionalFlowModel(config.model) + self.guided_model = IndependentConditionalFlowModel(config.model) + elif self.model_type == "SchrodingerBridgeConditionalFlowModel": + self.base_model = SchrodingerBridgeConditionalFlowModel(config.model) + self.guided_model = SchrodingerBridgeConditionalFlowModel(config.model) + else: + raise NotImplementedError + self.softmax = nn.Softmax(dim=1)
+ +
[docs] def forward( + self, state: Union[torch.Tensor, TensorDict] + ) -> Union[torch.Tensor, TensorDict]: + """ + Overview: + Return the output of GPO policy, which is the action conditioned on the state. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + Returns: + action (:obj:`Union[torch.Tensor, TensorDict]`): The output action. + """ + return self.sample(state)
+ +
[docs] def sample( + self, + state: Union[torch.Tensor, TensorDict], + batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, + solver_config: EasyDict = None, + t_span: torch.Tensor = None, + with_grad: bool = False, + ) -> Union[torch.Tensor, TensorDict]: + """ + Overview: + Return the output of GPO policy, which is the action conditioned on the state. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + batch_size (:obj:`Union[torch.Size, int, Tuple[int], List[int]]`): The batch size. + solver_config (:obj:`EasyDict`): The configuration for the ODE solver. + t_span (:obj:`torch.Tensor`): The time span for the ODE solver or SDE solver. + Returns: + action (:obj:`Union[torch.Tensor, TensorDict]`): The output action. + """ + + return self.guided_model.sample( + t_span=t_span, + condition=state, + batch_size=batch_size, + with_grad=with_grad, + solver_config=solver_config, + )
+ +
[docs] def behaviour_policy_sample( + self, + state: Union[torch.Tensor, TensorDict], + batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, + solver_config: EasyDict = None, + t_span: torch.Tensor = None, + with_grad: bool = False, + ) -> Union[torch.Tensor, TensorDict]: + """ + Overview: + Return the output of behaviour policy, which is the action conditioned on the state. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + batch_size (:obj:`Union[torch.Size, int, Tuple[int], List[int]]`): The batch size. + solver_config (:obj:`EasyDict`): The configuration for the ODE solver. + t_span (:obj:`torch.Tensor`): The time span for the ODE solver or SDE solver. + with_grad (:obj:`bool`): Whether to calculate the gradient. + Returns: + action (:obj:`Union[torch.Tensor, TensorDict]`): The output action. + """ + return self.base_model.sample( + t_span=t_span, + condition=state, + batch_size=batch_size, + with_grad=with_grad, + solver_config=solver_config, + )
+ +
[docs] def compute_q( + self, + state: Union[torch.Tensor, TensorDict], + action: Union[torch.Tensor, TensorDict], + ) -> torch.Tensor: + """ + Overview: + Calculate the Q value. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + action (:obj:`Union[torch.Tensor, TensorDict]`): The input action. + Returns: + q (:obj:`torch.Tensor`): The Q value. + """ + + return self.critic(action, state)
+ +
[docs] def behaviour_policy_loss( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict], + maximum_likelihood: bool = False, + ): + """ + Overview: + Calculate the behaviour policy loss. + Arguments: + action (:obj:`torch.Tensor`): The input action. + state (:obj:`torch.Tensor`): The input state. + """ + + if self.model_type == "DiffusionModel": + if self.model_loss_type == "score_matching": + if maximum_likelihood: + return self.base_model.score_matching_loss(action, state) + else: + return self.base_model.score_matching_loss( + action, state, weighting_scheme="vanilla" + ) + elif self.model_loss_type == "flow_matching": + return self.base_model.flow_matching_loss(action, state) + elif self.model_type in [ + "OptimalTransportConditionalFlowModel", + "IndependentConditionalFlowModel", + "SchrodingerBridgeConditionalFlowModel", + ]: + x0 = self.base_model.gaussian_generator(batch_size=state.shape[0]) + return self.base_model.flow_matching_loss(x0=x0, x1=action, condition=state)
+ + def policy_gradient_loss( + self, + state: Union[torch.Tensor, TensorDict], + gradtime_step: int = 1000, + beta: float = 1.0, + repeats: int = 1, + ): + t_span = torch.linspace(0.0, 1.0, gradtime_step).to(state.device) + + def log_grad(name, grad): + wandb.log( + { + f"{name}_mean": grad.mean().item(), + f"{name}_max": grad.max().item(), + f"{name}_min": grad.min().item(), + }, + commit=False, + ) + + if repeats == 1: + state_repeated = state + else: + state_repeated = torch.repeat_interleave( + state, repeats=repeats, dim=0 + ).requires_grad_() + + action_repeated = self.guided_model.sample( + t_span=t_span, condition=state_repeated, with_grad=True + ) + + q_value_repeated = self.critic(action_repeated, state_repeated).squeeze(dim=-1) + + log_p = compute_likelihood( + model=self.guided_model, + x=action_repeated, + condition=state_repeated, + t=t_span, + using_Hutchinson_trace_estimator=True, + ) + + bits_ratio = torch.prod( + torch.tensor(action_repeated.shape[1], device=state.device) + ) * torch.log(torch.tensor(2.0, device=state.device)) + + log_p_per_dim = log_p / bits_ratio + log_mu = compute_likelihood( + model=self.base_model, + x=action_repeated, + condition=state_repeated, + t=t_span, + using_Hutchinson_trace_estimator=True, + ) + + log_mu_per_dim = log_mu / bits_ratio + + if repeats > 1: + q_value_repeated = q_value_repeated.reshape(-1, repeats) + log_p_per_dim = log_p_per_dim.reshape(-1, repeats) + log_mu_per_dim = log_mu_per_dim.reshape(-1, repeats) + + return ( + ( + -beta * q_value_repeated.mean(dim=1) + + log_p_per_dim(dim=1) + - log_mu_per_dim(dim=1) + ), + -beta * q_value_repeated.detach().mean(), + log_p_per_dim.detach().mean(), + -log_mu_per_dim.detach().mean(), + ) + else: + return ( + (-beta * q_value_repeated + log_p_per_dim - log_mu_per_dim).mean(), + -beta * q_value_repeated.detach().mean(), + log_p_per_dim.detach().mean(), + -log_mu_per_dim.detach().mean(), + ) + + def policy_gradient_loss_by_REINFORCE( + self, + state: Union[torch.Tensor, TensorDict], + gradtime_step: int = 1000, + beta: float = 1.0, + repeats: int = 1, + weight_clamp: float = 100.0, + ): + t_span = torch.linspace(0.0, 1.0, gradtime_step).to(state.device) + + state_repeated = torch.repeat_interleave(state, repeats=repeats, dim=0) + action_repeated = self.base_model.sample( + t_span=t_span, condition=state_repeated, with_grad=False + ) + q_value_repeated = self.critic(action_repeated, state_repeated).squeeze(dim=-1) + v_value_repeated = self.critic.v(state_repeated).squeeze(dim=-1) + + weight = ( + torch.exp(beta * (q_value_repeated - v_value_repeated)).clamp( + max=weight_clamp + ) + / weight_clamp + ) + + log_p = compute_likelihood( + model=self.guided_model, + x=action_repeated, + condition=state_repeated, + t=t_span, + using_Hutchinson_trace_estimator=True, + ) + bits_ratio = torch.prod( + torch.tensor(action_repeated.shape[1], device=state.device) + ) * torch.log(torch.tensor(2.0, device=state.device)) + log_p_per_dim = log_p / bits_ratio + log_mu = compute_likelihood( + model=self.base_model, + x=action_repeated, + condition=state_repeated, + t=t_span, + using_Hutchinson_trace_estimator=True, + ) + log_mu_per_dim = log_mu / bits_ratio + + loss = ( + ( + -beta * q_value_repeated.detach() + + log_p_per_dim.detach() + - log_mu_per_dim.detach() + ) + * log_p_per_dim + * weight + ) + with torch.no_grad(): + loss_q = -beta * q_value_repeated.detach().mean() + loss_p = log_p_per_dim.detach().mean() + loss_u = -log_mu_per_dim.detach().mean() + return loss, loss_q, loss_p, loss_u + + def policy_gradient_loss_by_REINFORCE_softmax( + self, + state: Union[torch.Tensor, TensorDict], + gradtime_step: int = 1000, + beta: float = 1.0, + repeats: int = 10, + ): + assert repeats > 1 + t_span = torch.linspace(0.0, 1.0, gradtime_step).to(state.device) + + state_repeated = torch.repeat_interleave(state, repeats=repeats, dim=0) + action_repeated = self.base_model.sample( + t_span=t_span, condition=state_repeated, with_grad=False + ) + q_value_repeated = self.critic(action_repeated, state_repeated).squeeze(dim=-1) + q_value_reshaped = q_value_repeated.reshape(-1, repeats) + + weight = nn.Softmax(dim=1)(q_value_reshaped * beta) + weight = weight.reshape(-1) + + log_p = compute_likelihood( + model=self.guided_model, + x=action_repeated, + condition=state_repeated, + t=t_span, + using_Hutchinson_trace_estimator=True, + ) + bits_ratio = torch.prod( + torch.tensor(action_repeated.shape[1], device=state.device) + ) * torch.log(torch.tensor(2.0, device=state.device)) + log_p_per_dim = log_p / bits_ratio + log_mu = compute_likelihood( + model=self.base_model, + x=action_repeated, + condition=state_repeated, + t=t_span, + using_Hutchinson_trace_estimator=True, + ) + log_mu_per_dim = log_mu / bits_ratio + + loss = ( + ( + -beta * q_value_repeated.detach() + + log_p_per_dim.detach() + - log_mu_per_dim.detach() + ) + * log_p_per_dim + * weight + ) + loss_q = -beta * q_value_repeated.detach().mean() + loss_p = log_p_per_dim.detach().mean() + loss_u = -log_mu_per_dim.detach().mean() + return loss, loss_q, loss_p, loss_u + + def policy_gradient_loss_add_matching_loss( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict], + maximum_likelihood: bool = False, + gradtime_step: int = 1000, + beta: float = 1.0, + repeats: int = 1, + ): + + t_span = torch.linspace(0.0, 1.0, gradtime_step).to(state.device) + + if repeats == 1: + state_repeated = state + else: + state_repeated = torch.repeat_interleave( + state, repeats=repeats, dim=0 + ).requires_grad_() + + action_repeated = self.guided_model.sample( + t_span=t_span, condition=state_repeated, with_grad=True + ) + + q_value_repeated = self.critic(action_repeated, state_repeated).squeeze(dim=-1) + + loss_q = -beta * q_value_repeated.mean() + + loss_matching = self.behaviour_policy_loss( + action=action, state=state, maximum_likelihood=maximum_likelihood + ) + + loss = loss_q + loss_matching + + return loss, loss_q, loss_matching
+ + +
[docs]class GMPGAlgorithm: + """ + Overview: + The Generative Model Policy Gradient(GMPG) algorithm. + Interfaces: + ``__init__``, ``train``, ``deploy`` + """ + +
[docs] def __init__( + self, + config: EasyDict = None, + simulator=None, + dataset: GPDataset = None, + model: Union[torch.nn.Module, torch.nn.ModuleDict] = None, + seed=None, + ): + """ + Overview: + Initialize algorithm. + Arguments: + config (:obj:`EasyDict`): The configuration , which must contain the following keys: + train (:obj:`EasyDict`): The training configuration. + deploy (:obj:`EasyDict`): The deployment configuration. + simulator (:obj:`object`): The environment simulator. + dataset (:obj:`GPDataset`): The dataset. + model (:obj:`Union[torch.nn.Module, torch.nn.ModuleDict]`): The model. + Interface: + ``__init__``, ``train``, ``deploy`` + """ + self.config = config + self.simulator = simulator + self.dataset = dataset + self.seed_value = set_seed(seed) + + # --------------------------------------- + # Customized model initialization code ↓ + # --------------------------------------- + + if model is not None: + self.model = model + self.behaviour_policy_train_epoch = 0 + self.critic_train_epoch = 0 + self.guided_policy_train_epoch = 0 + else: + self.model = torch.nn.ModuleDict() + config = self.config.train + assert hasattr(config.model, "GPPolicy") + + if torch.__version__ >= "2.0.0": + self.model["GPPolicy"] = torch.compile( + GMPGPolicy(config.model.GPPolicy).to(config.model.GPPolicy.device) + ) + else: + self.model["GPPolicy"] = GMPGPolicy(config.model.GPPolicy).to( + config.model.GPPolicy.device + ) + + if ( + hasattr(config.parameter, "checkpoint_path") + and config.parameter.checkpoint_path is not None + ): + if not os.path.exists(config.parameter.checkpoint_path): + log.warning( + f"Checkpoint path {config.parameter.checkpoint_path} does not exist" + ) + self.behaviour_policy_train_epoch = -1 + self.critic_train_epoch = -1 + self.guided_policy_train_epoch = -1 + else: + base_model_files = sort_files_by_criteria( + folder_path=config.parameter.checkpoint_path, + start_string="basemodel_", + end_string=".pt", + ) + if len(base_model_files) == 0: + self.behaviour_policy_train_epoch = -1 + log.warning( + f"No basemodel file found in {config.parameter.checkpoint_path}" + ) + else: + checkpoint = torch.load( + os.path.join( + config.parameter.checkpoint_path, + base_model_files[0], + ), + map_location="cpu", + ) + self.model["GPPolicy"].base_model.load_state_dict( + checkpoint["base_model"] + ) + self.behaviour_policy_train_epoch = checkpoint.get( + "behaviour_policy_train_epoch", -1 + ) + + guided_model_files = sort_files_by_criteria( + folder_path=config.parameter.checkpoint_path, + start_string="guidedmodel_", + end_string=".pt", + ) + if len(guided_model_files) == 0: + self.guided_policy_train_epoch = -1 + log.warning( + f"No guidedmodel file found in {config.parameter.checkpoint_path}" + ) + else: + checkpoint = torch.load( + os.path.join( + config.parameter.checkpoint_path, + guided_model_files[0], + ), + map_location="cpu", + ) + self.model["GPPolicy"].guided_model.load_state_dict( + checkpoint["guided_model"] + ) + self.guided_policy_train_epoch = checkpoint.get( + "guided_policy_train_epoch", -1 + ) + + critic_model_files = sort_files_by_criteria( + folder_path=config.parameter.checkpoint_path, + start_string="critic_", + end_string=".pt", + ) + if len(critic_model_files) == 0: + self.critic_train_epoch = -1 + log.warning( + f"No criticmodel file found in {config.parameter.checkpoint_path}" + ) + else: + checkpoint = torch.load( + os.path.join( + config.parameter.checkpoint_path, + critic_model_files[0], + ), + map_location="cpu", + ) + self.model["GPPolicy"].critic.load_state_dict( + checkpoint["critic_model"] + ) + self.critic_train_epoch = checkpoint.get( + "critic_train_epoch", -1 + )
+ + # --------------------------------------- + # Customized model initialization code ↑ + # --------------------------------------- + +
[docs] def train(self, config: EasyDict = None, seed=None): + """ + Overview: + Train the model using the given configuration. \ + A weight-and-bias run will be created automatically when this function is called. + Arguments: + config (:obj:`EasyDict`): The training configuration. + seed (:obj:`int`): The random seed. + """ + + config = ( + merge_two_dicts_into_newone( + self.config.train if hasattr(self.config, "train") else EasyDict(), + config, + ) + if config is not None + else self.config.train + ) + + config["seed"] = self.seed_value if seed is None else seed + + if not hasattr(config, "wandb"): + config["wandb"] = dict(project=config.project) + elif not hasattr(config.wandb, "project"): + config.wandb["project"] = config.project + + with wandb.init(**config.wandb) as wandb_run: + if not hasattr(config.parameter.guided_policy, "beta"): + config.parameter.guided_policy.beta = 1.0 + + assert config.parameter.algorithm_type in [ + "GMPG", + "GMPG_REINFORCE", + "GMPG_REINFORCE_softmax", + "GMPG_add_matching", + ] + run_name = f"{config.parameter.critic.method}-beta-{config.parameter.guided_policy.beta}-T-{config.parameter.guided_policy.gradtime_step}-batch-{config.parameter.guided_policy.batch_size}-lr-{config.parameter.guided_policy.learning_rate}-seed-{self.seed_value}" + wandb.run.name = run_name + wandb.run.save() + + config = merge_two_dicts_into_newone(EasyDict(wandb_run.config), config) + wandb_run.config.update(config) + self.config.train = config + + self.simulator = ( + create_simulator(config.simulator) + if hasattr(config, "simulator") + else self.simulator + ) + self.dataset = ( + create_dataset(config.dataset) + if hasattr(config, "dataset") + else self.dataset + ) + + # --------------------------------------- + # Customized training code ↓ + # --------------------------------------- + + def save_checkpoint(model, iteration=None, model_type=False): + if iteration == None: + iteration = 0 + if model_type == "base_model": + if ( + hasattr(config.parameter, "checkpoint_path") + and config.parameter.checkpoint_path is not None + ): + if not os.path.exists(config.parameter.checkpoint_path): + os.makedirs(config.parameter.checkpoint_path) + torch.save( + dict( + base_model=model["GPPolicy"].base_model.state_dict(), + behaviour_policy_train_epoch=self.behaviour_policy_train_epoch, + behaviour_policy_train_iter=iteration, + ), + f=os.path.join( + config.parameter.checkpoint_path, + f"basemodel_{self.behaviour_policy_train_epoch}_{iteration}.pt", + ), + ) + elif model_type == "guided_model": + if ( + hasattr(config.parameter, "checkpoint_path") + and config.parameter.checkpoint_path is not None + ): + if not os.path.exists(config.parameter.checkpoint_path): + os.makedirs(config.parameter.checkpoint_path) + torch.save( + dict( + guided_model=model[ + "GPPolicy" + ].guided_model.state_dict(), + guided_policy_train_epoch=self.guided_policy_train_epoch, + guided_policy_train_iteration=iteration, + ), + f=os.path.join( + config.parameter.checkpoint_path, + f"guidedmodel_{self.guided_policy_train_epoch}_{iteration}.pt", + ), + ) + elif model_type == "critic_model": + if ( + hasattr(config.parameter, "checkpoint_path") + and config.parameter.checkpoint_path is not None + ): + if not os.path.exists(config.parameter.checkpoint_path): + os.makedirs(config.parameter.checkpoint_path) + torch.save( + dict( + critic_model=model["GPPolicy"].critic.state_dict(), + critic_train_epoch=self.critic_train_epoch, + critic_train_iter=iteration, + ), + f=os.path.join( + config.parameter.checkpoint_path, + f"critic_{self.critic_train_epoch}_{iteration}.pt", + ), + ) + else: + raise NotImplementedError + + def generate_fake_action(model, states, action_augment_num): + + fake_actions_sampled = [] + for states in track( + np.array_split(states, states.shape[0] // 4096 + 1), + description="Generate fake actions", + ): + + fake_actions_ = model.behaviour_policy_sample( + state=states, + batch_size=action_augment_num, + t_span=( + torch.linspace(0.0, 1.0, config.parameter.t_span).to( + states.device + ) + if hasattr(config.parameter, "t_span") + and config.parameter.t_span is not None + else None + ), + ) + fake_actions_sampled.append(torch.einsum("nbd->bnd", fake_actions_)) + + fake_actions = torch.cat(fake_actions_sampled, dim=0) + return fake_actions + + def evaluate(model, train_epoch, repeat=1): + evaluation_results = dict() + + def policy(obs: np.ndarray) -> np.ndarray: + if isinstance(obs, np.ndarray): + obs = torch.tensor( + obs, + dtype=torch.float32, + device=config.model.GPPolicy.device, + ).unsqueeze(0) + elif isinstance(obs, dict): + for key in obs: + obs[key] = torch.tensor( + obs[key], + dtype=torch.float32, + device=config.model.GPPolicy.device, + ).unsqueeze(0) + if obs[key].dim() == 1 and obs[key].shape[0] == 1: + obs[key] = obs[key].unsqueeze(1) + obs = TensorDict(obs, batch_size=[1]) + action = ( + model.sample( + condition=obs, + t_span=( + torch.linspace(0.0, 1.0, config.parameter.t_span).to( + config.model.GPPolicy.device + ) + if hasattr(config.parameter, "t_span") + and config.parameter.t_span is not None + else None + ), + ) + .squeeze(0) + .cpu() + .detach() + .numpy() + ) + return action + + eval_results = self.simulator.evaluate( + policy=policy, num_episodes=repeat + ) + return_results = [ + eval_results[i]["total_return"] for i in range(repeat) + ] + log.info(f"Return: {return_results}") + return_mean = np.mean(return_results) + return_std = np.std(return_results) + return_max = np.max(return_results) + return_min = np.min(return_results) + evaluation_results[f"evaluation/return_mean"] = return_mean + evaluation_results[f"evaluation/return_std"] = return_std + evaluation_results[f"evaluation/return_max"] = return_max + evaluation_results[f"evaluation/return_min"] = return_min + + if isinstance(self.dataset, GPD4RLDataset): + import d4rl + + env_id = config.dataset.args.env_id + evaluation_results[f"evaluation/return_mean_normalized"] = ( + d4rl.get_normalized_score(env_id, return_mean) + ) + evaluation_results[f"evaluation/return_std_normalized"] = ( + d4rl.get_normalized_score(env_id, return_std) + ) + evaluation_results[f"evaluation/return_max_normalized"] = ( + d4rl.get_normalized_score(env_id, return_max) + ) + evaluation_results[f"evaluation/return_min_normalized"] = ( + d4rl.get_normalized_score(env_id, return_min) + ) + + if repeat > 1: + log.info( + f"Train epoch: {train_epoch}, return_mean: {return_mean}, return_std: {return_std}, return_max: {return_max}, return_min: {return_min}" + ) + else: + log.info(f"Train epoch: {train_epoch}, return: {return_mean}") + + return evaluation_results + + # --------------------------------------- + # behavior training code ↓ + # --------------------------------------- + behaviour_policy_optimizer = torch.optim.Adam( + self.model["GPPolicy"].base_model.model.parameters(), + lr=config.parameter.behaviour_policy.learning_rate, + ) + + replay_buffer = TensorDictReplayBuffer( + storage=self.dataset.storage, + batch_size=config.parameter.behaviour_policy.batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + pin_memory=True, + ) + + behaviour_policy_train_iter = 0 + + logp_min = [] + logp_max = [] + logp_mean = [] + logp_sum = [] + end_return = [] + for epoch in track( + range(config.parameter.behaviour_policy.epochs), + description="Behaviour policy training", + ): + if self.behaviour_policy_train_epoch >= epoch: + continue + if ( + hasattr(config.parameter.evaluation, "analysis_interval") + and epoch % config.parameter.evaluation.analysis_interval == 0 + ): + + if hasattr(config.parameter.evaluation, "analysis_repeat"): + analysis_repeat = config.parameter.evaluation.analysis_repeat + else: + analysis_repeat = 10 + + analysis_counter = 0 + for index, data in enumerate(replay_buffer): + if analysis_counter == 0: + if not os.path.exists(config.parameter.checkpoint_path): + os.makedirs(config.parameter.checkpoint_path) + plot_distribution( + data["a"].detach().cpu().numpy(), + os.path.join( + config.parameter.checkpoint_path, + f"action_base_{epoch}.png", + ), + ) + + action = self.model["GPPolicy"].behaviour_policy_sample( + state=data["s"].to(config.model.GPPolicy.device), + t_span=( + torch.linspace( + 0.0, 1.0, config.parameter.t_span + ).to(config.model.GPPolicy.device) + if hasattr(config.parameter, "t_span") + and config.parameter.t_span is not None + else None + ), + ) + + evaluation_results = evaluate( + self.model["GPPolicy"].base_model, + train_epoch=epoch, + repeat=( + 1 + if not hasattr(config.parameter.evaluation, "repeat") + else config.parameter.evaluation.repeat + ), + ) + + if analysis_counter == 0: + plot_distribution( + action.detach().cpu().numpy(), + os.path.join( + config.parameter.checkpoint_path, + f"action_base_model_{epoch}_{evaluation_results['evaluation/return_mean']}.png", + ), + ) + + log_p = compute_likelihood( + model=self.model["GPPolicy"].base_model, + x=data["a"].to(config.model.GPPolicy.device), + condition=data["s"].to(config.model.GPPolicy.device), + t=torch.linspace(0.0, 1.0, 100).to( + config.model.GPPolicy.device + ), + using_Hutchinson_trace_estimator=True, + ) + logp_max.append(log_p.max().detach().cpu().numpy()) + logp_min.append(log_p.min().detach().cpu().numpy()) + logp_mean.append(log_p.mean().detach().cpu().numpy()) + logp_sum.append(log_p.sum().detach().cpu().numpy()) + end_return.append(evaluation_results["evaluation/return_mean"]) + + wandb.log(data=evaluation_results, commit=False) + + analysis_counter += 1 + if analysis_counter >= analysis_repeat: + logp_dict = { + "logp_max": logp_max, + "logp_min": logp_min, + "logp_mean": logp_mean, + "logp_sum": logp_sum, + "end_return": end_return, + } + np.savez( + os.path.join( + config.parameter.checkpoint_path, + f"logp_data_based_{epoch}.npz", + ), + **logp_dict, + ) + plot_histogram2d_x_y( + end_return, + logp_mean, + os.path.join( + config.parameter.checkpoint_path, + f"return_logp_base_{epoch}.png", + ), + ) + break + + counter = 1 + behaviour_policy_loss_sum = 0 + for index, data in enumerate(replay_buffer): + + behaviour_policy_loss = self.model[ + "GPPolicy" + ].behaviour_policy_loss( + action=data["a"].to(config.model.GPPolicy.device), + state=data["s"].to(config.model.GPPolicy.device), + maximum_likelihood=( + config.parameter.behaviour_policy.maximum_likelihood + if hasattr( + config.parameter.behaviour_policy, "maximum_likelihood" + ) + else False + ), + ) + behaviour_policy_optimizer.zero_grad() + behaviour_policy_loss.backward() + behaviour_policy_optimizer.step() + + counter += 1 + behaviour_policy_loss_sum += behaviour_policy_loss.item() + + behaviour_policy_train_iter += 1 + self.behaviour_policy_train_epoch = epoch + + wandb.log( + data=dict( + behaviour_policy_train_iter=behaviour_policy_train_iter, + behaviour_policy_train_epoch=epoch, + behaviour_policy_loss=behaviour_policy_loss_sum / counter, + ), + commit=True, + ) + + if ( + hasattr(config.parameter, "checkpoint_freq") + and (epoch + 1) % config.parameter.checkpoint_freq == 0 + ): + save_checkpoint( + self.model, + iteration=behaviour_policy_train_iter, + model_type="base_model", + ) + + # --------------------------------------- + # behavior training code ↑ + # --------------------------------------- + # --------------------------------------- + # critic training code ↓ + # --------------------------------------- + + q_optimizer = torch.optim.Adam( + self.model["GPPolicy"].critic.q.parameters(), + lr=config.parameter.critic.learning_rate, + ) + v_optimizer = torch.optim.Adam( + self.model["GPPolicy"].critic.v.parameters(), + lr=config.parameter.critic.learning_rate, + ) + + replay_buffer = TensorDictReplayBuffer( + storage=self.dataset.storage, + batch_size=config.parameter.critic.batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + pin_memory=True, + ) + + critic_train_iter = 0 + for epoch in track( + range(config.parameter.critic.epochs), description="Critic training" + ): + if self.critic_train_epoch >= epoch: + continue + + counter = 1 + + v_loss_sum = 0.0 + v_sum = 0.0 + q_loss_sum = 0.0 + q_sum = 0.0 + q_target_sum = 0.0 + for index, data in enumerate(replay_buffer): + + v_loss, next_v = self.model["GPPolicy"].critic.v_loss( + state=data["s"].to(config.model.GPPolicy.device), + action=data["a"].to(config.model.GPPolicy.device), + next_state=data["s_"].to(config.model.GPPolicy.device), + tau=config.parameter.critic.tau, + ) + v_optimizer.zero_grad(set_to_none=True) + v_loss.backward() + v_optimizer.step() + q_loss, q, q_target = self.model["GPPolicy"].critic.iql_q_loss( + state=data["s"].to(config.model.GPPolicy.device), + action=data["a"].to(config.model.GPPolicy.device), + reward=data["r"].to(config.model.GPPolicy.device), + done=data["d"].to(config.model.GPPolicy.device), + next_v=next_v, + discount=config.parameter.critic.discount_factor, + ) + q_optimizer.zero_grad(set_to_none=True) + q_loss.backward() + q_optimizer.step() + + # Update target + for param, target_param in zip( + self.model["GPPolicy"].critic.q.parameters(), + self.model["GPPolicy"].critic.q_target.parameters(), + ): + target_param.data.copy_( + config.parameter.critic.update_momentum * param.data + + (1 - config.parameter.critic.update_momentum) + * target_param.data + ) + + counter += 1 + + q_loss_sum += q_loss.item() + q_sum += q.mean().item() + q_target_sum += q_target.mean().item() + + v_loss_sum += v_loss.item() + v_sum += next_v.mean().item() + + critic_train_iter += 1 + self.critic_train_epoch = epoch + + wandb.log( + data=dict(v_loss=v_loss_sum / counter, v=v_sum / counter), + commit=False, + ) + + wandb.log( + data=dict( + critic_train_iter=critic_train_iter, + critic_train_epoch=epoch, + q_loss=q_loss_sum / counter, + q=q_sum / counter, + q_target=q_target_sum / counter, + ), + commit=True, + ) + + if ( + hasattr(config.parameter, "checkpoint_freq") + and (epoch + 1) % config.parameter.checkpoint_freq == 0 + ): + save_checkpoint( + self.model, + iteration=critic_train_iter, + model_type="critic_model", + ) + # --------------------------------------- + # critic training code ↑ + # --------------------------------------- + + # --------------------------------------- + # guided policy training code ↓ + # --------------------------------------- + + if not self.guided_policy_train_epoch > 0: + self.model["GPPolicy"].guided_model.model.load_state_dict( + self.model["GPPolicy"].base_model.model.state_dict() + ) + + guided_policy_optimizer = torch.optim.Adam( + self.model["GPPolicy"].guided_model.parameters(), + lr=config.parameter.guided_policy.learning_rate, + ) + + replay_buffer = TensorDictReplayBuffer( + storage=self.dataset.storage, + batch_size=config.parameter.guided_policy.batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + pin_memory=True, + ) + + logp_min = [] + logp_max = [] + logp_mean = [] + logp_sum = [] + end_return = [] + + guided_policy_train_iter = 0 + beta = config.parameter.guided_policy.beta + for epoch in track( + range(config.parameter.guided_policy.epochs), + description="Guided policy training", + ): + + if self.guided_policy_train_epoch >= epoch: + continue + + counter = 1 + guided_policy_loss_sum = 0.0 + for index, data in enumerate(replay_buffer): + if config.parameter.algorithm_type == "GMPG": + ( + guided_policy_loss, + q_loss, + log_p_loss, + log_u_loss, + ) = self.model["GPPolicy"].policy_gradient_loss( + data["s"].to(config.model.GPPolicy.device), + gradtime_step=config.parameter.guided_policy.gradtime_step, + beta=beta, + repeats=( + config.parameter.guided_policy.repeats + if hasattr(config.parameter.guided_policy, "repeats") + else 1 + ), + ) + elif config.parameter.algorithm_type == "GMPG_REINFORCE": + ( + guided_policy_loss, + q_loss, + log_p_loss, + log_u_loss, + ) = self.model["GPPolicy"].policy_gradient_loss_by_REINFORCE( + data["s"].to(config.model.GPPolicy.device), + gradtime_step=config.parameter.guided_policy.gradtime_step, + beta=beta, + repeats=( + config.parameter.guided_policy.repeats + if hasattr(config.parameter.guided_policy, "repeats") + else 1 + ), + weight_clamp=( + config.parameter.guided_policy.weight_clamp + if hasattr( + config.parameter.guided_policy, "weight_clamp" + ) + else 100.0 + ), + ) + elif config.parameter.algorithm_type == "GMPG_REINFORCE_softmax": + ( + guided_policy_loss, + q_loss, + log_p_loss, + log_u_loss, + ) = self.model[ + "GPPolicy" + ].policy_gradient_loss_by_REINFORCE_softmax( + data["s"].to(config.model.GPPolicy.device), + gradtime_step=config.parameter.guided_policy.gradtime_step, + beta=beta, + repeats=( + config.parameter.guided_policy.repeats + if hasattr(config.parameter.guided_policy, "repeats") + else 32 + ), + ) + elif config.parameter.algorithm_type == "GMPG_add_matching": + guided_policy_loss = self.model[ + "GPPolicy" + ].policy_gradient_loss_add_matching_loss( + data["a"].to(config.model.GPPolicy.device), + data["s"].to(config.model.GPPolicy.device), + maximum_likelihood=( + config.parameter.guided_policy.maximum_likelihood + if hasattr( + config.parameter.guided_policy, "maximum_likelihood" + ) + else False + ), + gradtime_step=config.parameter.guided_policy.gradtime_step, + beta=beta, + repeats=( + config.parameter.guided_policy.repeats + if hasattr(config.parameter.guided_policy, "repeats") + else 1 + ), + ) + else: + raise NotImplementedError + guided_policy_optimizer.zero_grad() + guided_policy_loss = guided_policy_loss * ( + data["s"].shape[0] / config.parameter.guided_policy.batch_size + ) + guided_policy_loss = guided_policy_loss.mean() + guided_policy_loss.backward() + guided_policy_optimizer.step() + counter += 1 + + if config.parameter.algorithm_type == "GMPG_add_matching": + wandb.log( + data=dict( + guided_policy_train_iter=guided_policy_train_iter, + guided_policy_train_epoch=epoch, + guided_policy_loss=guided_policy_loss.item(), + ), + commit=False, + ) + if ( + hasattr(config.parameter, "checkpoint_freq") + and (guided_policy_train_iter + 1) + % config.parameter.checkpoint_freq + == 0 + ): + save_checkpoint( + self.model, + iteration=guided_policy_train_iter, + model_type="guided_model", + ) + + elif config.parameter.algorithm_type in [ + "GMPG", + "GMPG_REINFORCE", + "GMPG_REINFORCE_softmax", + ]: + wandb.log( + data=dict( + guided_policy_train_iter=guided_policy_train_iter, + guided_policy_train_epoch=epoch, + guided_policy_loss=guided_policy_loss.item(), + q_loss=q_loss.item(), + log_p_loss=log_p_loss.item(), + log_u_loss=log_u_loss.item(), + ), + commit=False, + ) + if ( + hasattr(config.parameter, "checkpoint_freq") + and (guided_policy_train_iter + 1) + % config.parameter.checkpoint_freq + == 0 + ): + save_checkpoint( + self.model, + iteration=guided_policy_train_iter, + model_type="guided_model", + ) + + guided_policy_loss_sum += guided_policy_loss.item() + + self.guided_policy_train_epoch = epoch + + if ( + hasattr(config.parameter.evaluation, "analysis_interval") + and guided_policy_train_iter + % config.parameter.evaluation.analysis_interval + == 0 + ): + if hasattr(config.parameter.evaluation, "analysis_repeat"): + analysis_repeat = ( + config.parameter.evaluation.analysis_repeat + ) + else: + analysis_repeat = 10 + + if hasattr( + config.parameter.evaluation, "analysis_distribution" + ): + analysis_distribution = ( + config.parameter.evaluation.analysis_distribution + ) + else: + analysis_distribution = True + + analysis_counter = 0 + for index, data in enumerate(replay_buffer): + + if analysis_counter == 0 and analysis_distribution: + if not os.path.exists(config.parameter.checkpoint_path): + os.makedirs(config.parameter.checkpoint_path) + plot_distribution( + data["a"].detach().cpu().numpy(), + os.path.join( + config.parameter.checkpoint_path, + f"action_guided_{guided_policy_train_iter}.png", + ), + ) + + action = self.model["GPPolicy"].sample( + state=data["s"].to(config.model.GPPolicy.device), + t_span=( + torch.linspace( + 0.0, 1.0, config.parameter.t_span + ).to(config.model.GPPolicy.device) + if hasattr(config.parameter, "t_span") + and config.parameter.t_span is not None + else None + ), + ) + + evaluation_results = evaluate( + self.model["GPPolicy"].guided_model, + train_epoch=epoch, + repeat=( + 1 + if not hasattr( + config.parameter.evaluation, "repeat" + ) + else config.parameter.evaluation.repeat + ), + ) + + log_p = compute_likelihood( + model=self.model["GPPolicy"].guided_model, + x=data["a"].to(config.model.GPPolicy.device), + condition=data["s"].to(config.model.GPPolicy.device), + t=torch.linspace(0.0, 1.0, 100).to( + config.model.GPPolicy.device + ), + using_Hutchinson_trace_estimator=True, + ) + + logp_max.append(log_p.max().detach().cpu().numpy()) + logp_min.append(log_p.min().detach().cpu().numpy()) + logp_mean.append(log_p.mean().detach().cpu().numpy()) + logp_sum.append(log_p.sum().detach().cpu().numpy()) + end_return.append( + evaluation_results["evaluation/return_mean"] + ) + + if analysis_counter == 0 and analysis_distribution: + plot_distribution( + action.detach().cpu().numpy(), + os.path.join( + config.parameter.checkpoint_path, + f"action_guided_model_{guided_policy_train_iter}_{evaluation_results['evaluation/return_mean']}.png", + ), + ) + + analysis_counter += 1 + wandb.log(data=evaluation_results, commit=False) + if analysis_counter > analysis_repeat: + logp_dict = { + "logp_max": logp_max, + "logp_min": logp_min, + "logp_mean": logp_mean, + "logp_sum": logp_sum, + "end_return": end_return, + } + np.savez( + os.path.join( + config.parameter.checkpoint_path, + f"logp_data_guided_{epoch}.npz", + ), + **logp_dict, + ) + plot_histogram2d_x_y( + end_return, + logp_mean, + os.path.join( + config.parameter.checkpoint_path, + f"return_logp_guided_{guided_policy_train_iter}.png", + ), + ) + break + + if ( + config.parameter.evaluation.eval + and hasattr(config.parameter.evaluation, "interval") + and guided_policy_train_iter + % config.parameter.evaluation.interval + == 0 + ): + evaluation_results = evaluate( + self.model["GPPolicy"].guided_model, + train_epoch=epoch, + repeat=( + 1 + if not hasattr(config.parameter.evaluation, "repeat") + else config.parameter.evaluation.repeat + ), + ) + wandb.log(data=evaluation_results, commit=False) + guided_policy_train_iter += 1 + wandb.log( + data=dict( + guided_policy_train_iter=guided_policy_train_iter, + guided_policy_train_epoch=epoch, + ), + commit=True, + ) + + # --------------------------------------- + # guided policy training code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized training code ↑ + # --------------------------------------- + + wandb.finish()
+ + def deploy(self, config: EasyDict = None) -> GPAgent: + + if config is not None: + config = merge_two_dicts_into_newone(self.config.deploy, config) + else: + config = self.config.deploy + + assert "GPPolicy" in self.model, "The model must be trained first." + return GPAgent( + config=config, + model=copy.deepcopy(self.model["GPPolicy"].guided_model), + )
+
+ +
+ +
+
+ + +
+ +
+

+ © Copyright 2024, OpenDILab Contributors. + +

+
+ +
+ Built with Sphinx using a theme provided by Read the + Docs. +
+ + +
+
+
+ +
+
+
+ +
+
+
+
+
+ + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ + + + + +
+
+
+
+ + +
+
+
+ + +
+ + + + + + + + + \ No newline at end of file diff --git a/_modules/grl/algorithms/gmpo.html b/_modules/grl/algorithms/gmpo.html new file mode 100644 index 0000000..a60f5b1 --- /dev/null +++ b/_modules/grl/algorithms/gmpo.html @@ -0,0 +1,1799 @@ + + + + + + + + + + + + + + + grl.algorithms.gmpo — GenerativeRL v0.0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ +
+
+ + + + + + + + + + + +
+
+
+ + + + + + + + + + + + + + + + +
+ + + + +
+
+ +
+ Shortcuts +
+
+ +
+
+ +
+ +
+
+ +

Source code for grl.algorithms.gmpo

+import os
+import copy
+from typing import List, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from easydict import EasyDict
+from rich.progress import track
+from tensordict import TensorDict
+from torchrl.data import TensorDictReplayBuffer
+from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
+
+import wandb
+from grl.agents.gm import GPAgent
+
+from grl.datasets import create_dataset
+from grl.datasets.gp import GPDataset, GPD4RLDataset, GPD4RLTensorDictDataset
+from grl.generative_models.diffusion_model import DiffusionModel
+from grl.generative_models.conditional_flow_model.optimal_transport_conditional_flow_model import (
+    OptimalTransportConditionalFlowModel,
+)
+from grl.generative_models.conditional_flow_model.independent_conditional_flow_model import (
+    IndependentConditionalFlowModel,
+)
+from grl.generative_models.bridge_flow_model.schrodinger_bridge_conditional_flow_model import (
+    SchrodingerBridgeConditionalFlowModel,
+)
+
+from grl.rl_modules.simulators import create_simulator
+from grl.rl_modules.value_network.q_network import DoubleQNetwork
+from grl.rl_modules.value_network.value_network import VNetwork, DoubleVNetwork
+from grl.utils.config import merge_two_dicts_into_newone
+from grl.utils.log import log
+from grl.utils import set_seed
+from grl.utils.plot import plot_distribution, plot_histogram2d_x_y
+from grl.utils.statistics import sort_files_by_criteria
+from grl.generative_models.metric import compute_likelihood
+
+
+def asymmetric_l2_loss(u, tau):
+    return torch.mean(torch.abs(tau - (u < 0).float()) * u**2)
+
+
+
[docs]class GMPOCritic(nn.Module): + """ + Overview: + Critic network for GMPO algorithm. + Interfaces: + ``__init__``, ``forward`` + """ + +
[docs] def __init__(self, config: EasyDict): + """ + Overview: + Initialization of GMPO critic network. + Arguments: + config (:obj:`EasyDict`): The configuration dict. + """ + + super().__init__() + self.config = config + self.q_alpha = config.q_alpha + self.q = DoubleQNetwork(config.DoubleQNetwork) + self.q_target = copy.deepcopy(self.q).requires_grad_(False) + self.v = VNetwork(config.VNetwork)
+ +
[docs] def forward( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict] = None, + ) -> torch.Tensor: + """ + Overview: + Return the output of GMPO critic. + Arguments: + action (:obj:`torch.Tensor`): The input action. + state (:obj:`torch.Tensor`): The input state. + """ + + return self.q(action, state)
+ +
[docs] def compute_double_q( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Return the output of two Q networks. + Arguments: + action (:obj:`Union[torch.Tensor, TensorDict]`): The input action. + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + Returns: + q1 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the first Q network. + q2 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the second Q network. + """ + return self.q.compute_double_q(action, state)
+ + def v_loss(self, state, action, next_state, tau): + with torch.no_grad(): + target_q = self.q_target(action, state).detach() + next_v = self.v(next_state).detach() + # Update value function + v = self.v(state) + adv = target_q - v + v_loss = asymmetric_l2_loss(adv, tau) + return v_loss, next_v + + def iql_q_loss(self, state, action, reward, done, next_v, discount): + q_target = reward + (1.0 - done.float()) * discount * next_v.detach() + qs = self.q.compute_double_q(action, state) + q_loss = sum(torch.nn.functional.mse_loss(q, q_target) for q in qs) / len(qs) + return q_loss, torch.mean(qs[0]), torch.mean(q_target)
+ + +
[docs]class GMPOPolicy(nn.Module): + """ + Overview: + GMPO policy network for GMPO algorithm, which includes the base model (optinal), the guided model and the critic. + Interfaces: + ``__init__``, ``forward``, ``sample``, ``compute_q``, ``behaviour_policy_loss``, ``policy_optimization_loss_by_advantage_weighted_regression``, ``policy_optimization_loss_by_advantage_weighted_regression_softmax`` + """ + +
[docs] def __init__(self, config: EasyDict): + """ + Overview: + Initialize the GMPO policy network. + Arguments: + config (:obj:`EasyDict`): The configuration dict. + """ + super().__init__() + self.config = config + self.device = config.device + + self.critic = GMPOCritic(config.critic) + self.model_type = config.model_type + if self.model_type == "DiffusionModel": + self.base_model = DiffusionModel(config.model) + self.guided_model = DiffusionModel(config.model) + self.model_loss_type = config.model_loss_type + assert self.model_loss_type in ["score_matching", "flow_matching"] + elif self.model_type == "OptimalTransportConditionalFlowModel": + self.base_model = OptimalTransportConditionalFlowModel(config.model) + self.guided_model = OptimalTransportConditionalFlowModel(config.model) + elif self.model_type == "IndependentConditionalFlowModel": + self.base_model = IndependentConditionalFlowModel(config.model) + self.guided_model = IndependentConditionalFlowModel(config.model) + elif self.model_type == "SchrodingerBridgeConditionalFlowModel": + self.base_model = SchrodingerBridgeConditionalFlowModel(config.model) + self.guided_model = SchrodingerBridgeConditionalFlowModel(config.model) + else: + raise NotImplementedError + self.softmax = nn.Softmax(dim=1)
+ +
[docs] def forward( + self, state: Union[torch.Tensor, TensorDict] + ) -> Union[torch.Tensor, TensorDict]: + """ + Overview: + Return the output of GMPO policy, which is the action conditioned on the state. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + Returns: + action (:obj:`Union[torch.Tensor, TensorDict]`): The output action. + """ + return self.sample(state)
+ +
[docs] def sample( + self, + state: Union[torch.Tensor, TensorDict], + batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, + solver_config: EasyDict = None, + t_span: torch.Tensor = None, + with_grad: bool = False, + ) -> Union[torch.Tensor, TensorDict]: + """ + Overview: + Return the output of GMPO policy, which is the action conditioned on the state. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + batch_size (:obj:`Union[torch.Size, int, Tuple[int], List[int]]`): The batch size. + solver_config (:obj:`EasyDict`): The configuration for the ODE solver. + t_span (:obj:`torch.Tensor`): The time span for the ODE solver or SDE solver. + Returns: + action (:obj:`Union[torch.Tensor, TensorDict]`): The output action. + """ + + return self.guided_model.sample( + t_span=t_span, + condition=state, + batch_size=batch_size, + with_grad=with_grad, + solver_config=solver_config, + )
+ +
[docs] def behaviour_policy_sample( + self, + state: Union[torch.Tensor, TensorDict], + batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, + solver_config: EasyDict = None, + t_span: torch.Tensor = None, + with_grad: bool = False, + ) -> Union[torch.Tensor, TensorDict]: + """ + Overview: + Return the output of behaviour policy, which is the action conditioned on the state. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + batch_size (:obj:`Union[torch.Size, int, Tuple[int], List[int]]`): The batch size. + solver_config (:obj:`EasyDict`): The configuration for the ODE solver. + t_span (:obj:`torch.Tensor`): The time span for the ODE solver or SDE solver. + with_grad (:obj:`bool`): Whether to calculate the gradient. + Returns: + action (:obj:`Union[torch.Tensor, TensorDict]`): The output action. + """ + return self.base_model.sample( + t_span=t_span, + condition=state, + batch_size=batch_size, + with_grad=with_grad, + solver_config=solver_config, + )
+ +
[docs] def compute_q( + self, + state: Union[torch.Tensor, TensorDict], + action: Union[torch.Tensor, TensorDict], + ) -> torch.Tensor: + """ + Overview: + Calculate the Q value. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + action (:obj:`Union[torch.Tensor, TensorDict]`): The input action. + Returns: + q (:obj:`torch.Tensor`): The Q value. + """ + + return self.critic(action, state)
+ +
[docs] def behaviour_policy_loss( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict], + maximum_likelihood: bool = False, + ): + """ + Overview: + Calculate the behaviour policy loss. + Arguments: + action (:obj:`torch.Tensor`): The input action. + state (:obj:`torch.Tensor`): The input state. + """ + + if self.model_type == "DiffusionModel": + if self.model_loss_type == "score_matching": + if maximum_likelihood: + return self.base_model.score_matching_loss(action, state) + else: + return self.base_model.score_matching_loss( + action, state, weighting_scheme="vanilla" + ) + elif self.model_loss_type == "flow_matching": + return self.base_model.flow_matching_loss(action, state) + elif self.model_type in [ + "OptimalTransportConditionalFlowModel", + "IndependentConditionalFlowModel", + "SchrodingerBridgeConditionalFlowModel", + ]: + x0 = self.base_model.gaussian_generator(batch_size=state.shape[0]) + return self.base_model.flow_matching_loss(x0=x0, x1=action, condition=state)
+ +
[docs] def policy_optimization_loss_by_advantage_weighted_regression( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict], + maximum_likelihood: bool = False, + beta: float = 1.0, + weight_clamp: float = 100.0, + ): + """ + Overview: + Calculate the behaviour policy loss. + Arguments: + action (:obj:`torch.Tensor`): The input action. + state (:obj:`torch.Tensor`): The input state. + """ + + if self.model_type == "DiffusionModel": + if self.model_loss_type == "score_matching": + if maximum_likelihood: + model_loss = self.guided_model.score_matching_loss( + action, state, average=False + ) + else: + model_loss = self.guided_model.score_matching_loss( + action, state, weighting_scheme="vanilla", average=False + ) + elif self.model_loss_type == "flow_matching": + model_loss = self.guided_model.flow_matching_loss( + action, state, average=False + ) + elif self.model_type in [ + "OptimalTransportConditionalFlowModel", + "IndependentConditionalFlowModel", + "SchrodingerBridgeConditionalFlowModel", + ]: + x0 = self.guided_model.gaussian_generator(batch_size=state.shape[0]) + model_loss = self.guided_model.flow_matching_loss( + x0=x0, x1=action, condition=state, average=False + ) + else: + raise NotImplementedError + + with torch.no_grad(): + q_value = self.critic(action, state).squeeze(dim=-1) + v_value = self.critic.v(state).squeeze(dim=-1) + weight = torch.exp(beta * (q_value - v_value)) + + clamped_weight = weight.clamp(max=weight_clamp) + + # calculate the number of clamped_weight<weight + clamped_ratio = torch.mean( + torch.tensor(clamped_weight < weight, dtype=torch.float32) + ) + + return ( + torch.mean(model_loss * clamped_weight), + torch.mean(weight), + torch.mean(clamped_weight), + clamped_ratio, + )
+ +
[docs] def policy_optimization_loss_by_advantage_weighted_regression_softmax( + self, + state: Union[torch.Tensor, TensorDict], + fake_action: Union[torch.Tensor, TensorDict], + maximum_likelihood: bool = False, + beta: float = 1.0, + ): + """ + Overview: + Calculate the behaviour policy loss. + Arguments: + action (:obj:`torch.Tensor`): The input action. + state (:obj:`torch.Tensor`): The input state. + """ + action = fake_action + + action_reshape = action.reshape( + action.shape[0] * action.shape[1], *action.shape[2:] + ) + state_repeat = torch.stack([state] * action.shape[1], axis=1) + state_repeat_reshape = state_repeat.reshape( + state_repeat.shape[0] * state_repeat.shape[1], *state_repeat.shape[2:] + ) + energy = self.critic(action_reshape, state_repeat_reshape).detach() + energy = energy.reshape(action.shape[0], action.shape[1]).squeeze(dim=-1) + + if self.model_type == "DiffusionModel": + if self.model_loss_type == "score_matching": + if maximum_likelihood: + model_loss = self.guided_model.score_matching_loss( + action_reshape, state_repeat_reshape, average=False + ) + else: + model_loss = self.guided_model.score_matching_loss( + action_reshape, + state_repeat_reshape, + weighting_scheme="vanilla", + average=False, + ) + elif self.model_loss_type == "flow_matching": + model_loss = self.guided_model.flow_matching_loss( + action_reshape, state_repeat_reshape, average=False + ) + elif self.model_type in [ + "OptimalTransportConditionalFlowModel", + "IndependentConditionalFlowModel", + "SchrodingerBridgeConditionalFlowModel", + ]: + x0 = self.guided_model.gaussian_generator( + batch_size=state.shape[0] * action.shape[1] + ) + model_loss = self.guided_model.flow_matching_loss( + x0=x0, x1=action_reshape, condition=state_repeat_reshape, average=False + ) + else: + raise NotImplementedError + + model_loss = model_loss.reshape(action.shape[0], action.shape[1]).squeeze( + dim=-1 + ) + + relative_energy = nn.Softmax(dim=1)(energy * beta) + + loss = torch.mean(torch.sum(relative_energy * model_loss, axis=-1), dim=1) + + return ( + loss, + torch.mean(energy), + torch.mean(relative_energy), + torch.mean(model_loss), + )
+ + +
[docs]class GMPOAlgorithm: + """ + Overview: + The Generative Model Policy Optimization(GMPO) algorithm. + Interfaces: + ``__init__``, ``train``, ``deploy`` + """ + +
[docs] def __init__( + self, + config: EasyDict = None, + simulator=None, + dataset: GPDataset = None, + model: Union[torch.nn.Module, torch.nn.ModuleDict] = None, + seed=None, + ): + """ + Overview: + Initialize the GMPO && GPG algorithm. + Arguments: + config (:obj:`EasyDict`): The configuration , which must contain the following keys: + train (:obj:`EasyDict`): The training configuration. + deploy (:obj:`EasyDict`): The deployment configuration. + simulator (:obj:`object`): The environment simulator. + dataset (:obj:`GPDataset`): The dataset. + model (:obj:`Union[torch.nn.Module, torch.nn.ModuleDict]`): The model. + Interface: + ``__init__``, ``train``, ``deploy`` + """ + self.config = config + self.simulator = simulator + self.dataset = dataset + self.seed_value = set_seed(seed) + + # --------------------------------------- + # Customized model initialization code ↓ + # --------------------------------------- + + if model is not None: + self.model = model + self.behaviour_policy_train_epoch = 0 + self.critic_train_epoch = 0 + self.guided_policy_train_epoch = 0 + else: + self.model = torch.nn.ModuleDict() + config = self.config.train + assert hasattr(config.model, "GPPolicy") + + if torch.__version__ >= "2.0.0": + self.model["GPPolicy"] = torch.compile( + GMPOPolicy(config.model.GPPolicy).to(config.model.GPPolicy.device) + ) + else: + self.model["GPPolicy"] = GMPOPolicy(config.model.GPPolicy).to( + config.model.GPPolicy.device + ) + + if ( + hasattr(config.parameter, "checkpoint_path") + and config.parameter.checkpoint_path is not None + ): + if not os.path.exists(config.parameter.checkpoint_path): + log.warning( + f"Checkpoint path {config.parameter.checkpoint_path} does not exist" + ) + self.behaviour_policy_train_epoch = -1 + self.critic_train_epoch = -1 + self.guided_policy_train_epoch = -1 + else: + base_model_files = sort_files_by_criteria( + folder_path=config.parameter.checkpoint_path, + start_string="basemodel_", + end_string=".pt", + ) + if len(base_model_files) == 0: + self.behaviour_policy_train_epoch = -1 + log.warning( + f"No basemodel file found in {config.parameter.checkpoint_path}" + ) + else: + checkpoint = torch.load( + os.path.join( + config.parameter.checkpoint_path, + base_model_files[0], + ), + map_location="cpu", + ) + self.model["GPPolicy"].base_model.load_state_dict( + checkpoint["base_model"] + ) + self.behaviour_policy_train_epoch = checkpoint.get( + "behaviour_policy_train_epoch", -1 + ) + + guided_model_files = sort_files_by_criteria( + folder_path=config.parameter.checkpoint_path, + start_string="guidedmodel_", + end_string=".pt", + ) + if len(guided_model_files) == 0: + self.guided_policy_train_epoch = -1 + log.warning( + f"No guidedmodel file found in {config.parameter.checkpoint_path}" + ) + else: + checkpoint = torch.load( + os.path.join( + config.parameter.checkpoint_path, + guided_model_files[0], + ), + map_location="cpu", + ) + self.model["GPPolicy"].guided_model.load_state_dict( + checkpoint["guided_model"] + ) + self.guided_policy_train_epoch = checkpoint.get( + "guided_policy_train_epoch", -1 + ) + + critic_model_files = sort_files_by_criteria( + folder_path=config.parameter.checkpoint_path, + start_string="critic_", + end_string=".pt", + ) + if len(critic_model_files) == 0: + self.critic_train_epoch = -1 + log.warning( + f"No criticmodel file found in {config.parameter.checkpoint_path}" + ) + else: + checkpoint = torch.load( + os.path.join( + config.parameter.checkpoint_path, + critic_model_files[0], + ), + map_location="cpu", + ) + self.model["GPPolicy"].critic.load_state_dict( + checkpoint["critic_model"] + ) + self.critic_train_epoch = checkpoint.get( + "critic_train_epoch", -1 + )
+ + # --------------------------------------- + # Customized model initialization code ↑ + # --------------------------------------- + +
[docs] def train(self, config: EasyDict = None, seed=None): + """ + Overview: + Train the model using the given configuration. \ + A weight-and-bias run will be created automatically when this function is called. + Arguments: + config (:obj:`EasyDict`): The training configuration. + seed (:obj:`int`): The random seed. + """ + + config = ( + merge_two_dicts_into_newone( + self.config.train if hasattr(self.config, "train") else EasyDict(), + config, + ) + if config is not None + else self.config.train + ) + + config["seed"] = self.seed_value if seed is None else seed + + if not hasattr(config, "wandb"): + config["wandb"] = dict(project=config.project) + elif not hasattr(config.wandb, "project"): + config.wandb["project"] = config.project + + with wandb.init(**config.wandb) as wandb_run: + if not hasattr(config.parameter.guided_policy, "beta"): + config.parameter.guided_policy.beta = 1.0 + + assert config.parameter.algorithm_type in [ + "GMPO", + "GMPO_softmax_static", + "GMPO_softmax_sample", + ] + run_name = f"{config.parameter.critic.method}-tau-{config.parameter.critic.tau}-beta-{config.parameter.guided_policy.beta}-batch-{config.parameter.guided_policy.batch_size}-lr-{config.parameter.guided_policy.learning_rate}-{config.model.GPPolicy.model.model.type}-{self.seed_value}" + wandb.run.name = run_name + wandb.run.save() + + config = merge_two_dicts_into_newone(EasyDict(wandb_run.config), config) + wandb_run.config.update(config) + self.config.train = config + + self.simulator = ( + create_simulator(config.simulator) + if hasattr(config, "simulator") + else self.simulator + ) + self.dataset = ( + create_dataset(config.dataset) + if hasattr(config, "dataset") + else self.dataset + ) + + # --------------------------------------- + # Customized training code ↓ + # --------------------------------------- + + def save_checkpoint(model, iteration=None, model_type=False): + if iteration == None: + iteration = 0 + if model_type == "base_model": + if ( + hasattr(config.parameter, "checkpoint_path") + and config.parameter.checkpoint_path is not None + ): + if not os.path.exists(config.parameter.checkpoint_path): + os.makedirs(config.parameter.checkpoint_path) + torch.save( + dict( + base_model=model["GPPolicy"].base_model.state_dict(), + behaviour_policy_train_epoch=self.behaviour_policy_train_epoch, + behaviour_policy_train_iter=iteration, + ), + f=os.path.join( + config.parameter.checkpoint_path, + f"basemodel_{self.behaviour_policy_train_epoch}_{iteration}.pt", + ), + ) + elif model_type == "guided_model": + if ( + hasattr(config.parameter, "checkpoint_path") + and config.parameter.checkpoint_path is not None + ): + if not os.path.exists(config.parameter.checkpoint_path): + os.makedirs(config.parameter.checkpoint_path) + torch.save( + dict( + guided_model=model[ + "GPPolicy" + ].guided_model.state_dict(), + guided_policy_train_epoch=self.guided_policy_train_epoch, + guided_policy_train_iteration=iteration, + ), + f=os.path.join( + config.parameter.checkpoint_path, + f"guidedmodel_{self.guided_policy_train_epoch}_{iteration}.pt", + ), + ) + elif model_type == "critic_model": + if ( + hasattr(config.parameter, "checkpoint_path") + and config.parameter.checkpoint_path is not None + ): + if not os.path.exists(config.parameter.checkpoint_path): + os.makedirs(config.parameter.checkpoint_path) + torch.save( + dict( + critic_model=model["GPPolicy"].critic.state_dict(), + critic_train_epoch=self.critic_train_epoch, + critic_train_iter=iteration, + ), + f=os.path.join( + config.parameter.checkpoint_path, + f"critic_{self.critic_train_epoch}_{iteration}.pt", + ), + ) + else: + raise NotImplementedError + + def generate_fake_action(model, states, action_augment_num): + + fake_actions_sampled = [] + for states in track( + np.array_split(states, states.shape[0] // 4096 + 1), + description="Generate fake actions", + ): + + fake_actions_ = model.behaviour_policy_sample( + state=states, + batch_size=action_augment_num, + t_span=( + torch.linspace(0.0, 1.0, config.parameter.t_span).to( + states.device + ) + if hasattr(config.parameter, "t_span") + and config.parameter.t_span is not None + else None + ), + ) + fake_actions_sampled.append(torch.einsum("nbd->bnd", fake_actions_)) + + fake_actions = torch.cat(fake_actions_sampled, dim=0) + return fake_actions + + def evaluate(model, train_epoch, repeat=1): + evaluation_results = dict() + + def policy(obs: np.ndarray) -> np.ndarray: + if isinstance(obs, np.ndarray): + obs = torch.tensor( + obs, + dtype=torch.float32, + device=config.model.GPPolicy.device, + ).unsqueeze(0) + elif isinstance(obs, dict): + for key in obs: + obs[key] = torch.tensor( + obs[key], + dtype=torch.float32, + device=config.model.GPPolicy.device, + ).unsqueeze(0) + if obs[key].dim() == 1 and obs[key].shape[0] == 1: + obs[key] = obs[key].unsqueeze(1) + obs = TensorDict(obs, batch_size=[1]) + action = ( + model.sample( + condition=obs, + t_span=( + torch.linspace(0.0, 1.0, config.parameter.t_span).to( + config.model.GPPolicy.device + ) + if hasattr(config.parameter, "t_span") + and config.parameter.t_span is not None + else None + ), + ) + .squeeze(0) + .cpu() + .detach() + .numpy() + ) + return action + + eval_results = self.simulator.evaluate( + policy=policy, num_episodes=repeat + ) + return_results = [ + eval_results[i]["total_return"] for i in range(repeat) + ] + log.info(f"Return: {return_results}") + return_mean = np.mean(return_results) + return_std = np.std(return_results) + return_max = np.max(return_results) + return_min = np.min(return_results) + evaluation_results[f"evaluation/return_mean"] = return_mean + evaluation_results[f"evaluation/return_std"] = return_std + evaluation_results[f"evaluation/return_max"] = return_max + evaluation_results[f"evaluation/return_min"] = return_min + + if isinstance(self.dataset, GPD4RLDataset) or isinstance( + self.dataset, GPD4RLTensorDictDataset + ): + import d4rl + + env_id = config.dataset.args.env_id + evaluation_results[f"evaluation/return_mean_normalized"] = ( + d4rl.get_normalized_score(env_id, return_mean) + ) + evaluation_results[f"evaluation/return_std_normalized"] = ( + d4rl.get_normalized_score(env_id, return_std) + ) + evaluation_results[f"evaluation/return_max_normalized"] = ( + d4rl.get_normalized_score(env_id, return_max) + ) + evaluation_results[f"evaluation/return_min_normalized"] = ( + d4rl.get_normalized_score(env_id, return_min) + ) + + if repeat > 1: + log.info( + f"Train epoch: {train_epoch}, return_mean: {return_mean}, return_std: {return_std}, return_max: {return_max}, return_min: {return_min}" + ) + else: + log.info(f"Train epoch: {train_epoch}, return: {return_mean}") + + return evaluation_results + + # --------------------------------------- + # behavior training code ↓ + # --------------------------------------- + + behaviour_policy_optimizer = torch.optim.Adam( + self.model["GPPolicy"].base_model.model.parameters(), + lr=config.parameter.behaviour_policy.learning_rate, + ) + + replay_buffer = TensorDictReplayBuffer( + storage=self.dataset.storage, + batch_size=config.parameter.behaviour_policy.batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + pin_memory=True, + ) + + behaviour_policy_train_iter = 0 + for epoch in track( + range(config.parameter.behaviour_policy.epochs), + description="Behaviour policy training", + ): + if self.behaviour_policy_train_epoch >= epoch: + continue + + if ( + hasattr(config.parameter.evaluation, "analysis_interval") + and epoch % config.parameter.evaluation.analysis_interval == 0 + ): + for index, data in enumerate(replay_buffer): + + evaluation_results = evaluate( + self.model["GPPolicy"].base_model, + train_epoch=epoch, + repeat=( + 1 + if not hasattr(config.parameter.evaluation, "repeat") + else config.parameter.evaluation.repeat + ), + ) + if not os.path.exists(config.parameter.checkpoint_path): + os.makedirs(config.parameter.checkpoint_path) + plot_distribution( + data["a"], + os.path.join( + config.parameter.checkpoint_path, + f"action_base_{epoch}.png", + ), + ) + + action = self.model["GPPolicy"].sample( + state=data["s"].to(config.model.GPPolicy.device), + t_span=( + torch.linspace(0.0, 1.0, config.parameter.t_span).to( + config.model.GPPolicy.device + ) + if hasattr(config.parameter, "t_span") + and config.parameter.t_span is not None + else None + ), + ) + plot_distribution( + action, + os.path.join( + config.parameter.checkpoint_path, + f"action_base_model_{epoch}_{evaluation_results['evaluation/return_mean']}.png", + ), + ) + + wandb.log(data=evaluation_results, commit=False) + break + + counter = 1 + behaviour_policy_loss_sum = 0 + for index, data in enumerate(replay_buffer): + + behaviour_policy_loss = self.model[ + "GPPolicy" + ].behaviour_policy_loss( + action=data["a"].to(config.model.GPPolicy.device), + state=data["s"].to(config.model.GPPolicy.device), + maximum_likelihood=( + config.parameter.behaviour_policy.maximum_likelihood + if hasattr( + config.parameter.behaviour_policy, "maximum_likelihood" + ) + else False + ), + ) + behaviour_policy_optimizer.zero_grad() + behaviour_policy_loss.backward() + behaviour_policy_optimizer.step() + + counter += 1 + behaviour_policy_loss_sum += behaviour_policy_loss.item() + + behaviour_policy_train_iter += 1 + self.behaviour_policy_train_epoch = epoch + + wandb.log( + data=dict( + behaviour_policy_train_iter=behaviour_policy_train_iter, + behaviour_policy_train_epoch=epoch, + behaviour_policy_loss=behaviour_policy_loss_sum / counter, + ), + commit=True, + ) + + if ( + hasattr(config.parameter, "checkpoint_freq") + and (epoch + 1) % config.parameter.checkpoint_freq == 0 + ): + save_checkpoint( + self.model, + iteration=behaviour_policy_train_iter, + model_type="base_model", + ) + + # --------------------------------------- + # behavior training code ↑ + # --------------------------------------- + + # --------------------------------------- + # make fake action ↓ + # --------------------------------------- + + if config.parameter.algorithm_type in ["GMPO_softmax_static"]: + data_augmentation = True + else: + data_augmentation = False + + self.model["GPPolicy"].base_model.eval() + if data_augmentation: + + fake_actions = generate_fake_action( + self.model["GPPolicy"], + self.dataset.states[:].to(config.model.GPPolicy.device), + config.parameter.action_augment_num, + ) + fake_next_actions = generate_fake_action( + self.model["GPPolicy"], + self.dataset.next_states[:].to(config.model.GPPolicy.device), + config.parameter.action_augment_num, + ) + + self.dataset.load_fake_actions( + fake_actions=fake_actions.to("cpu"), + fake_next_actions=fake_next_actions.to("cpu"), + ) + + # --------------------------------------- + # make fake action ↑ + # --------------------------------------- + + # --------------------------------------- + # critic training code ↓ + # --------------------------------------- + + q_optimizer = torch.optim.Adam( + self.model["GPPolicy"].critic.q.parameters(), + lr=config.parameter.critic.learning_rate, + ) + v_optimizer = torch.optim.Adam( + self.model["GPPolicy"].critic.v.parameters(), + lr=config.parameter.critic.learning_rate, + ) + + replay_buffer = TensorDictReplayBuffer( + storage=self.dataset.storage, + batch_size=config.parameter.critic.batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + pin_memory=True, + ) + + critic_train_iter = 0 + for epoch in track( + range(config.parameter.critic.epochs), description="Critic training" + ): + if self.critic_train_epoch >= epoch: + continue + + counter = 1 + + v_loss_sum = 0.0 + v_sum = 0.0 + q_loss_sum = 0.0 + q_sum = 0.0 + q_target_sum = 0.0 + for index, data in enumerate(replay_buffer): + + v_loss, next_v = self.model["GPPolicy"].critic.v_loss( + state=data["s"].to(config.model.GPPolicy.device), + action=data["a"].to(config.model.GPPolicy.device), + next_state=data["s_"].to(config.model.GPPolicy.device), + tau=config.parameter.critic.tau, + ) + v_optimizer.zero_grad(set_to_none=True) + v_loss.backward() + v_optimizer.step() + q_loss, q, q_target = self.model["GPPolicy"].critic.iql_q_loss( + state=data["s"].to(config.model.GPPolicy.device), + action=data["a"].to(config.model.GPPolicy.device), + reward=data["r"].to(config.model.GPPolicy.device), + done=data["d"].to(config.model.GPPolicy.device), + next_v=next_v, + discount=config.parameter.critic.discount_factor, + ) + q_optimizer.zero_grad(set_to_none=True) + q_loss.backward() + q_optimizer.step() + + # Update target + for param, target_param in zip( + self.model["GPPolicy"].critic.q.parameters(), + self.model["GPPolicy"].critic.q_target.parameters(), + ): + target_param.data.copy_( + config.parameter.critic.update_momentum * param.data + + (1 - config.parameter.critic.update_momentum) + * target_param.data + ) + + counter += 1 + + q_loss_sum += q_loss.item() + q_sum += q.mean().item() + q_target_sum += q_target.mean().item() + + v_loss_sum += v_loss.item() + v_sum += next_v.mean().item() + + critic_train_iter += 1 + self.critic_train_epoch = epoch + + wandb.log( + data=dict(v_loss=v_loss_sum / counter, v=v_sum / counter), + commit=False, + ) + + wandb.log( + data=dict( + critic_train_iter=critic_train_iter, + critic_train_epoch=epoch, + q_loss=q_loss_sum / counter, + q=q_sum / counter, + q_target=q_target_sum / counter, + ), + commit=True, + ) + + if ( + hasattr(config.parameter, "checkpoint_freq") + and (epoch + 1) % config.parameter.checkpoint_freq == 0 + ): + save_checkpoint( + self.model, + iteration=critic_train_iter, + model_type="critic_model", + ) + # --------------------------------------- + # critic training code ↑ + # --------------------------------------- + + # --------------------------------------- + # guided policy training code ↓ + # --------------------------------------- + + if not self.guided_policy_train_epoch > 0: + if ( + hasattr(config.parameter.guided_policy, "copy_from_basemodel") + and config.parameter.guided_policy.copy_from_basemodel + ): + self.model["GPPolicy"].guided_model.model.load_state_dict( + self.model["GPPolicy"].base_model.model.state_dict() + ) + + guided_policy_optimizer = torch.optim.Adam( + self.model["GPPolicy"].guided_model.parameters(), + lr=config.parameter.guided_policy.learning_rate, + ) + guided_policy_train_iter = 0 + logp_mean = [] + end_return = [] + beta = config.parameter.guided_policy.beta + + replay_buffer = TensorDictReplayBuffer( + storage=self.dataset.storage, + batch_size=config.parameter.guided_policy.batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + pin_memory=True, + ) + + for epoch in track( + range(config.parameter.guided_policy.epochs), + description="Guided policy training", + ): + + if self.guided_policy_train_epoch >= epoch: + continue + + if ( + hasattr(config.parameter.evaluation, "analysis_interval") + and epoch % config.parameter.evaluation.analysis_interval == 0 + ): + timlimited = 0 + for index, data in enumerate(replay_buffer): + if timlimited == 0: + if not os.path.exists(config.parameter.checkpoint_path): + os.makedirs(config.parameter.checkpoint_path) + plot_distribution( + data["a"].detach().cpu().numpy(), + os.path.join( + config.parameter.checkpoint_path, + f"action_base_{epoch}.png", + ), + ) + + action = self.model["GPPolicy"].sample( + state=data["s"].to(config.model.GPPolicy.device), + t_span=( + torch.linspace( + 0.0, 1.0, config.parameter.t_span + ).to(config.model.GPPolicy.device) + if hasattr(config.parameter, "t_span") + and config.parameter.t_span is not None + else None + ), + ) + + evaluation_results = evaluate( + self.model["GPPolicy"].guided_model, + train_epoch=epoch, + repeat=( + 1 + if not hasattr(config.parameter.evaluation, "repeat") + else config.parameter.evaluation.repeat + ), + ) + + log_p = compute_likelihood( + model=self.model["GPPolicy"].guided_model, + x=data["a"].to(config.model.GPPolicy.device), + condition=data["s"].to(config.model.GPPolicy.device), + t=torch.linspace(0.0, 1.0, 100).to( + config.model.GPPolicy.device + ), + using_Hutchinson_trace_estimator=True, + ) + logp_mean.append(log_p.mean().detach().cpu().numpy()) + end_return.append(evaluation_results["evaluation/return_mean"]) + + if timlimited == 0: + plot_distribution( + action.detach().cpu().numpy(), + os.path.join( + config.parameter.checkpoint_path, + f"action_guided_model_{epoch}_{evaluation_results['evaluation/return_mean']}.png", + ), + ) + timlimited += 1 + wandb.log(data=evaluation_results, commit=False) + + if timlimited > 10: + logp_dict = { + "logp_mean": logp_mean, + "end_return": end_return, + } + np.savez( + os.path.join( + config.parameter.checkpoint_path, + f"logp_data_guided_{epoch}.npz", + ), + **logp_dict, + ) + plot_histogram2d_x_y( + end_return, + logp_mean, + os.path.join( + config.parameter.checkpoint_path, + f"return_logp_guided_{epoch}.png", + ), + ) + break + + counter = 1 + guided_policy_loss_sum = 0.0 + if config.parameter.algorithm_type == "GMPO": + weight_sum = 0.0 + clamped_weight_sum = 0.0 + clamped_ratio_sum = 0.0 + elif config.parameter.algorithm_type in [ + "GMPO_softmax_static", + "GMPO_softmax_sample", + ]: + energy_sum = 0.0 + relative_energy_sum = 0.0 + matching_loss_sum = 0.0 + + for index, data in enumerate(replay_buffer): + if config.parameter.algorithm_type == "GMPO": + ( + guided_policy_loss, + weight, + clamped_weight, + clamped_ratio, + ) = self.model[ + "GPPolicy" + ].policy_optimization_loss_by_advantage_weighted_regression( + data["a"].to(config.model.GPPolicy.device), + data["s"].to(config.model.GPPolicy.device), + maximum_likelihood=( + config.parameter.guided_policy.maximum_likelihood + if hasattr( + config.parameter.guided_policy, "maximum_likelihood" + ) + else False + ), + beta=beta, + weight_clamp=( + config.parameter.guided_policy.weight_clamp + if hasattr( + config.parameter.guided_policy, "weight_clamp" + ) + else 100.0 + ), + ) + weight_sum += weight + clamped_weight_sum += clamped_weight + clamped_ratio_sum += clamped_ratio + elif config.parameter.algorithm_type == "GMPO_softmax_static": + ( + guided_policy_loss, + energy, + relative_energy, + matching_loss, + ) = self.model[ + "GPPolicy" + ].policy_optimization_loss_by_advantage_weighted_regression_softmax( + data["s"].to(config.model.GPPolicy.device), + data["fake_a"].to(config.model.GPPolicy.device), + maximum_likelihood=( + config.parameter.guided_policy.maximum_likelihood + if hasattr( + config.parameter.guided_policy, "maximum_likelihood" + ) + else False + ), + beta=beta, + ) + energy_sum += energy + relative_energy_sum += relative_energy + matching_loss_sum += matching_loss + elif config.parameter.algorithm_type == "GMPO_softmax_sample": + fake_actions_ = self.model["GPPolicy"].behaviour_policy_sample( + state=data["s"].to(config.model.GPPolicy.device), + t_span=( + torch.linspace(0.0, 1.0, config.parameter.t_span).to( + config.model.GPPolicy.device + ) + if hasattr(config.parameter, "t_span") + and config.parameter.t_span is not None + else None + ), + batch_size=config.parameter.action_augment_num, + ) + fake_actions_ = torch.einsum("nbd->bnd", fake_actions_) + ( + guided_policy_loss, + energy, + relative_energy, + matching_loss, + ) = self.model[ + "GPPolicy" + ].policy_optimization_loss_by_advantage_weighted_regression_softmax( + data["s"].to(config.model.GPPolicy.device), + fake_actions_, + maximum_likelihood=( + config.parameter.guided_policy.maximum_likelihood + if hasattr( + config.parameter.guided_policy, "maximum_likelihood" + ) + else False + ), + beta=beta, + ) + energy_sum += energy + relative_energy_sum += relative_energy + matching_loss_sum += matching_loss + else: + raise NotImplementedError + guided_policy_optimizer.zero_grad() + guided_policy_loss.backward() + guided_policy_optimizer.step() + counter += 1 + + guided_policy_loss_sum += guided_policy_loss.item() + + guided_policy_train_iter += 1 + self.guided_policy_train_epoch = epoch + + if ( + config.parameter.evaluation.eval + and hasattr(config.parameter.evaluation, "epoch_interval") + and (self.guided_policy_train_epoch + 1) + % config.parameter.evaluation.epoch_interval + == 0 + ): + evaluation_results = evaluate( + self.model["GPPolicy"].guided_model, + train_epoch=epoch, + repeat=( + 1 + if not hasattr(config.parameter.evaluation, "repeat") + else config.parameter.evaluation.repeat + ), + ) + wandb.log(data=evaluation_results, commit=False) + wandb.log( + data=dict( + guided_policy_train_iter=guided_policy_train_iter, + guided_policy_train_epoch=epoch, + ), + commit=True, + ) + + if config.parameter.algorithm_type == "GMPO": + wandb.log( + data=dict( + weight=weight_sum / counter, + clamped_weight=clamped_weight_sum / counter, + clamped_ratio=clamped_ratio_sum / counter, + ), + commit=False, + ) + elif config.parameter.algorithm_type in [ + "GMPO_softmax_static", + "GMPO_softmax_sample", + ]: + wandb.log( + data=dict( + energy=energy_sum / counter, + relative_energy=relative_energy_sum / counter, + matching_loss=matching_loss_sum / counter, + ), + commit=False, + ) + + wandb.log( + data=dict( + guided_policy_train_iter=guided_policy_train_iter, + guided_policy_train_epoch=epoch, + guided_policy_loss=guided_policy_loss_sum / counter, + ), + commit=True, + ) + + if ( + hasattr(config.parameter, "checkpoint_freq") + and (epoch + 1) % config.parameter.checkpoint_freq == 0 + ): + save_checkpoint( + self.model, + iteration=guided_policy_train_iter, + model_type="guided_model", + ) + + # --------------------------------------- + # guided policy training code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized training code ↑ + # --------------------------------------- + + wandb.finish()
+ + def deploy(self, config: EasyDict = None) -> GPAgent: + + if config is not None: + config = merge_two_dicts_into_newone(self.config.deploy, config) + else: + config = self.config.deploy + + assert "GPPolicy" in self.model, "The model must be trained first." + return GPAgent( + config=config, + model=copy.deepcopy(self.model["GPPolicy"].guided_model), + )
+
+ +
+ +
+
+ + +
+ +
+

+ © Copyright 2024, OpenDILab Contributors. + +

+
+ +
+ Built with Sphinx using a theme provided by Read the + Docs. +
+ + +
+
+
+ +
+
+
+ +
+
+
+
+
+ + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ + + + + +
+
+
+
+ + +
+
+
+ + +
+ + + + + + + + + \ No newline at end of file diff --git a/_modules/grl/algorithms/qgpo.html b/_modules/grl/algorithms/qgpo.html new file mode 100644 index 0000000..a501f0e --- /dev/null +++ b/_modules/grl/algorithms/qgpo.html @@ -0,0 +1,1201 @@ + + + + + + + + + + + + + + + grl.algorithms.qgpo — GenerativeRL v0.0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ +
+
+ + + + + + + + + + + +
+
+
+ + + + + + + + + + + + + + + + +
+ + + + +
+
+ +
+ Shortcuts +
+
+ +
+
+ +
+ +
+
+ +

Source code for grl.algorithms.qgpo

+#############################################################
+# This QGPO model is a modification implementation from https://github.com/ChenDRAG/CEP-energy-guided-diffusion
+#############################################################
+
+import copy
+from typing import List, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from easydict import EasyDict
+from rich.progress import Progress, track
+from tensordict import TensorDict
+from torchrl.data import TensorDictReplayBuffer
+from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
+
+import wandb
+from grl.agents.qgpo import QGPOAgent
+from grl.datasets import create_dataset
+from grl.datasets.qgpo import QGPODataset
+from grl.generative_models.diffusion_model.energy_conditional_diffusion_model import (
+    EnergyConditionalDiffusionModel,
+)
+from grl.rl_modules.simulators import create_simulator
+from grl.rl_modules.value_network.q_network import DoubleQNetwork
+from grl.utils.config import merge_two_dicts_into_newone
+from grl.utils.log import log
+from grl.utils.model_utils import save_model, load_model
+
+
+
[docs]class QGPOCritic(nn.Module): + """ + Overview: + Critic network for QGPO algorithm. + Interfaces: + ``__init__``, ``forward`` + """ + +
[docs] def __init__(self, config: EasyDict): + """ + Overview: + Initialization of QGPO critic network. + Arguments: + config (:obj:`EasyDict`): The configuration dict. + """ + + super().__init__() + self.config = config + self.q_alpha = config.q_alpha + self.q = DoubleQNetwork(config.DoubleQNetwork) + self.q_target = copy.deepcopy(self.q).requires_grad_(False)
+ +
[docs] def forward( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict] = None, + ) -> torch.Tensor: + """ + Overview: + Return the output of QGPO critic. + Arguments: + action (:obj:`torch.Tensor`): The input action. + state (:obj:`torch.Tensor`): The input state. + """ + + return self.q(action, state)
+ +
[docs] def compute_double_q( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Return the output of two Q networks. + Arguments: + action (:obj:`Union[torch.Tensor, TensorDict]`): The input action. + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + Returns: + q1 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the first Q network. + q2 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the second Q network. + """ + return self.q.compute_double_q(action, state)
+ +
[docs] def q_loss( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict], + reward: Union[torch.Tensor, TensorDict], + next_state: Union[torch.Tensor, TensorDict], + done: Union[torch.Tensor, TensorDict], + fake_next_action: Union[torch.Tensor, TensorDict], + discount_factor: float = 1.0, + ) -> torch.Tensor: + """ + Overview: + Calculate the Q loss. + Arguments: + action (:obj:`torch.Tensor`): The input action. + state (:obj:`torch.Tensor`): The input state. + reward (:obj:`torch.Tensor`): The input reward. + next_state (:obj:`torch.Tensor`): The input next state. + done (:obj:`torch.Tensor`): The input done. + fake_next_action (:obj:`torch.Tensor`): The input fake next action. + discount_factor (:obj:`float`): The discount factor. + """ + with torch.no_grad(): + softmax = nn.Softmax(dim=1) + if isinstance(next_state, TensorDict): + new_next_state = next_state.clone(False) + for key, value in next_state.items(): + if isinstance(value, torch.Tensor): + stacked_value = torch.stack( + [value] * fake_next_action.shape[1], axis=1 + ) + new_next_state.set(key, stacked_value) + else: + new_next_state = torch.stack( + [next_state] * fake_next_action.shape[1], axis=1 + ) + next_energy = ( + self.q_target( + fake_next_action, + new_next_state, + ) + .detach() + .squeeze(dim=-1) + ) + next_v = torch.sum( + softmax(self.q_alpha * next_energy) * next_energy, dim=-1, keepdim=True + ) + # Update Q function + targets = reward + (1.0 - done.float()) * discount_factor * next_v.detach() + q0, q1 = self.q.compute_double_q(action, state) + q_loss = ( + torch.nn.functional.mse_loss(q0, targets) + + torch.nn.functional.mse_loss(q1, targets) + ) / 2 + return q_loss
+ + +
[docs]class QGPOPolicy(nn.Module): + """ + Overview: + QGPO policy network. + Interfaces: + ``__init__``, ``forward``, ``sample``, ``behaviour_policy_sample``, ``compute_q``, ``behaviour_policy_loss``, ``energy_guidance_loss``, ``q_loss`` + """ + +
[docs] def __init__(self, config: EasyDict): + super().__init__() + self.config = config + self.device = config.device + + self.critic = QGPOCritic(config.critic) + self.diffusion_model = EnergyConditionalDiffusionModel( + config.diffusion_model, energy_model=self.critic + )
+ +
[docs] def forward( + self, state: Union[torch.Tensor, TensorDict] + ) -> Union[torch.Tensor, TensorDict]: + """ + Overview: + Return the output of QGPO policy, which is the action conditioned on the state. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + Returns: + action (:obj:`Union[torch.Tensor, TensorDict]`): The output action. + """ + return self.sample(state)
+ +
[docs] def sample( + self, + state: Union[torch.Tensor, TensorDict], + batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, + guidance_scale: Union[torch.Tensor, float] = torch.tensor(1.0), + solver_config: EasyDict = None, + t_span: torch.Tensor = None, + ) -> Union[torch.Tensor, TensorDict]: + """ + Overview: + Return the output of QGPO policy, which is the action conditioned on the state. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + guidance_scale (:obj:`Union[torch.Tensor, float]`): The guidance scale. + solver_config (:obj:`EasyDict`): The configuration for the ODE solver. + t_span (:obj:`torch.Tensor`): The time span for the ODE solver or SDE solver. + Returns: + action (:obj:`Union[torch.Tensor, TensorDict]`): The output action. + """ + return self.diffusion_model.sample( + t_span=t_span, + condition=state, + batch_size=batch_size, + guidance_scale=guidance_scale, + with_grad=False, + solver_config=solver_config, + )
+ +
[docs] def behaviour_policy_sample( + self, + state: Union[torch.Tensor, TensorDict], + batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, + solver_config: EasyDict = None, + t_span: torch.Tensor = None, + ) -> Union[torch.Tensor, TensorDict]: + """ + Overview: + Return the output of behaviour policy, which is the action conditioned on the state. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + solver_config (:obj:`EasyDict`): The configuration for the ODE solver. + t_span (:obj:`torch.Tensor`): The time span for the ODE solver or SDE solver. + Returns: + action (:obj:`Union[torch.Tensor, TensorDict]`): The output action. + """ + return self.diffusion_model.sample_without_energy_guidance( + t_span=t_span, + condition=state, + batch_size=batch_size, + solver_config=solver_config, + )
+ +
[docs] def compute_q( + self, + state: Union[torch.Tensor, TensorDict], + action: Union[torch.Tensor, TensorDict], + ) -> torch.Tensor: + """ + Overview: + Calculate the Q value. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + action (:obj:`Union[torch.Tensor, TensorDict]`): The input action. + Returns: + q (:obj:`torch.Tensor`): The Q value. + """ + + return self.critic(action, state)
+ +
[docs] def behaviour_policy_loss( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict], + ): + """ + Overview: + Calculate the behaviour policy loss. + Arguments: + action (:obj:`torch.Tensor`): The input action. + state (:obj:`torch.Tensor`): The input state. + """ + + return self.diffusion_model.score_matching_loss( + action, state, weighting_scheme="vanilla" + )
+ +
[docs] def energy_guidance_loss( + self, + state: Union[torch.Tensor, TensorDict], + fake_next_action: Union[torch.Tensor, TensorDict], + ) -> torch.Tensor: + """ + Overview: + Calculate the energy guidance loss of QGPO. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + fake_next_action (:obj:`Union[torch.Tensor, TensorDict]`): The input fake next action. + """ + + return self.diffusion_model.energy_guidance_loss(fake_next_action, state)
+ +
[docs] def q_loss( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict], + reward: Union[torch.Tensor, TensorDict], + next_state: Union[torch.Tensor, TensorDict], + done: Union[torch.Tensor, TensorDict], + fake_next_action: Union[torch.Tensor, TensorDict], + discount_factor: float = 1.0, + ) -> torch.Tensor: + """ + Overview: + Calculate the Q loss. + Arguments: + action (:obj:`torch.Tensor`): The input action. + state (:obj:`torch.Tensor`): The input state. + reward (:obj:`torch.Tensor`): The input reward. + next_state (:obj:`torch.Tensor`): The input next state. + done (:obj:`torch.Tensor`): The input done. + fake_next_action (:obj:`torch.Tensor`): The input fake next action. + discount_factor (:obj:`float`): The discount factor. + """ + return self.critic.q_loss( + action, state, reward, next_state, done, fake_next_action, discount_factor + )
+ + +
[docs]class QGPOAlgorithm: + """ + Overview: + Q-guided policy optimization (QGPO) algorithm, which is an offline reinforcement learning algorithm that uses energy-based diffusion model for policy modeling. + Interfaces: + ``__init__``, ``train``, ``deploy`` + """ + +
[docs] def __init__( + self, + config: EasyDict = None, + simulator=None, + dataset: QGPODataset = None, + model: Union[torch.nn.Module, torch.nn.ModuleDict] = None, + ): + """ + Overview: + Initialize the QGPO algorithm. + Arguments: + config (:obj:`EasyDict`): The configuration , which must contain the following keys: + train (:obj:`EasyDict`): The training configuration. + deploy (:obj:`EasyDict`): The deployment configuration. + simulator (:obj:`object`): The environment simulator. + dataset (:obj:`QGPODataset`): The dataset. + model (:obj:`Union[torch.nn.Module, torch.nn.ModuleDict]`): The model. + Interface: + ``__init__``, ``train``, ``deploy`` + """ + self.config = config + self.simulator = simulator + self.dataset = dataset + + # --------------------------------------- + # Customized model initialization code ↓ + # --------------------------------------- + + self.model = model if model is not None else torch.nn.ModuleDict() + + if model is not None: + self.model = model + self.behaviour_policy_train_epoch = 0 + self.energy_guidance_train_epoch = 0 + self.critic_train_epoch = 0 + else: + self.model = torch.nn.ModuleDict() + config = self.config.train + assert hasattr(config.model, "QGPOPolicy") + + if torch.__version__ >= "2.0.0": + self.model["QGPOPolicy"] = torch.compile( + QGPOPolicy(config.model.QGPOPolicy).to( + config.model.QGPOPolicy.device + ) + ) + else: + self.model["QGPOPolicy"] = QGPOPolicy(config.model.QGPOPolicy).to( + config.model.QGPOPolicy.device + ) + + if ( + hasattr(config.parameter, "checkpoint_path") + and config.parameter.checkpoint_path is not None + ): + self.behaviour_policy_train_epoch = load_model( + path=config.parameter.checkpoint_path, + model=self.model["QGPOPolicy"].diffusion_model.model, + optimizer=None, + prefix="behaviour_policy", + ) + + self.energy_guidance_train_epoch = load_model( + path=config.parameter.checkpoint_path, + model=self.model["QGPOPolicy"].diffusion_model.energy_guidance, + optimizer=None, + prefix="energy_guidance", + ) + + self.critic_train_epoch = load_model( + path=config.parameter.checkpoint_path, + model=self.model["QGPOPolicy"].critic, + optimizer=None, + prefix="critic", + ) + else: + self.behaviour_policy_train_epoch = 0 + self.energy_guidance_train_epoch = 0 + self.critic_train_epoch = 0
+ + # --------------------------------------- + # Customized model initialization code ↑ + # --------------------------------------- + +
[docs] def train(self, config: EasyDict = None): + """ + Overview: + Train the model using the given configuration. \ + A weight-and-bias run will be created automatically when this function is called. + Arguments: + config (:obj:`EasyDict`): The training configuration. + """ + + config = ( + merge_two_dicts_into_newone( + self.config.train if hasattr(self.config, "train") else EasyDict(), + config, + ) + if config is not None + else self.config.train + ) + + with wandb.init(**config.wandb) as wandb_run: + config = merge_two_dicts_into_newone(EasyDict(wandb_run.config), config) + wandb_run.config.update(config) + self.config.train = config + + self.simulator = ( + create_simulator(config.simulator) + if hasattr(config, "simulator") + else self.simulator + ) + self.dataset = ( + create_dataset(config.dataset) + if hasattr(config, "dataset") + else self.dataset + ) + + # --------------------------------------- + # Customized training code ↓ + # --------------------------------------- + + def generate_fake_action(model, states, action_augment_num): + # model.eval() + fake_actions_sampled = [] + if isinstance(states, TensorDict): + from torchrl.data import LazyTensorStorage + + storage = LazyTensorStorage(max_size=states.shape[0]) + storage.set( + range(states.shape[0]), + TensorDict( + { + "s": states, + }, + batch_size=[states.shape[0]], + ), + ) + for index in torch.split(torch.arange(0, states.shape[0], 1), 4096): + index = index.int() + data = storage[index] + fake_actions_per_state = [] + for _ in range(action_augment_num): + fake_actions_per_state.append( + model.sample( + state=data["s"].to(config.model.QGPOPolicy.device), + guidance_scale=0.0, + t_span=( + torch.linspace( + 0.0, 1.0, config.parameter.fake_data_t_span + ).to(config.model.QGPOPolicy.device) + if config.parameter.fake_data_t_span is not None + else None + ), + ) + ) + fake_actions_sampled.append( + torch.stack(fake_actions_per_state, dim=1) + ) + else: + for states in track( + np.array_split(states, states.shape[0] // 4096 + 1), + description="Generate fake actions", + ): + # TODO: mkae it batchsize + fake_actions_per_state = [] + for _ in range(action_augment_num): + fake_actions_per_state.append( + model.sample( + state=states, + guidance_scale=0.0, + t_span=( + torch.linspace( + 0.0, 1.0, config.parameter.fake_data_t_span + ).to(states.device) + if config.parameter.fake_data_t_span is not None + else None + ), + ) + ) + fake_actions_sampled.append( + torch.stack(fake_actions_per_state, dim=1) + ) + fake_actions = torch.cat(fake_actions_sampled, dim=0) + return fake_actions + + def evaluate(model, epoch): + evaluation_results = dict() + for guidance_scale in config.parameter.evaluation.guidance_scale: + + def policy(obs: np.ndarray) -> np.ndarray: + if isinstance(obs, np.ndarray): + obs = torch.tensor( + obs, + dtype=torch.float32, + device=config.model.QGPOPolicy.device, + ).unsqueeze(0) + elif isinstance(obs, dict): + for key in obs: + obs[key] = torch.tensor( + obs[key], + dtype=torch.float32, + device=config.model.QGPOPolicy.device, + ).unsqueeze(0) + if obs[key].dim() == 1 and obs[key].shape[0] == 1: + obs[key] = obs[key].unsqueeze(1) + obs = TensorDict(obs, batch_size=[1]) + else: + raise ValueError("Unsupported observation type.") + action = ( + model.sample( + state=obs, + guidance_scale=guidance_scale, + t_span=( + torch.linspace( + 0.0, 1.0, config.parameter.fake_data_t_span + ).to(config.model.QGPOPolicy.device) + if config.parameter.fake_data_t_span is not None + else None + ), + ) + .squeeze(0) + .cpu() + .detach() + .numpy() + ) + return action + + evaluation_results[ + f"evaluation/guidance_scale:[{guidance_scale}]/total_return" + ] = self.simulator.evaluate(policy=policy,)[0]["total_return"] + log.info( + f"Train epoch: {epoch}, guidance_scale: {guidance_scale}, total_return: {evaluation_results[f'evaluation/guidance_scale:[{guidance_scale}]/total_return']}" + ) + + return evaluation_results + + replay_buffer = TensorDictReplayBuffer( + storage=self.dataset.storage, + batch_size=config.parameter.behaviour_policy.batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + pin_memory=True, + ) + + behaviour_model_optimizer = torch.optim.Adam( + self.model["QGPOPolicy"].diffusion_model.model.parameters(), + lr=config.parameter.behaviour_policy.learning_rate, + ) + + for epoch in track( + range(config.parameter.behaviour_policy.epochs), + description="Behaviour policy training", + ): + + if self.behaviour_policy_train_epoch >= epoch: + continue + + counter = 0 + behaviour_model_training_loss_sum = 0.0 + for index, data in enumerate(replay_buffer): + + behaviour_model_training_loss = self.model[ + "QGPOPolicy" + ].behaviour_policy_loss( + data["a"].to(config.model.QGPOPolicy.device), + data["s"].to(config.model.QGPOPolicy.device), + ) + behaviour_model_optimizer.zero_grad() + behaviour_model_training_loss.backward() + behaviour_model_optimizer.step() + + counter += 1 + behaviour_model_training_loss_sum += ( + behaviour_model_training_loss.item() + ) + + if ( + epoch == 0 + or (epoch + 1) % config.parameter.evaluation.evaluation_interval + == 0 + ): + evaluation_results = evaluate(self.model["QGPOPolicy"], epoch=epoch) + wandb_run.log(data=evaluation_results, commit=False) + save_model( + path=config.parameter.checkpoint_path, + model=self.model["QGPOPolicy"].diffusion_model.model, + optimizer=behaviour_model_optimizer, + iteration=epoch, + prefix="behaviour_policy", + ) + + self.behaviour_policy_train_epoch = epoch + + wandb_run.log( + data=dict( + behaviour_policy_train_epoch=epoch, + behaviour_model_training_loss=behaviour_model_training_loss_sum + / counter, + ), + commit=True, + ) + + fake_actions = generate_fake_action( + self.model["QGPOPolicy"], + self.dataset.states[:].to(config.model.QGPOPolicy.device), + config.parameter.action_augment_num, + ).to("cpu") + fake_next_actions = generate_fake_action( + self.model["QGPOPolicy"], + self.dataset.next_states[:].to(config.model.QGPOPolicy.device), + config.parameter.action_augment_num, + ).to("cpu") + + self.dataset.load_fake_actions( + fake_actions=fake_actions, + fake_next_actions=fake_next_actions, + ) + + # TODO add notation + replay_buffer = TensorDictReplayBuffer( + storage=self.dataset.storage, + batch_size=config.parameter.energy_guided_policy.batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + pin_memory=True, + ) + + q_optimizer = torch.optim.Adam( + self.model["QGPOPolicy"].critic.q.parameters(), + lr=config.parameter.critic.learning_rate, + ) + + energy_guidance_optimizer = torch.optim.Adam( + self.model["QGPOPolicy"].diffusion_model.energy_guidance.parameters(), + lr=config.parameter.energy_guidance.learning_rate, + ) + + with Progress() as progress: + critic_training = progress.add_task( + "Critic training", + total=config.parameter.critic.stop_training_epochs, + ) + energy_guidance_training = progress.add_task( + "Energy guidance training", + total=config.parameter.energy_guidance.epochs, + ) + + for epoch in range(config.parameter.energy_guidance.epochs): + + if self.energy_guidance_train_epoch >= epoch: + continue + + counter = 0 + q_loss_sum = 0.0 + energy_guidance_loss_sum = 0.0 + + for index, data in enumerate(replay_buffer): + + if epoch < config.parameter.critic.stop_training_epochs: + + q_loss = self.model["QGPOPolicy"].q_loss( + data["a"].to(config.model.QGPOPolicy.device), + data["s"].to(config.model.QGPOPolicy.device), + data["r"].to(config.model.QGPOPolicy.device), + data["s_"].to(config.model.QGPOPolicy.device), + data["d"].to(config.model.QGPOPolicy.device), + data["fake_a_"].to(config.model.QGPOPolicy.device), + discount_factor=config.parameter.critic.discount_factor, + ) + + q_optimizer.zero_grad() + q_loss.backward() + q_optimizer.step() + q_loss_sum += q_loss.item() + + # Update target + for param, target_param in zip( + self.model["QGPOPolicy"].critic.q.parameters(), + self.model["QGPOPolicy"].critic.q_target.parameters(), + ): + target_param.data.copy_( + config.parameter.critic.update_momentum * param.data + + (1 - config.parameter.critic.update_momentum) + * target_param.data + ) + + energy_guidance_loss = self.model[ + "QGPOPolicy" + ].energy_guidance_loss( + data["s"].to(config.model.QGPOPolicy.device), + data["fake_a"].to(config.model.QGPOPolicy.device), + ) + energy_guidance_optimizer.zero_grad() + energy_guidance_loss.backward() + energy_guidance_optimizer.step() + energy_guidance_loss_sum += energy_guidance_loss.item() + + counter += 1 + + if epoch < config.parameter.critic.stop_training_epochs: + progress.update(critic_training, advance=1) + progress.update(energy_guidance_training, advance=1) + + if ( + epoch == 0 + or (epoch + 1) % config.parameter.evaluation.evaluation_interval + == 0 + ): + evaluation_results = evaluate( + self.model["QGPOPolicy"], epoch=epoch + ) + wandb_run.log(data=evaluation_results, commit=False) + save_model( + path=config.parameter.checkpoint_path, + model=self.model[ + "QGPOPolicy" + ].diffusion_model.energy_guidance, + optimizer=energy_guidance_optimizer, + iteration=epoch, + prefix="energy_guidance", + ) + save_model( + path=config.parameter.checkpoint_path, + model=self.model["QGPOPolicy"].critic, + optimizer=q_optimizer, + iteration=epoch, + prefix="critic", + ) + + self.energy_guidance_train_epoch = epoch + self.critic_train_epoch = epoch + + wandb_run.log( + data=dict( + energy_guidance_train_epoch=epoch, + critic_train_epoch=epoch, + q_loss=q_loss_sum / counter, + energy_guidance_loss=energy_guidance_loss_sum / counter, + ), + commit=True, + ) + + # --------------------------------------- + # Customized training code ↑ + # --------------------------------------- + + wandb.finish()
+ +
[docs] def deploy(self, config: EasyDict = None) -> QGPOAgent: + """ + Overview: + Deploy the model using the given configuration. + Arguments: + config (:obj:`EasyDict`): The deployment configuration. + """ + + if config is not None: + config = merge_two_dicts_into_newone(self.config.deploy, config) + else: + config = self.config.deploy + + assert "QGPOPolicy" in self.model, "The model must be trained first." + return QGPOAgent( + config=config, + model=copy.deepcopy(self.model), + )
+
+ +
+ +
+
+ + +
+ +
+

+ © Copyright 2024, OpenDILab Contributors. + +

+
+ +
+ Built with Sphinx using a theme provided by Read the + Docs. +
+ + +
+
+
+ +
+
+
+ +
+
+
+
+
+ + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ + + + + +
+
+
+
+ + +
+
+
+ + +
+ + + + + + + + + \ No newline at end of file diff --git a/_modules/grl/algorithms/srpo.html b/_modules/grl/algorithms/srpo.html new file mode 100644 index 0000000..3cca6b4 --- /dev/null +++ b/_modules/grl/algorithms/srpo.html @@ -0,0 +1,1137 @@ + + + + + + + + + + + + + + + grl.algorithms.srpo — GenerativeRL v0.0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ +
+
+ + + + + + + + + + + +
+
+
+ + + + + + + + + + + + + + + + +
+ + + + +
+
+ +
+ Shortcuts +
+
+ +
+
+ +
+ +
+
+ +

Source code for grl.algorithms.srpo

+#############################################################
+# This SRPO model is a modification implementation from https://github.com/thu-ml/SRPO
+#############################################################
+import copy
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from easydict import EasyDict
+from rich.progress import track
+from tensordict import TensorDict
+from torchrl.data import TensorDictReplayBuffer
+from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
+from grl.rl_modules.value_network.value_network import VNetwork, DoubleVNetwork
+import wandb
+from grl.agents.srpo import SRPOAgent
+from grl.datasets import create_dataset
+from grl.neural_network.encoders import get_encoder
+from grl.generative_models.sro import SRPOConditionalDiffusionModel
+from grl.neural_network import MultiLayerPerceptron
+from grl.rl_modules.simulators import create_simulator
+from grl.rl_modules.value_network.q_network import DoubleQNetwork
+from grl.utils import set_seed
+from grl.utils.config import merge_two_dicts_into_newone
+from grl.utils.log import log
+from grl.utils.model_utils import save_model, load_model
+
+
+class Dirac_Policy(nn.Module):
+    """
+    Overview:
+        The deterministic policy network used in SRPO algorithm.
+    Interfaces:
+        ``__init__``, ``forward``, ``select_actions``
+    """
+
+    def __init__(self, config: EasyDict):
+        super().__init__()
+        action_dim = config.action_dim
+        state_dim = config.state_dim
+        layer = config.layer
+        self.net = MultiLayerPerceptron(
+            hidden_sizes=[state_dim] + [256 for _ in range(layer)],
+            output_size=action_dim,
+            activation="relu",
+            final_activation="tanh",
+        )
+
+        if hasattr(config, "state_encoder"):
+            self.state_encoder = get_encoder(config.state_encoder.type)(
+                **config.state_encoder.args
+            )
+        else:
+            self.state_encoder = torch.nn.Identity()
+
+    def forward(self, state: torch.Tensor):
+        state = self.state_encoder(state)
+        return self.net(state)
+
+    def select_actions(self, state: torch.Tensor):
+        return self(state)
+
+
+def asymmetric_l2_loss(u, tau):
+    return torch.mean(torch.abs(tau - (u < 0).float()) * u**2)
+
+
+
[docs]class SRPOCritic(nn.Module): + """ + Overview: + The critic network used in SRPO algorithm. + Interfaces: + ``__init__``, ``v_loss``, ``q_loss + """ + +
[docs] def __init__(self, config: EasyDict): + """ + Overview: + Initialize the critic network. + Arguments: + config (:obj:`EasyDict`): The configuration. + """ + super().__init__() + self.config = config + self.q_alpha = config.q_alpha + self.q = DoubleQNetwork(config.DoubleQNetwork) + self.q_target = copy.deepcopy(self.q).requires_grad_(False) + self.v = VNetwork(config.VNetwork)
+ +
[docs] def forward( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict] = None, + ) -> torch.Tensor: + """ + Overview: + Return the output of critic. + Arguments: + action (:obj:`torch.Tensor`): The input action. + state (:obj:`torch.Tensor`): The input state. + """ + + return self.q(action, state)
+ + def v_loss(self, state, action, next_state, tau): + with torch.no_grad(): + target_q = self.q_target(action, state).detach() + next_v = self.v(next_state).detach() + # Update value function + v = self.v(state) + adv = target_q - v + v_loss = asymmetric_l2_loss(adv, tau) + return v_loss, next_v + + def iql_q_loss(self, state, action, reward, done, next_v, discount): + q_target = reward + (1.0 - done.float()) * discount * next_v.detach() + qs = self.q.compute_double_q(action, state) + q_loss = sum(torch.nn.functional.mse_loss(q, q_target) for q in qs) / len(qs) + return q_loss, torch.mean(qs[0]), torch.mean(q_target)
+ + +
[docs]class SRPOPolicy(nn.Module): + """ + Overview: + The SRPO policy network. + Interfaces: + ``__init__``, ``forward``, ``sample``, ``behaviour_policy_loss``, ``srpo_actor_loss`` + """ + +
[docs] def __init__(self, config: EasyDict): + """ + Overview: + Initialize the SRPO policy network. + Arguments: + config (:obj:`EasyDict`): The configuration. + """ + super().__init__() + self.config = config + self.device = config.device + + self.policy = Dirac_Policy(config.policy_model) + self.critic = SRPOCritic(config.critic) + self.sro = SRPOConditionalDiffusionModel( + config=config.diffusion_model, + value_model=self.critic, + distribution_model=self.policy, + )
+ +
[docs] def sample( + self, + state: Union[torch.Tensor, TensorDict], + batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, + solver_config: EasyDict = None, + t_span: torch.Tensor = None, + ) -> Union[torch.Tensor, TensorDict]: + """ + Overview: + Return the output of SRPO policy, which is the action conditioned on the state. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + solver_config (:obj:`EasyDict`): The configuration for the ODE solver. + t_span (:obj:`torch.Tensor`): The time span for the ODE solver or SDE solver. + Returns: + action (:obj:`Union[torch.Tensor, TensorDict]`): The output action. + """ + return self.sro.diffusion_model.sample( + t_span=t_span, + condition=state, + batch_size=batch_size, + with_grad=False, + solver_config=solver_config, + )
+ +
[docs] def forward( + self, state: Union[torch.Tensor, TensorDict] + ) -> Union[torch.Tensor, TensorDict]: + """ + Overview: + Return the output of SRPO policy, which is the action conditioned on the state. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + Returns: + action (:obj:`Union[torch.Tensor, TensorDict]`): The output action. + """ + return self.policy.select_actions(state)
+ +
[docs] def behaviour_policy_loss( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict], + ): + """ + Overview: + Calculate the behaviour policy loss. + Arguments: + action (:obj:`torch.Tensor`): The input action. + state (:obj:`torch.Tensor`): The input state. + """ + + return self.sro.score_matching_loss(action, state)
+ +
[docs] def srpo_actor_loss( + self, + state, + ) -> torch.Tensor: + """ + Overview: + Calculate the Q loss. + Arguments: + action (:obj:`torch.Tensor`): The input action. + state (:obj:`torch.Tensor`): The input state. + reward (:obj:`torch.Tensor`): The input reward. + next_state (:obj:`torch.Tensor`): The input next state. + done (:obj:`torch.Tensor`): The input done. + fake_next_action (:obj:`torch.Tensor`): The input fake next action. + discount_factor (:obj:`float`): The discount factor. + """ + loss, q = self.sro.srpo_loss(state) + return loss, q
+ + +
[docs]class SRPOAlgorithm: + +
[docs] def __init__( + self, + config: EasyDict = None, + simulator=None, + dataset=None, + model: Union[torch.nn.Module, torch.nn.ModuleDict] = None, + ): + """ + Overview: + Initialize the SRPO algorithm. + Arguments: + config (:obj:`EasyDict`): The configuration , which must contain the following keys: + train (:obj:`EasyDict`): The training configuration. + deploy (:obj:`EasyDict`): The deployment configuration. + simulator (:obj:`object`): The environment simulator. + dataset (:obj:`Dataset`): The dataset. + model (:obj:`Union[torch.nn.Module, torch.nn.ModuleDict]`): The model. + Interface: + ``__init__``, ``train``, ``deploy`` + """ + self.config = config + self.simulator = simulator + self.dataset = dataset + + # --------------------------------------- + # Customized model initialization code ↓ + # --------------------------------------- + + self.model = model if model is not None else torch.nn.ModuleDict() + + if model is not None: + self.model = model + self.behaviour_train_epoch = 0 + self.critic_train_epoch = 0 + self.policy_train_epoch = 0 + else: + self.model = torch.nn.ModuleDict() + config = self.config.train + assert hasattr(config.model, "SRPOPolicy") + + if torch.__version__ >= "2.0.0": + self.model["SRPOPolicy"] = torch.compile( + SRPOPolicy(config.model.SRPOPolicy).to( + config.model.SRPOPolicy.device + ) + ) + else: + self.model["SRPOPolicy"] = SRPOPolicy(config.model.SRPOPolicy).to( + config.model.SRPOPolicy.device + ) + + if ( + hasattr(config.parameter, "checkpoint_path") + and config.parameter.checkpoint_path is not None + ): + self.behaviour_train_epoch = load_model( + path=config.parameter.checkpoint_path, + model=self.model["SRPOPolicy"].sro.diffusion_model.model, + optimizer=None, + prefix="behaviour_policy", + ) + + self.critic_train_epoch = load_model( + path=config.parameter.checkpoint_path, + model=self.model["SRPOPolicy"].critic, + optimizer=None, + prefix="critic", + ) + + self.policy_train_epoch = load_model( + path=config.parameter.checkpoint_path, + model=self.model["SRPOPolicy"].policy, + optimizer=None, + prefix="policy", + ) + else: + self.behaviour_policy_train_epoch = 0 + self.energy_guidance_train_epoch = 0 + self.critic_train_epoch = 0
+ + # --------------------------------------- + # Customized model initialization code ↑ + # --------------------------------------- + +
[docs] def train(self, config: EasyDict = None): + """ + Overview: + Train the model using the given configuration. \ + A weight-and-bias run will be created automatically when this function is called. + Arguments: + config (:obj:`EasyDict`): The training configuration. + """ + set_seed(self.config.deploy.env["seed"]) + + config = ( + merge_two_dicts_into_newone( + self.config.train if hasattr(self.config, "train") else EasyDict(), + config, + ) + if config is not None + else self.config.train + ) + + with wandb.init( + project=( + config.project if hasattr(config, "project") else __class__.__name__ + ), + **config.wandb if hasattr(config, "wandb") else {}, + ) as wandb_run: + config = merge_two_dicts_into_newone(EasyDict(wandb_run.config), config) + wandb_run.config.update(config) + self.config.train = config + + self.simulator = ( + create_simulator(config.simulator) + if hasattr(config, "simulator") + else self.simulator + ) + self.dataset = ( + create_dataset(config.dataset) + if hasattr(config, "dataset") + else self.dataset + ) + + def evaluate(model, train_epoch, method="diffusion", repeat=1): + evaluation_results = dict() + + if method == "diffusion": + + def policy(obs: np.ndarray) -> np.ndarray: + if isinstance(obs, np.ndarray): + obs = torch.tensor( + obs, + dtype=torch.float32, + device=config.model.SRPOPolicy.device, + ).unsqueeze(0) + elif isinstance(obs, dict): + for key in obs: + obs[key] = torch.tensor( + obs[key], + dtype=torch.float32, + device=config.model.SRPOPolicy.device, + ).unsqueeze(0) + if obs[key].dim() == 1 and obs[key].shape[0] == 1: + obs[key] = obs[key].unsqueeze(1) + obs = TensorDict(obs, batch_size=[1]) + + action = ( + model.sample( + state=obs, + t_span=( + torch.linspace( + 0.0, 1.0, config.parameter.t_span + ).to(config.model.SRPOPolicy.device) + if hasattr(config.parameter, "t_span") + and config.parameter.t_span is not None + else None + ), + ) + .squeeze(0) + .cpu() + .detach() + .numpy() + ) + return action + + elif method == "diracpolicy": + + def policy(obs: np.ndarray) -> np.ndarray: + if isinstance(obs, np.ndarray): + obs = torch.tensor( + obs, + dtype=torch.float32, + device=config.model.SRPOPolicy.device, + ).unsqueeze(0) + elif isinstance(obs, dict): + for key in obs: + obs[key] = torch.tensor( + obs[key], + dtype=torch.float32, + device=config.model.SRPOPolicy.device, + ).unsqueeze(0) + if obs[key].dim() == 1 and obs[key].shape[0] == 1: + obs[key] = obs[key].unsqueeze(1) + obs = TensorDict(obs, batch_size=[1]) + + action = model(obs).squeeze(0).cpu().detach().numpy() + return action + + eval_results = self.simulator.evaluate( + policy=policy, num_episodes=repeat + ) + return_results = [ + eval_results[i]["total_return"] for i in range(repeat) + ] + log.info(f"Return: {return_results}") + return_mean = np.mean(return_results) + return_std = np.std(return_results) + return_max = np.max(return_results) + return_min = np.min(return_results) + evaluation_results[f"evaluation/return_mean"] = return_mean + evaluation_results[f"evaluation/return_std"] = return_std + evaluation_results[f"evaluation/return_max"] = return_max + evaluation_results[f"evaluation/return_min"] = return_min + + if repeat > 1: + log.info( + f"Train epoch: {train_epoch}, return_mean: {return_mean}, return_std: {return_std}, return_max: {return_max}, return_min: {return_min}" + ) + else: + log.info(f"Train epoch: {train_epoch}, return: {return_mean}") + + return evaluation_results + + # --------------------------------------- + # Customized training code ↓ + # --------------------------------------- + + replay_buffer = TensorDictReplayBuffer( + storage=self.dataset.storage, + batch_size=config.parameter.behaviour_policy.batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + pin_memory=True, + ) + + behaviour_model_optimizer = torch.optim.Adam( + self.model["SRPOPolicy"].sro.diffusion_model.model.parameters(), + lr=config.parameter.behaviour_policy.learning_rate, + ) + + for epoch in track( + range(config.parameter.behaviour_policy.iterations), + description="Behaviour policy training", + ): + if self.behaviour_train_epoch >= epoch: + continue + + counter = 0 + behaviour_model_training_loss_sum = 0.0 + for index, data in enumerate(replay_buffer): + behaviour_model_training_loss = self.model[ + "SRPOPolicy" + ].behaviour_policy_loss( + data["a"].to(config.model.SRPOPolicy.device), + data["s"].to(config.model.SRPOPolicy.device), + ) + behaviour_model_optimizer.zero_grad() + behaviour_model_training_loss.backward() + behaviour_model_optimizer.step() + counter += 1 + behaviour_model_training_loss_sum += ( + behaviour_model_training_loss.item() + ) + + self.behaviour_policy_train_epoch = epoch + + if ( + epoch == 0 + or (epoch + 1) % config.parameter.evaluation.evaluation_interval + == 0 + ): + evaluation_results = evaluate( + self.model["SRPOPolicy"], + train_epoch=epoch, + method="diffusion", + repeat=( + 1 + if not hasattr(config.parameter.evaluation, "repeat") + else config.parameter.evaluation.repeat + ), + ) + wandb_run.log(data=evaluation_results, commit=False) + save_model( + path=config.parameter.checkpoint_path, + model=self.model["SRPOPolicy"].sro.diffusion_model.model, + optimizer=behaviour_model_optimizer, + iteration=epoch, + prefix="behaviour_policy", + ) + + wandb_run.log( + data=dict( + behaviour_policy_train_epoch=epoch, + behaviour_model_training_loss=behaviour_model_training_loss_sum + / counter, + ), + commit=True, + ) + + replay_buffer = TensorDictReplayBuffer( + storage=self.dataset.storage, + batch_size=config.parameter.critic.batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + pin_memory=True, + ) + + q_optimizer = torch.optim.Adam( + self.model["SRPOPolicy"].critic.q.parameters(), + lr=config.parameter.critic.learning_rate, + ) + v_optimizer = torch.optim.Adam( + self.model["SRPOPolicy"].critic.v.parameters(), + lr=config.parameter.critic.learning_rate, + ) + + for epoch in track( + range(config.parameter.critic.iterations), + description="Critic training", + ): + if self.critic_train_epoch >= epoch: + continue + + counter = 1 + + v_loss_sum = 0.0 + v_sum = 0.0 + q_loss_sum = 0.0 + q_sum = 0.0 + q_target_sum = 0.0 + for index, data in enumerate(replay_buffer): + + v_loss, next_v = self.model["SRPOPolicy"].critic.v_loss( + state=data["s"].to(config.model.SRPOPolicy.device), + action=data["a"].to(config.model.SRPOPolicy.device), + next_state=data["s_"].to(config.model.SRPOPolicy.device), + tau=config.parameter.critic.tau, + ) + v_optimizer.zero_grad(set_to_none=True) + v_loss.backward() + v_optimizer.step() + q_loss, q, q_target = self.model["SRPOPolicy"].critic.iql_q_loss( + state=data["s"].to(config.model.SRPOPolicy.device), + action=data["a"].to(config.model.SRPOPolicy.device), + reward=data["r"].to(config.model.SRPOPolicy.device), + done=data["d"].to(config.model.SRPOPolicy.device), + next_v=next_v, + discount=config.parameter.critic.discount_factor, + ) + q_optimizer.zero_grad(set_to_none=True) + q_loss.backward() + q_optimizer.step() + + # Update target + for param, target_param in zip( + self.model["SRPOPolicy"].critic.q.parameters(), + self.model["SRPOPolicy"].critic.q_target.parameters(), + ): + target_param.data.copy_( + config.parameter.critic.update_momentum * param.data + + (1 - config.parameter.critic.update_momentum) + * target_param.data + ) + + counter += 1 + + q_loss_sum += q_loss.item() + q_sum += q.mean().item() + q_target_sum += q_target.mean().item() + + v_loss_sum += v_loss.item() + v_sum += next_v.mean().item() + self.critic_train_epoch = epoch + + wandb.log( + data=dict(v_loss=v_loss_sum / counter, v=v_sum / counter), + commit=False, + ) + + wandb.log( + data=dict( + critic_train_epoch=epoch, + q_loss=q_loss_sum / counter, + q=q_sum / counter, + q_target=q_target_sum / counter, + ), + commit=True, + ) + + if ( + epoch == 0 + or (epoch + 1) % config.parameter.evaluation.evaluation_interval + == 0 + ): + save_model( + path=config.parameter.checkpoint_path, + model=self.model["SRPOPolicy"].critic, + optimizer=q_optimizer, + iteration=epoch, + prefix="critic", + ) + + replay_buffer = TensorDictReplayBuffer( + storage=self.dataset.storage, + batch_size=config.parameter.policy.batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + pin_memory=True, + ) + SRPO_policy_optimizer = torch.optim.Adam( + self.model["SRPOPolicy"].policy.parameters(), + lr=config.parameter.policy.learning_rate, + ) + SRPO_policy_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + SRPO_policy_optimizer, + T_max=config.parameter.policy.tmax, + eta_min=0.0, + ) + + for epoch in track( + range(config.parameter.policy.iterations), + description="Policy training", + ): + counter = 0 + policy_loss_sum = 0 + if self.policy_train_epoch >= epoch: + continue + + for index, data in enumerate(replay_buffer): + self.model["SRPOPolicy"].sro.diffusion_model.model.eval() + policy_loss, q = self.model["SRPOPolicy"].srpo_actor_loss( + data["s"].to(config.model.SRPOPolicy.device) + ) + policy_loss = policy_loss.sum(-1).mean() + SRPO_policy_optimizer.zero_grad(set_to_none=True) + policy_loss.backward() + SRPO_policy_optimizer.step() + SRPO_policy_lr_scheduler.step() + counter += 1 + policy_loss_sum += policy_loss + + if ( + epoch == 0 + or (epoch + 1) % config.parameter.evaluation.evaluation_interval + == 0 + ): + evaluation_results = evaluate( + self.model["SRPOPolicy"], + train_epoch=epoch, + method="diracpolicy", + repeat=( + 1 + if not hasattr(config.parameter.evaluation, "repeat") + else config.parameter.evaluation.repeat + ), + ) + wandb_run.log( + data=evaluation_results, + commit=False, + ) + save_model( + path=config.parameter.checkpoint_path, + model=self.model["SRPOPolicy"].policy, + optimizer=SRPO_policy_optimizer, + iteration=epoch, + prefix="policy", + ) + wandb.log( + data=dict( + policy_loss=policy_loss_sum / counter, + ), + commit=True, + ) + # --------------------------------------- + # Customized training code ↑ + # --------------------------------------- + + wandb.finish()
+ +
[docs] def deploy(self, config: EasyDict = None) -> SRPOAgent: + """ + Overview: + Deploy the model using the given configuration. + Arguments: + config (:obj:`EasyDict`): The deployment configuration. + """ + + if config is not None: + config = merge_two_dicts_into_newone(self.config.deploy, config) + else: + config = self.config.deploy + + return SRPOAgent( + config=config, + model=copy.deepcopy(self.model), + )
+
+ +
+ +
+
+ + +
+ +
+

+ © Copyright 2024, OpenDILab Contributors. + +

+
+ +
+ Built with Sphinx using a theme provided by Read the + Docs. +
+ + +
+
+
+ +
+
+
+ +
+
+
+
+
+ + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ + + + + +
+
+
+
+ + +
+
+
+ + +
+ + + + + + + + + \ No newline at end of file diff --git a/_modules/grl/rl_modules/simulators/gym_env_simulator.html b/_modules/grl/rl_modules/simulators/gym_env_simulator.html new file mode 100644 index 0000000..9276eef --- /dev/null +++ b/_modules/grl/rl_modules/simulators/gym_env_simulator.html @@ -0,0 +1,780 @@ + + + + + + + + + + + + + + + grl.rl_modules.simulators.gym_env_simulator — GenerativeRL v0.0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ +
+
+ + + + + + + + + + + +
+
+
+ + + + + + + + + + + + + + + + +
+ + + + +
+
+ +
+ Shortcuts +
+
+ +
+
+ +
+ +
+
+ +

Source code for grl.rl_modules.simulators.gym_env_simulator

+from typing import Callable, Dict, List, Union
+
+import gym
+import torch
+
+
+
[docs]class GymEnvSimulator: + """ + Overview: + A simple gym environment simulator in GenerativeRL. + This simulator is used to collect episodes and steps using a given policy in a gym environment. + It runs in single process and is suitable for small-scale experiments. + Interfaces: + ``__init__``, ``collect_episodes``, ``collect_steps``, ``evaluate`` + """ + +
[docs] def __init__(self, env_id: str) -> None: + """ + Overview: + Initialize the GymEnvSimulator according to the given configuration. + Arguments: + env_id (:obj:`str`): The id of the gym environment to simulate. + """ + self.env_id = env_id + self.collect_env = gym.make(self.env_id) + + if gym.__version__ >= "0.26.0": + self.last_state_obs, _ = self.collect_env.reset() + self.last_state_done = False + self.last_state_truncated = False + else: + self.last_state_obs = self.collect_env.reset() + self.last_state_done = False + + self.observation_space = self.collect_env.observation_space + self.action_space = self.collect_env.action_space
+ +
[docs] def collect_episodes( + self, + policy: Union[Callable, torch.nn.Module], + num_episodes: int = None, + num_steps: int = None, + ) -> List[Dict]: + """ + Overview: + Collect several episodes using the given policy. The environment will be reset at the beginning of each episode. + No history will be stored in this method. The collected information of steps will be returned as a list of dictionaries. + Arguments: + policy (:obj:`Union[Callable, torch.nn.Module]`): The policy to collect episodes. + num_episodes (:obj:`int`): The number of episodes to collect. + num_steps (:obj:`int`): The number of steps to collect. + """ + assert num_episodes is not None or num_steps is not None + if num_episodes is not None: + data_list = [] + with torch.no_grad(): + if gym.__version__ >= "0.26.0": + for i in range(num_episodes): + obs, _ = self.collect_env.reset() + done = False + truncated = False + while not done and not truncated: + action = policy(obs) + next_obs, reward, done, truncated, _ = ( + self.collect_env.step(action) + ) + data_list.append( + dict( + obs=obs, + action=action, + reward=reward, + truncated=truncated, + done=done, + next_obs=next_obs, + ) + ) + obs = next_obs + else: + for i in range(num_episodes): + obs = self.collect_env.reset() + done = False + while not done: + action = policy(obs) + next_obs, reward, done, _ = self.collect_env.step(action) + data_list.append( + dict( + obs=obs, + action=action, + reward=reward, + done=done, + next_obs=next_obs, + ) + ) + obs = next_obs + return data_list + elif num_steps is not None: + data_list = [] + with torch.no_grad(): + if gym.__version__ >= "0.26.0": + while len(data_list) < num_steps: + obs, _ = self.collect_env.reset() + done = False + truncated = False + while not done and not truncated: + action = policy(obs) + next_obs, reward, done, truncated, _ = ( + self.collect_env.step(action) + ) + data_list.append( + dict( + obs=obs, + action=action, + reward=reward, + truncated=truncated, + done=done, + next_obs=next_obs, + ) + ) + obs = next_obs + else: + while len(data_list) < num_steps: + obs = self.collect_env.reset() + done = False + while not done: + action = policy(obs) + next_obs, reward, done, _ = self.collect_env.step(action) + data_list.append( + dict( + obs=obs, + action=action, + reward=reward, + done=done, + next_obs=next_obs, + ) + ) + obs = next_obs + return data_list
+ +
[docs] def collect_steps( + self, + policy: Union[Callable, torch.nn.Module], + num_episodes: int = None, + num_steps: int = None, + random_policy: bool = False, + ) -> List[Dict]: + """ + Overview: + Collect several steps using the given policy. The environment will not be reset until the end of the episode. + Last observation will be stored in this method. The collected information of steps will be returned as a list of dictionaries. + Arguments: + policy (:obj:`Union[Callable, torch.nn.Module]`): The policy to collect steps. + num_episodes (:obj:`int`): The number of episodes to collect. + num_steps (:obj:`int`): The number of steps to collect. + random_policy (:obj:`bool`): Whether to use a random policy. + """ + assert num_episodes is not None or num_steps is not None + if num_episodes is not None: + data_list = [] + with torch.no_grad(): + if gym.__version__ >= "0.26.0": + for i in range(num_episodes): + obs, _ = self.collect_env.reset() + done = False + truncated = False + while not done and not truncated: + if random_policy: + action = self.collect_env.action_space.sample() + else: + action = policy(obs) + next_obs, reward, done, truncated, _ = ( + self.collect_env.step(action) + ) + data_list.append( + dict( + obs=obs, + action=action, + reward=reward, + truncated=truncated, + done=done, + next_obs=next_obs, + ) + ) + obs = next_obs + self.last_state_obs, _ = self.collect_env.reset() + self.last_state_done = False + self.last_state_truncated = False + else: + for i in range(num_episodes): + obs = self.collect_env.reset() + done = False + while not done: + if random_policy: + action = self.collect_env.action_space.sample() + else: + action = policy(obs) + next_obs, reward, done, _ = self.collect_env.step(action) + data_list.append( + dict( + obs=obs, + action=action, + reward=reward, + done=done, + next_obs=next_obs, + ) + ) + obs = next_obs + self.last_state_obs = self.collect_env.reset() + self.last_state_done = False + return data_list + elif num_steps is not None: + data_list = [] + with torch.no_grad(): + if gym.__version__ >= "0.26.0": + while len(data_list) < num_steps: + if not self.last_state_done or not self.last_state_truncated: + if random_policy: + action = self.collect_env.action_space.sample() + else: + action = policy(self.last_state_obs) + next_obs, reward, done, truncated, _ = ( + self.collect_env.step(action) + ) + data_list.append( + dict( + obs=self.last_state_obs, + action=action, + reward=reward, + truncated=truncated, + done=done, + next_obs=next_obs, + ) + ) + self.last_state_obs = next_obs + self.last_state_done = done + self.last_state_truncated = truncated + else: + self.last_state_obs, _ = self.collect_env.reset() + self.last_state_done = False + self.last_state_truncated = False + else: + while len(data_list) < num_steps: + if not self.last_state_done: + if random_policy: + action = self.collect_env.action_space.sample() + else: + action = policy(self.last_state_obs) + next_obs, reward, done, _ = self.collect_env.step(action) + data_list.append( + dict( + obs=self.last_state_obs, + action=action, + reward=reward, + done=done, + next_obs=next_obs, + ) + ) + self.last_state_obs = next_obs + self.last_state_done = done + else: + self.last_state_obs = self.collect_env.reset() + self.last_state_done = False + return data_list
+ +
[docs] def evaluate( + self, + policy: Union[Callable, torch.nn.Module], + num_episodes: int = None, + render_args: Dict = None, + ) -> List[Dict]: + """ + Overview: + Evaluate the given policy using the environment. The environment will be reset at the beginning of each episode. + No history will be stored in this method. The evaluation resultswill be returned as a list of dictionaries. + """ + if num_episodes is None: + num_episodes = 1 + + if render_args is not None: + render = True + else: + render = False + + def render_env(env, render_args): + # TODO: support different render modes + render_output = env.render( + **render_args, + ) + return render_output + + eval_results = [] + + env = gym.make(self.env_id) + for i in range(num_episodes): + if render: + render_output = [] + data_list = [] + with torch.no_grad(): + if gym.__version__ >= "0.26.0": + obs, _ = env.reset() + if render: + render_output.append(render_env(env, render_args)) + done = False + truncated = False + while not done and not truncated: + action = policy(obs) + next_obs, reward, done, truncated, _ = env.step(action) + data_list.append( + dict( + obs=obs, + action=action, + reward=reward, + truncated=truncated, + done=done, + next_obs=next_obs, + ) + ) + obs = next_obs + if render: + render_output.append(render_env(env, render_args)) + else: + step = 0 + obs = env.reset() + if render: + render_output.append(render_env(env, render_args)) + done = False + while not done: + action = policy(obs) + next_obs, reward, done, _ = env.step(action) + step += 1 + if render: + render_output.append(render_env(env, render_args)) + data_list.append( + dict( + obs=obs, + action=action, + reward=reward, + done=done, + next_obs=next_obs, + ) + ) + obs = next_obs + if render: + render_output.append(render_env(env, render_args)) + + eval_results.append( + dict( + total_return=sum([d["reward"] for d in data_list]), + total_steps=len(data_list), + data_list=data_list, + render_output=render_output if render else None, + ) + ) + + return eval_results
+
+ +
+ +
+
+ + +
+ +
+

+ © Copyright 2024, OpenDILab Contributors. + +

+
+ +
+ Built with Sphinx using a theme provided by Read the + Docs. +
+ + +
+
+
+ +
+
+
+ +
+
+
+
+
+ + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ + + + + +
+
+
+
+ + +
+
+
+ + +
+ + + + + + + + + \ No newline at end of file diff --git a/_modules/grl/rl_modules/value_network/one_shot_value_function.html b/_modules/grl/rl_modules/value_network/one_shot_value_function.html new file mode 100644 index 0000000..a23a5ab --- /dev/null +++ b/_modules/grl/rl_modules/value_network/one_shot_value_function.html @@ -0,0 +1,516 @@ + + + + + + + + + + + + + + + grl.rl_modules.value_network.one_shot_value_function — GenerativeRL v0.0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ +
+
+ + + + + + + + + + + +
+
+
+ + + + + + + + + + + + + + + + +
+ +
    + +
  • + + Docs + > +
  • + + +
  • Module code >
  • + +
  • grl.rl_modules.value_network.one_shot_value_function
  • + + +
  • + +
  • + +
+ + +
+
+ +
+ Shortcuts +
+
+ +
+
+ +
+ +
+
+ +

Source code for grl.rl_modules.value_network.one_shot_value_function

+import copy
+from typing import Tuple, Union
+
+import torch
+import torch.nn as nn
+from easydict import EasyDict
+from tensordict import TensorDict
+
+from grl.rl_modules.value_network.value_network import DoubleVNetwork
+
+
+
[docs]class OneShotValueFunction(nn.Module): + """ + Overview: + Value network for one-shot cases, which means that no Bellman backup is needed for training. + Interfaces: + ``__init__``, ``forward`` + """ + +
[docs] def __init__(self, config: EasyDict): + """ + Overview: + Initialization of one-shot value network. + Arguments: + config (:obj:`EasyDict`): The configuration dict. + """ + + super().__init__() + self.config = config + self.v_alpha = config.v_alpha + self.v = DoubleVNetwork(config.DoubleVNetwork) + self.v_target = copy.deepcopy(self.v).requires_grad_(False)
+ +
[docs] def forward( + self, + state: Union[torch.Tensor, TensorDict], + condition: Union[torch.Tensor, TensorDict] = None, + ) -> torch.Tensor: + """ + Overview: + Return the output of one-shot value network. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. + """ + + return self.v(state, condition)
+ +
[docs] def compute_double_v( + self, + state: Union[torch.Tensor, TensorDict], + condition: Union[torch.Tensor, TensorDict] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Return the output of two value networks. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. + Returns: + v1 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the first value network. + v2 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the second value network. + """ + return self.v.compute_double_v(state, condition=condition)
+ +
[docs] def v_loss( + self, + state: Union[torch.Tensor, TensorDict], + value: Union[torch.Tensor, TensorDict], + condition: Union[torch.Tensor, TensorDict] = None, + ) -> torch.Tensor: + """ + Overview: + Calculate the v loss. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + value (:obj:`Union[torch.Tensor, TensorDict]`): The input value. + condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. + Returns: + v_loss (:obj:`torch.Tensor`): The v loss. + """ + + # Update value function + targets = value + v0, v1 = self.v.compute_double_v(state, condition=condition) + v_loss = ( + torch.nn.functional.mse_loss(v0, targets) + + torch.nn.functional.mse_loss(v1, targets) + ) / 2 + return v_loss
+
+ +
+ +
+
+ + +
+ +
+

+ © Copyright 2024, OpenDILab Contributors. + +

+
+ +
+ Built with Sphinx using a theme provided by Read the + Docs. +
+ + +
+
+
+ +
+
+
+ +
+
+
+
+
+ + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ + + + + +
+
+
+
+ + +
+
+
+ + +
+ + + + + + + + + \ No newline at end of file diff --git a/_modules/grl/rl_modules/value_network/q_network.html b/_modules/grl/rl_modules/value_network/q_network.html new file mode 100644 index 0000000..0caa8ff --- /dev/null +++ b/_modules/grl/rl_modules/value_network/q_network.html @@ -0,0 +1,558 @@ + + + + + + + + + + + + + + + grl.rl_modules.value_network.q_network — GenerativeRL v0.0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ +
+
+ + + + + + + + + + + +
+
+
+ + + + + + + + + + + + + + + + +
+ + + + +
+
+ +
+ Shortcuts +
+
+ +
+
+ +
+ +
+
+ +

Source code for grl.rl_modules.value_network.q_network

+from typing import Tuple, Union
+
+import torch
+import torch.nn as nn
+from easydict import EasyDict
+from tensordict import TensorDict
+
+from grl.neural_network import get_module
+from grl.neural_network.encoders import get_encoder
+
+
+
[docs]class QNetwork(nn.Module): + """ + Overview: + Q network, which is used to approximate the Q value. + Interfaces: + ``__init__``, ``forward`` + """ + +
[docs] def __init__(self, config: EasyDict): + """ + Overview: + Initialization of Q network. + Arguments: + config (:obj:`EasyDict`): The configuration dict. + """ + super().__init__() + self.config = config + self.model = torch.nn.ModuleDict() + if hasattr(config, "action_encoder"): + self.model["action_encoder"] = get_encoder(config.action_encoder.type)( + **config.action_encoder.args + ) + else: + self.model["action_encoder"] = torch.nn.Identity() + if hasattr(config, "state_encoder"): + self.model["state_encoder"] = get_encoder(config.state_encoder.type)( + **config.state_encoder.args + ) + else: + self.model["state_encoder"] = torch.nn.Identity() + # TODO + # specific backbone network + self.model["backbone"] = get_module(config.backbone.type)( + **config.backbone.args + )
+ +
[docs] def forward( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict], + ) -> torch.Tensor: + """ + Overview: + Return output of Q networks. + Arguments: + action (:obj:`Union[torch.Tensor, TensorDict]`): The input action. + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + Returns: + q (:obj:`Union[torch.Tensor, TensorDict]`): The output of Q network. + """ + action_embedding = self.model["action_encoder"](action) + state_embedding = self.model["state_encoder"](state) + return self.model["backbone"](action_embedding, state_embedding)
+ + +
[docs]class DoubleQNetwork(nn.Module): + """ + Overview: + Double Q network, which has two Q networks. + Interfaces: + ``__init__``, ``forward``, ``compute_double_q``, ``compute_mininum_q`` + """ + +
[docs] def __init__(self, config: EasyDict): + super().__init__() + + self.model = torch.nn.ModuleDict() + self.model["q1"] = QNetwork(config) + self.model["q2"] = QNetwork(config)
+ +
[docs] def compute_double_q( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Return the output of two Q networks. + Arguments: + action (:obj:`Union[torch.Tensor, TensorDict]`): The input action. + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + Returns: + q1 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the first Q network. + q2 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the second Q network. + """ + + return self.model["q1"](action, state), self.model["q2"](action, state)
+ +
[docs] def compute_mininum_q( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict], + ) -> torch.Tensor: + """ + Overview: + Return the minimum output of two Q networks. + Arguments: + action (:obj:`Union[torch.Tensor, TensorDict]`): The input action. + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + Returns: + minimum_q (:obj:`Union[torch.Tensor, TensorDict]`): The minimum output of Q network. + """ + + return torch.min(*self.compute_double_q(action, state))
+ +
[docs] def forward( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict], + ) -> torch.Tensor: + """ + Overview: + Return the minimum output of two Q networks. + Arguments: + action (:obj:`Union[torch.Tensor, TensorDict]`): The input action. + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + Returns: + minimum_q (:obj:`Union[torch.Tensor, TensorDict]`): The minimum output of Q network. + """ + + return self.compute_mininum_q(action, state)
+
+ +
+ +
+
+ + +
+ +
+

+ © Copyright 2024, OpenDILab Contributors. + +

+
+ +
+ Built with Sphinx using a theme provided by Read the + Docs. +
+ + +
+
+
+ +
+
+
+ +
+
+
+
+
+ + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ + + + + +
+
+
+
+ + +
+
+
+ + +
+ + + + + + + + + \ No newline at end of file diff --git a/_modules/grl/rl_modules/value_network/value_network.html b/_modules/grl/rl_modules/value_network/value_network.html new file mode 100644 index 0000000..58a911f --- /dev/null +++ b/_modules/grl/rl_modules/value_network/value_network.html @@ -0,0 +1,562 @@ + + + + + + + + + + + + + + + grl.rl_modules.value_network.value_network — GenerativeRL v0.0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ +
+
+ + + + + + + + + + + +
+
+
+ + + + + + + + + + + + + + + + +
+ + + + +
+
+ +
+ Shortcuts +
+
+ +
+
+ +
+ +
+
+ +

Source code for grl.rl_modules.value_network.value_network

+from typing import Tuple, Union
+
+import torch
+import torch.nn as nn
+from easydict import EasyDict
+from tensordict import TensorDict
+
+from grl.neural_network import get_module
+from grl.neural_network.encoders import get_encoder
+
+
+
[docs]class VNetwork(nn.Module): + """ + Overview: + Value network, which is used to approximate the value function. + Interfaces: + ``__init__``, ``forward`` + """ + +
[docs] def __init__(self, config: EasyDict): + """ + Overview: + Initialization of value network. + Arguments: + config (:obj:`EasyDict`): The configuration dict. + """ + super().__init__() + self.config = config + self.model = torch.nn.ModuleDict() + if hasattr(config, "state_encoder"): + self.model["state_encoder"] = get_encoder(config.state_encoder.type)( + **config.state_encoder.args + ) + else: + self.model["state_encoder"] = torch.nn.Identity() + if hasattr(config, "condition_encoder"): + self.model["condition_encoder"] = get_encoder( + config.condition_encoder.type + )(**config.condition_encoder.args) + else: + self.model["condition_encoder"] = torch.nn.Identity() + # TODO + # specific backbone network + self.model["backbone"] = get_module(config.backbone.type)( + **config.backbone.args + )
+ +
[docs] def forward( + self, + state: Union[torch.Tensor, TensorDict], + condition: Union[torch.Tensor, TensorDict] = None, + ) -> torch.Tensor: + """ + Overview: + Return output of value networks. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. + Returns: + value (:obj:`Union[torch.Tensor, TensorDict]`): The output of value network. + """ + + state_embedding = self.model["state_encoder"](state) + if condition is not None: + condition_encoder_embedding = self.model["condition_encoder"](condition) + return self.model["backbone"](state_embedding, condition_encoder_embedding) + else: + return self.model["backbone"](state_embedding)
+ + +
[docs]class DoubleVNetwork(nn.Module): + """ + Overview: + Double value network, which has two value networks. + Interfaces: + ``__init__``, ``forward``, ``compute_double_v``, ``compute_mininum_v`` + """ + +
[docs] def __init__(self, config: EasyDict): + super().__init__() + + self.model = torch.nn.ModuleDict() + self.model["v1"] = VNetwork(config) + self.model["v2"] = VNetwork(config)
+ +
[docs] def compute_double_v( + self, + state: Union[torch.Tensor, TensorDict], + condition: Union[torch.Tensor, TensorDict], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Return the output of two value networks. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. + Returns: + v1 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the first value network. + v2 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the second value network. + """ + + return self.model["v1"](state, condition), self.model["v2"](state, condition)
+ +
[docs] def compute_mininum_v( + self, + state: Union[torch.Tensor, TensorDict], + condition: Union[torch.Tensor, TensorDict], + ) -> torch.Tensor: + """ + Overview: + Return the minimum output of two value networks. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. + Returns: + minimum_v (:obj:`Union[torch.Tensor, TensorDict]`): The minimum output of value network. + """ + + return torch.min(*self.compute_double_v(state, condition=condition))
+ +
[docs] def forward( + self, + state: Union[torch.Tensor, TensorDict], + condition: Union[torch.Tensor, TensorDict], + ) -> torch.Tensor: + """ + Overview: + Return the minimum output of two value networks. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. + Returns: + minimum_v (:obj:`Union[torch.Tensor, TensorDict]`): The minimum output of value network. + """ + + return self.compute_mininum_v(state, condition=condition)
+
+ +
+ +
+
+ + +
+ +
+

+ © Copyright 2024, OpenDILab Contributors. + +

+
+ +
+ Built with Sphinx using a theme provided by Read the + Docs. +
+ + +
+
+
+ +
+
+
+ +
+
+
+
+
+ + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ + + + + +
+
+
+
+ + +
+
+
+ + +
+ + + + + + + + + \ No newline at end of file diff --git a/_modules/index.html b/_modules/index.html index 2d694a2..dc405a8 100644 --- a/_modules/index.html +++ b/_modules/index.html @@ -307,6 +307,10 @@

All modules for which code is available

diff --git a/api_doc/algorithms/index.html b/api_doc/algorithms/index.html index 2bac6e9..d72dc97 100644 --- a/api_doc/algorithms/index.html +++ b/api_doc/algorithms/index.html @@ -310,43 +310,1153 @@
-
-

grl.algorithms¶

+
+

grl.algorithms¶

QGPOCritic¶

+
+
+class grl.algorithms.QGPOCritic(config)[source]¶
+
+
Overview:

Critic network for QGPO algorithm.

+
+
Interfaces:

__init__, forward

+
+
+
+
+__init__(config)[source]¶
+
+
Overview:

Initialization of QGPO critic network.

+
+
+
+
Parameters:
+

config (EasyDict) – The configuration dict.

+
+
+
+ +
+
+compute_double_q(action, state=None)[source]¶
+
+
Overview:

Return the output of two Q networks.

+
+
+
+
Parameters:
+
    +
  • action (Union[torch.Tensor, TensorDict]) – The input action.

  • +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
+
+
Returns:
+

The output of the first Q network. +q2 (Union[torch.Tensor, TensorDict]): The output of the second Q network.

+
+
Return type:
+

q1 (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+
+forward(action, state=None)[source]¶
+
+
Overview:

Return the output of QGPO critic.

+
+
+
+
Parameters:
+
    +
  • action (torch.Tensor) – The input action.

  • +
  • state (torch.Tensor) – The input state.

  • +
+
+
Return type:
+

Tensor

+
+
+
+ +
+
+q_loss(action, state, reward, next_state, done, fake_next_action, discount_factor=1.0)[source]¶
+
+
Overview:

Calculate the Q loss.

+
+
+
+
Parameters:
+
    +
  • action (torch.Tensor) – The input action.

  • +
  • state (torch.Tensor) – The input state.

  • +
  • reward (torch.Tensor) – The input reward.

  • +
  • next_state (torch.Tensor) – The input next state.

  • +
  • done (torch.Tensor) – The input done.

  • +
  • fake_next_action (torch.Tensor) – The input fake next action.

  • +
  • discount_factor (float) – The discount factor.

  • +
+
+
Return type:
+

Tensor

+
+
+
+ +
+

QGPOPolicy¶

+
+
+class grl.algorithms.QGPOPolicy(config)[source]¶
+
+
Overview:

QGPO policy network.

+
+
Interfaces:

__init__, forward, sample, behaviour_policy_sample, compute_q, behaviour_policy_loss, energy_guidance_loss, q_loss

+
+
+
+
+__init__(config)[source]¶
+

Initialize internal Module state, shared by both nn.Module and ScriptModule.

+
+ +
+
+behaviour_policy_loss(action, state)[source]¶
+
+
Overview:

Calculate the behaviour policy loss.

+
+
+
+
Parameters:
+
    +
  • action (torch.Tensor) – The input action.

  • +
  • state (torch.Tensor) – The input state.

  • +
+
+
+
+ +
+
+behaviour_policy_sample(state, batch_size=None, solver_config=None, t_span=None)[source]¶
+
+
Overview:

Return the output of behaviour policy, which is the action conditioned on the state.

+
+
+
+
Parameters:
+
    +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
  • solver_config (EasyDict) – The configuration for the ODE solver.

  • +
  • t_span (torch.Tensor) – The time span for the ODE solver or SDE solver.

  • +
+
+
Returns:
+

The output action.

+
+
Return type:
+

action (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+
+compute_q(state, action)[source]¶
+
+
Overview:

Calculate the Q value.

+
+
+
+
Parameters:
+
    +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
  • action (Union[torch.Tensor, TensorDict]) – The input action.

  • +
+
+
Returns:
+

The Q value.

+
+
Return type:
+

q (torch.Tensor)

+
+
+
+ +
+
+energy_guidance_loss(state, fake_next_action)[source]¶
+
+
Overview:

Calculate the energy guidance loss of QGPO.

+
+
+
+
Parameters:
+
    +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
  • fake_next_action (Union[torch.Tensor, TensorDict]) – The input fake next action.

  • +
+
+
Return type:
+

Tensor

+
+
+
+ +
+
+forward(state)[source]¶
+
+
Overview:

Return the output of QGPO policy, which is the action conditioned on the state.

+
+
+
+
Parameters:
+

state (Union[torch.Tensor, TensorDict]) – The input state.

+
+
Returns:
+

The output action.

+
+
Return type:
+

action (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+
+q_loss(action, state, reward, next_state, done, fake_next_action, discount_factor=1.0)[source]¶
+
+
Overview:

Calculate the Q loss.

+
+
+
+
Parameters:
+
    +
  • action (torch.Tensor) – The input action.

  • +
  • state (torch.Tensor) – The input state.

  • +
  • reward (torch.Tensor) – The input reward.

  • +
  • next_state (torch.Tensor) – The input next state.

  • +
  • done (torch.Tensor) – The input done.

  • +
  • fake_next_action (torch.Tensor) – The input fake next action.

  • +
  • discount_factor (float) – The discount factor.

  • +
+
+
Return type:
+

Tensor

+
+
+
+ +
+
+sample(state, batch_size=None, guidance_scale=tensor(1.), solver_config=None, t_span=None)[source]¶
+
+
Overview:

Return the output of QGPO policy, which is the action conditioned on the state.

+
+
+
+
Parameters:
+
    +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
  • guidance_scale (Union[torch.Tensor, float]) – The guidance scale.

  • +
  • solver_config (EasyDict) – The configuration for the ODE solver.

  • +
  • t_span (torch.Tensor) – The time span for the ODE solver or SDE solver.

  • +
+
+
Returns:
+

The output action.

+
+
Return type:
+

action (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+

QGPOAlgorithm¶

+
+
+class grl.algorithms.QGPOAlgorithm(config=None, simulator=None, dataset=None, model=None)[source]¶
+
+
Overview:

Q-guided policy optimization (QGPO) algorithm, which is an offline reinforcement learning algorithm that uses energy-based diffusion model for policy modeling.

+
+
Interfaces:

__init__, train, deploy

+
+
+
+
+__init__(config=None, simulator=None, dataset=None, model=None)[source]¶
+
+
Overview:

Initialize the QGPO algorithm.

+
+
+
+
Parameters:
+
    +
  • config (EasyDict) – The configuration , which must contain the following keys: +train (EasyDict): The training configuration. +deploy (EasyDict): The deployment configuration.

  • +
  • simulator (object) – The environment simulator.

  • +
  • dataset (QGPODataset) – The dataset.

  • +
  • model (Union[torch.nn.Module, torch.nn.ModuleDict]) – The model.

  • +
+
+
+
+
Interface:

__init__, train, deploy

+
+
+
+ +
+
+deploy(config=None)[source]¶
+
+
Overview:

Deploy the model using the given configuration.

+
+
+
+
Parameters:
+

config (EasyDict) – The deployment configuration.

+
+
Return type:
+

QGPOAgent

+
+
+
+ +
+
+train(config=None)[source]¶
+
+
Overview:

Train the model using the given configuration. A weight-and-bias run will be created automatically when this function is called.

+
+
+
+
Parameters:
+

config (EasyDict) – The training configuration.

+
+
+
+ +
+

SRPOCritic¶

+
+
+class grl.algorithms.SRPOCritic(config)[source]¶
+
+
Overview:

The critic network used in SRPO algorithm.

+
+
Interfaces:

__init__, v_loss, ``q_loss

+
+
+
+
+__init__(config)[source]¶
+
+
Overview:

Initialize the critic network.

+
+
+
+
Parameters:
+

config (EasyDict) – The configuration.

+
+
+
+ +
+
+forward(action, state=None)[source]¶
+
+
Overview:

Return the output of critic.

+
+
+
+
Parameters:
+
    +
  • action (torch.Tensor) – The input action.

  • +
  • state (torch.Tensor) – The input state.

  • +
+
+
Return type:
+

Tensor

+
+
+
+ +
+

SRPOPolicy¶

+
+
+class grl.algorithms.SRPOPolicy(config)[source]¶
+
+
Overview:

The SRPO policy network.

+
+
Interfaces:

__init__, forward, sample, behaviour_policy_loss, srpo_actor_loss

+
+
+
+
+__init__(config)[source]¶
+
+
Overview:

Initialize the SRPO policy network.

+
+
+
+
Parameters:
+

config (EasyDict) – The configuration.

+
+
+
+ +
+
+behaviour_policy_loss(action, state)[source]¶
+
+
Overview:

Calculate the behaviour policy loss.

+
+
+
+
Parameters:
+
    +
  • action (torch.Tensor) – The input action.

  • +
  • state (torch.Tensor) – The input state.

  • +
+
+
+
+ +
+
+forward(state)[source]¶
+
+
Overview:

Return the output of SRPO policy, which is the action conditioned on the state.

+
+
+
+
Parameters:
+

state (Union[torch.Tensor, TensorDict]) – The input state.

+
+
Returns:
+

The output action.

+
+
Return type:
+

action (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+
+sample(state, batch_size=None, solver_config=None, t_span=None)[source]¶
+
+
Overview:

Return the output of SRPO policy, which is the action conditioned on the state.

+
+
+
+
Parameters:
+
    +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
  • solver_config (EasyDict) – The configuration for the ODE solver.

  • +
  • t_span (torch.Tensor) – The time span for the ODE solver or SDE solver.

  • +
+
+
Returns:
+

The output action.

+
+
Return type:
+

action (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+
+srpo_actor_loss(state)[source]¶
+
+
Overview:

Calculate the Q loss.

+
+
+
+
Parameters:
+
    +
  • action (torch.Tensor) – The input action.

  • +
  • state (torch.Tensor) – The input state.

  • +
  • reward (torch.Tensor) – The input reward.

  • +
  • next_state (torch.Tensor) – The input next state.

  • +
  • done (torch.Tensor) – The input done.

  • +
  • fake_next_action (torch.Tensor) – The input fake next action.

  • +
  • discount_factor (float) – The discount factor.

  • +
+
+
Return type:
+

Tensor

+
+
+
+ +
+

SRPOAlgorithm¶

+
+
+class grl.algorithms.SRPOAlgorithm(config=None, simulator=None, dataset=None, model=None)[source]¶
+
+
+__init__(config=None, simulator=None, dataset=None, model=None)[source]¶
+
+
Overview:

Initialize the SRPO algorithm.

+
+
+
+
Parameters:
+
    +
  • config (EasyDict) – The configuration , which must contain the following keys: +train (EasyDict): The training configuration. +deploy (EasyDict): The deployment configuration.

  • +
  • simulator (object) – The environment simulator.

  • +
  • dataset (Dataset) – The dataset.

  • +
  • model (Union[torch.nn.Module, torch.nn.ModuleDict]) – The model.

  • +
+
+
+
+
Interface:

__init__, train, deploy

+
+
+
+ +
+
+deploy(config=None)[source]¶
+
+
Overview:

Deploy the model using the given configuration.

+
+
+
+
Parameters:
+

config (EasyDict) – The deployment configuration.

+
+
Return type:
+

SRPOAgent

+
+
+
+ +
+
+train(config=None)[source]¶
+
+
Overview:

Train the model using the given configuration. A weight-and-bias run will be created automatically when this function is called.

+
+
+
+
Parameters:
+

config (EasyDict) – The training configuration.

+
+
+
+ +
+

GMPOCritic¶

+
+
+class grl.algorithms.GMPOCritic(config)[source]¶
+
+
Overview:

Critic network for GMPO algorithm.

+
+
Interfaces:

__init__, forward

+
+
+
+
+__init__(config)[source]¶
+
+
Overview:

Initialization of GMPO critic network.

+
+
+
+
Parameters:
+

config (EasyDict) – The configuration dict.

+
+
+
+ +
+
+compute_double_q(action, state=None)[source]¶
+
+
Overview:

Return the output of two Q networks.

+
+
+
+
Parameters:
+
    +
  • action (Union[torch.Tensor, TensorDict]) – The input action.

  • +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
+
+
Returns:
+

The output of the first Q network. +q2 (Union[torch.Tensor, TensorDict]): The output of the second Q network.

+
+
Return type:
+

q1 (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+
+forward(action, state=None)[source]¶
+
+
Overview:

Return the output of GMPO critic.

+
+
+
+
Parameters:
+
    +
  • action (torch.Tensor) – The input action.

  • +
  • state (torch.Tensor) – The input state.

  • +
+
+
Return type:
+

Tensor

+
+
+
+ +
+

GMPOPolicy¶

+
+
+class grl.algorithms.GMPOPolicy(config)[source]¶
+
+
Overview:

GMPO policy network for GMPO algorithm, which includes the base model (optinal), the guided model and the critic.

+
+
Interfaces:

__init__, forward, sample, compute_q, behaviour_policy_loss, policy_optimization_loss_by_advantage_weighted_regression, policy_optimization_loss_by_advantage_weighted_regression_softmax

+
+
+
+
+__init__(config)[source]¶
+
+
Overview:

Initialize the GMPO policy network.

+
+
+
+
Parameters:
+

config (EasyDict) – The configuration dict.

+
+
+
+ +
+
+behaviour_policy_loss(action, state, maximum_likelihood=False)[source]¶
+
+
Overview:

Calculate the behaviour policy loss.

+
+
+
+
Parameters:
+
    +
  • action (torch.Tensor) – The input action.

  • +
  • state (torch.Tensor) – The input state.

  • +
+
+
+
+ +
+
+behaviour_policy_sample(state, batch_size=None, solver_config=None, t_span=None, with_grad=False)[source]¶
+
+
Overview:

Return the output of behaviour policy, which is the action conditioned on the state.

+
+
+
+
Parameters:
+
    +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • +
  • solver_config (EasyDict) – The configuration for the ODE solver.

  • +
  • t_span (torch.Tensor) – The time span for the ODE solver or SDE solver.

  • +
  • with_grad (bool) – Whether to calculate the gradient.

  • +
+
+
Returns:
+

The output action.

+
+
Return type:
+

action (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+
+compute_q(state, action)[source]¶
+
+
Overview:

Calculate the Q value.

+
+
+
+
Parameters:
+
    +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
  • action (Union[torch.Tensor, TensorDict]) – The input action.

  • +
+
+
Returns:
+

The Q value.

+
+
Return type:
+

q (torch.Tensor)

+
+
+
+ +
+
+forward(state)[source]¶
+
+
Overview:

Return the output of GMPO policy, which is the action conditioned on the state.

+
+
+
+
Parameters:
+

state (Union[torch.Tensor, TensorDict]) – The input state.

+
+
Returns:
+

The output action.

+
+
Return type:
+

action (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+
+policy_optimization_loss_by_advantage_weighted_regression(action, state, maximum_likelihood=False, beta=1.0, weight_clamp=100.0)[source]¶
+
+
Overview:

Calculate the behaviour policy loss.

+
+
+
+
Parameters:
+
    +
  • action (torch.Tensor) – The input action.

  • +
  • state (torch.Tensor) – The input state.

  • +
+
+
+
+ +
+
+policy_optimization_loss_by_advantage_weighted_regression_softmax(state, fake_action, maximum_likelihood=False, beta=1.0)[source]¶
+
+
Overview:

Calculate the behaviour policy loss.

+
+
+
+
Parameters:
+
    +
  • action (torch.Tensor) – The input action.

  • +
  • state (torch.Tensor) – The input state.

  • +
+
+
+
+ +
+
+sample(state, batch_size=None, solver_config=None, t_span=None, with_grad=False)[source]¶
+
+
Overview:

Return the output of GMPO policy, which is the action conditioned on the state.

+
+
+
+
Parameters:
+
    +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • +
  • solver_config (EasyDict) – The configuration for the ODE solver.

  • +
  • t_span (torch.Tensor) – The time span for the ODE solver or SDE solver.

  • +
+
+
Returns:
+

The output action.

+
+
Return type:
+

action (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+

GMPOAlgorithm¶

+
+
+class grl.algorithms.GMPOAlgorithm(config=None, simulator=None, dataset=None, model=None, seed=None)[source]¶
+
+
Overview:

The Generative Model Policy Optimization(GMPO) algorithm.

+
+
Interfaces:

__init__, train, deploy

+
+
+
+
+__init__(config=None, simulator=None, dataset=None, model=None, seed=None)[source]¶
+
+
Overview:

Initialize the GMPO && GPG algorithm.

+
+
+
+
Parameters:
+
    +
  • config (EasyDict) – The configuration , which must contain the following keys: +train (EasyDict): The training configuration. +deploy (EasyDict): The deployment configuration.

  • +
  • simulator (object) – The environment simulator.

  • +
  • dataset (GPDataset) – The dataset.

  • +
  • model (Union[torch.nn.Module, torch.nn.ModuleDict]) – The model.

  • +
+
+
+
+
Interface:

__init__, train, deploy

+
+
+
+ +
+
+train(config=None, seed=None)[source]¶
+
+
Overview:

Train the model using the given configuration. A weight-and-bias run will be created automatically when this function is called.

+
+
+
+
Parameters:
+
    +
  • config (EasyDict) – The training configuration.

  • +
  • seed (int) – The random seed.

  • +
+
+
+
+ +
+

GMPGCritic¶

+
+
+class grl.algorithms.GMPGCritic(config)[source]¶
+
+
Overview:

Critic network.

+
+
Interfaces:

__init__, forward

+
+
+
+
+__init__(config)[source]¶
+
+
Overview:

Initialization of GPO critic network.

+
+
+
+
Parameters:
+

config (EasyDict) – The configuration dict.

+
+
+
+ +
+
+compute_double_q(action, state=None)[source]¶
+
+
Overview:

Return the output of two Q networks.

+
+
+
+
Parameters:
+
    +
  • action (Union[torch.Tensor, TensorDict]) – The input action.

  • +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
+
+
Returns:
+

The output of the first Q network. +q2 (Union[torch.Tensor, TensorDict]): The output of the second Q network.

+
+
Return type:
+

q1 (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+
+forward(action, state=None)[source]¶
+
+
Overview:

Return the output of GPO critic.

+
+
+
+
Parameters:
+
    +
  • action (torch.Tensor) – The input action.

  • +
  • state (torch.Tensor) – The input state.

  • +
+
+
Return type:
+

Tensor

+
+
+
+ +
+
+in_support_ql_loss(action, state, reward, next_state, done, fake_next_action, discount_factor=1.0)[source]¶
+
+
Overview:

Calculate the Q loss.

+
+
+
+
Parameters:
+
    +
  • action (torch.Tensor) – The input action.

  • +
  • state (torch.Tensor) – The input state.

  • +
  • reward (torch.Tensor) – The input reward.

  • +
  • next_state (torch.Tensor) – The input next state.

  • +
  • done (torch.Tensor) – The input done.

  • +
  • fake_next_action (torch.Tensor) – The input fake next action.

  • +
  • discount_factor (float) – The discount factor.

  • +
+
+
Return type:
+

Tensor

+
+
+
+ +
+

GMPGPolicy¶

+
+
+class grl.algorithms.GMPGPolicy(config)[source]¶
+
+
+__init__(config)[source]¶
+

Initialize internal Module state, shared by both nn.Module and ScriptModule.

+
+ +
+
+behaviour_policy_loss(action, state, maximum_likelihood=False)[source]¶
+
+
Overview:

Calculate the behaviour policy loss.

+
+
+
+
Parameters:
+
    +
  • action (torch.Tensor) – The input action.

  • +
  • state (torch.Tensor) – The input state.

  • +
+
+
+
+ +
+
+behaviour_policy_sample(state, batch_size=None, solver_config=None, t_span=None, with_grad=False)[source]¶
+
+
Overview:

Return the output of behaviour policy, which is the action conditioned on the state.

+
+
+
+
Parameters:
+
    +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • +
  • solver_config (EasyDict) – The configuration for the ODE solver.

  • +
  • t_span (torch.Tensor) – The time span for the ODE solver or SDE solver.

  • +
  • with_grad (bool) – Whether to calculate the gradient.

  • +
+
+
Returns:
+

The output action.

+
+
Return type:
+

action (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+
+compute_q(state, action)[source]¶
+
+
Overview:

Calculate the Q value.

+
+
+
+
Parameters:
+
    +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
  • action (Union[torch.Tensor, TensorDict]) – The input action.

  • +
+
+
Returns:
+

The Q value.

+
+
Return type:
+

q (torch.Tensor)

+
+
+
+ +
+
+forward(state)[source]¶
+
+
Overview:

Return the output of GPO policy, which is the action conditioned on the state.

+
+
+
+
Parameters:
+

state (Union[torch.Tensor, TensorDict]) – The input state.

+
+
Returns:
+

The output action.

+
+
Return type:
+

action (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+
+sample(state, batch_size=None, solver_config=None, t_span=None, with_grad=False)[source]¶
+
+
Overview:

Return the output of GPO policy, which is the action conditioned on the state.

+
+
+
+
Parameters:
+
    +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • +
  • solver_config (EasyDict) – The configuration for the ODE solver.

  • +
  • t_span (torch.Tensor) – The time span for the ODE solver or SDE solver.

  • +
+
+
Returns:
+

The output action.

+
+
Return type:
+

action (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+

GMPGAlgorithm¶

+
+
+class grl.algorithms.GMPGAlgorithm(config=None, simulator=None, dataset=None, model=None, seed=None)[source]¶
+
+
Overview:

The Generative Model Policy Gradient(GMPG) algorithm.

+
+
Interfaces:

__init__, train, deploy

+
+
+
+
+__init__(config=None, simulator=None, dataset=None, model=None, seed=None)[source]¶
+
+
Overview:

Initialize algorithm.

+
+
+
+
Parameters:
+
    +
  • config (EasyDict) – The configuration , which must contain the following keys: +train (EasyDict): The training configuration. +deploy (EasyDict): The deployment configuration.

  • +
  • simulator (object) – The environment simulator.

  • +
  • dataset (GPDataset) – The dataset.

  • +
  • model (Union[torch.nn.Module, torch.nn.ModuleDict]) – The model.

  • +
+
+
+
+
Interface:

__init__, train, deploy

+
+
+
+ +
+
+train(config=None, seed=None)[source]¶
+
+
Overview:

Train the model using the given configuration. A weight-and-bias run will be created automatically when this function is called.

+
+
+
+
Parameters:
+
    +
  • config (EasyDict) – The training configuration.

  • +
  • seed (int) – The random seed.

  • +
+
+
+
+ +
+
@@ -394,18 +1504,128 @@

GMPGAlgorithm diff --git a/api_doc/rl_modules/index.html b/api_doc/rl_modules/index.html index 1abe037..64cdd86 100644 --- a/api_doc/rl_modules/index.html +++ b/api_doc/rl_modules/index.html @@ -310,25 +310,475 @@
-
-

grl.rl_modules¶

+
+

grl.rl_modules¶

GymEnvSimulator¶

+
+
+class grl.rl_modules.GymEnvSimulator(env_id)[source]¶
+
+
Overview:

A simple gym environment simulator in GenerativeRL. +This simulator is used to collect episodes and steps using a given policy in a gym environment. +It runs in single process and is suitable for small-scale experiments.

+
+
Interfaces:

__init__, collect_episodes, collect_steps, evaluate

+
+
+
+
+__init__(env_id)[source]¶
+
+
Overview:

Initialize the GymEnvSimulator according to the given configuration.

+
+
+
+
Parameters:
+

env_id (str) – The id of the gym environment to simulate.

+
+
+
+ +
+
+collect_episodes(policy, num_episodes=None, num_steps=None)[source]¶
+
+
Overview:

Collect several episodes using the given policy. The environment will be reset at the beginning of each episode. +No history will be stored in this method. The collected information of steps will be returned as a list of dictionaries.

+
+
+
+
Parameters:
+
    +
  • policy (Union[Callable, torch.nn.Module]) – The policy to collect episodes.

  • +
  • num_episodes (int) – The number of episodes to collect.

  • +
  • num_steps (int) – The number of steps to collect.

  • +
+
+
Return type:
+

List[Dict]

+
+
+
+ +
+
+collect_steps(policy, num_episodes=None, num_steps=None, random_policy=False)[source]¶
+
+
Overview:

Collect several steps using the given policy. The environment will not be reset until the end of the episode. +Last observation will be stored in this method. The collected information of steps will be returned as a list of dictionaries.

+
+
+
+
Parameters:
+
    +
  • policy (Union[Callable, torch.nn.Module]) – The policy to collect steps.

  • +
  • num_episodes (int) – The number of episodes to collect.

  • +
  • num_steps (int) – The number of steps to collect.

  • +
  • random_policy (bool) – Whether to use a random policy.

  • +
+
+
Return type:
+

List[Dict]

+
+
+
+ +
+
+evaluate(policy, num_episodes=None, render_args=None)[source]¶
+
+
Return type:
+

List[Dict]

+
+
+
+
Overview:

Evaluate the given policy using the environment. The environment will be reset at the beginning of each episode. +No history will be stored in this method. The evaluation resultswill be returned as a list of dictionaries.

+
+
+
+ +
+

OneShotValueFunction¶

+
+
+class grl.rl_modules.OneShotValueFunction(config)[source]¶
+
+
Overview:

Value network for one-shot cases, which means that no Bellman backup is needed for training.

+
+
Interfaces:

__init__, forward

+
+
+
+
+__init__(config)[source]¶
+
+
Overview:

Initialization of one-shot value network.

+
+
+
+
Parameters:
+

config (EasyDict) – The configuration dict.

+
+
+
+ +
+
+compute_double_v(state, condition=None)[source]¶
+
+
Overview:

Return the output of two value networks.

+
+
+
+
Parameters:
+
    +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
  • condition (Union[torch.Tensor, TensorDict]) – The input condition.

  • +
+
+
Returns:
+

The output of the first value network. +v2 (Union[torch.Tensor, TensorDict]): The output of the second value network.

+
+
Return type:
+

v1 (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+
+forward(state, condition=None)[source]¶
+
+
Overview:

Return the output of one-shot value network.

+
+
+
+
Parameters:
+
    +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
  • condition (Union[torch.Tensor, TensorDict]) – The input condition.

  • +
+
+
Return type:
+

Tensor

+
+
+
+ +
+
+v_loss(state, value, condition=None)[source]¶
+
+
Overview:

Calculate the v loss.

+
+
+
+
Parameters:
+
    +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
  • value (Union[torch.Tensor, TensorDict]) – The input value.

  • +
  • condition (Union[torch.Tensor, TensorDict]) – The input condition.

  • +
+
+
Returns:
+

The v loss.

+
+
Return type:
+

v_loss (torch.Tensor)

+
+
+
+ +
+

VNetwork¶

+
+
+class grl.rl_modules.VNetwork(config)[source]¶
+
+
Overview:

Value network, which is used to approximate the value function.

+
+
Interfaces:

__init__, forward

+
+
+
+
+__init__(config)[source]¶
+
+
Overview:

Initialization of value network.

+
+
+
+
Parameters:
+

config (EasyDict) – The configuration dict.

+
+
+
+ +
+
+forward(state, condition=None)[source]¶
+
+
Overview:

Return output of value networks.

+
+
+
+
Parameters:
+
    +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
  • condition (Union[torch.Tensor, TensorDict]) – The input condition.

  • +
+
+
Returns:
+

The output of value network.

+
+
Return type:
+

value (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+

DoubleVNetwork¶

+
+
+class grl.rl_modules.DoubleVNetwork(config)[source]¶
+
+
Overview:

Double value network, which has two value networks.

+
+
Interfaces:

__init__, forward, compute_double_v, compute_mininum_v

+
+
+
+
+__init__(config)[source]¶
+

Initialize internal Module state, shared by both nn.Module and ScriptModule.

+
+ +
+
+compute_double_v(state, condition)[source]¶
+
+
Overview:

Return the output of two value networks.

+
+
+
+
Parameters:
+
    +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
  • condition (Union[torch.Tensor, TensorDict]) – The input condition.

  • +
+
+
Returns:
+

The output of the first value network. +v2 (Union[torch.Tensor, TensorDict]): The output of the second value network.

+
+
Return type:
+

v1 (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+
+compute_mininum_v(state, condition)[source]¶
+
+
Overview:

Return the minimum output of two value networks.

+
+
+
+
Parameters:
+
    +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
  • condition (Union[torch.Tensor, TensorDict]) – The input condition.

  • +
+
+
Returns:
+

The minimum output of value network.

+
+
Return type:
+

minimum_v (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+
+forward(state, condition)[source]¶
+
+
Overview:

Return the minimum output of two value networks.

+
+
+
+
Parameters:
+
    +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
  • condition (Union[torch.Tensor, TensorDict]) – The input condition.

  • +
+
+
Returns:
+

The minimum output of value network.

+
+
Return type:
+

minimum_v (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+

QNetwork¶

+
+
+class grl.rl_modules.QNetwork(config)[source]¶
+
+
Overview:

Q network, which is used to approximate the Q value.

+
+
Interfaces:

__init__, forward

+
+
+
+
+__init__(config)[source]¶
+
+
Overview:

Initialization of Q network.

+
+
+
+
Parameters:
+

config (EasyDict) – The configuration dict.

+
+
+
+ +
+
+forward(action, state)[source]¶
+
+
Overview:

Return output of Q networks.

+
+
+
+
Parameters:
+
    +
  • action (Union[torch.Tensor, TensorDict]) – The input action.

  • +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
+
+
Returns:
+

The output of Q network.

+
+
Return type:
+

q (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+

DoubleQNetwork¶

+
+
+class grl.rl_modules.DoubleQNetwork(config)[source]¶
+
+
Overview:

Double Q network, which has two Q networks.

+
+
Interfaces:

__init__, forward, compute_double_q, compute_mininum_q

+
+
+
+
+__init__(config)[source]¶
+

Initialize internal Module state, shared by both nn.Module and ScriptModule.

+
+ +
+
+compute_double_q(action, state)[source]¶
+
+
Overview:

Return the output of two Q networks.

+
+
+
+
Parameters:
+
    +
  • action (Union[torch.Tensor, TensorDict]) – The input action.

  • +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
+
+
Returns:
+

The output of the first Q network. +q2 (Union[torch.Tensor, TensorDict]): The output of the second Q network.

+
+
Return type:
+

q1 (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+
+compute_mininum_q(action, state)[source]¶
+
+
Overview:

Return the minimum output of two Q networks.

+
+
+
+
Parameters:
+
    +
  • action (Union[torch.Tensor, TensorDict]) – The input action.

  • +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
+
+
Returns:
+

The minimum output of Q network.

+
+
Return type:
+

minimum_q (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+
+forward(action, state)[source]¶
+
+
Overview:

Return the minimum output of two Q networks.

+
+
+
+
Parameters:
+
    +
  • action (Union[torch.Tensor, TensorDict]) – The input action.

  • +
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • +
+
+
Returns:
+

The minimum output of Q network.

+
+
Return type:
+

minimum_q (Union[torch.Tensor, TensorDict])

+
+
+
+ +
+
@@ -376,12 +826,62 @@

DoubleQNetwork diff --git a/genindex.html b/genindex.html index 7ac03a4..2cc0329 100644 --- a/genindex.html +++ b/genindex.html @@ -311,6 +311,7 @@

Index

_ | A + | B | C | D | E @@ -322,6 +323,7 @@

Index

| M | N | O + | P | Q | S | T @@ -338,6 +340,30 @@

_

  • (grl.agents.QGPOAgent method)
  • (grl.agents.SRPOAgent method) +
  • +
  • (grl.algorithms.GMPGAlgorithm method) +
  • +
  • (grl.algorithms.GMPGCritic method) +
  • +
  • (grl.algorithms.GMPGPolicy method) +
  • +
  • (grl.algorithms.GMPOAlgorithm method) +
  • +
  • (grl.algorithms.GMPOCritic method) +
  • +
  • (grl.algorithms.GMPOPolicy method) +
  • +
  • (grl.algorithms.QGPOAlgorithm method) +
  • +
  • (grl.algorithms.QGPOCritic method) +
  • +
  • (grl.algorithms.QGPOPolicy method) +
  • +
  • (grl.algorithms.SRPOAlgorithm method) +
  • +
  • (grl.algorithms.SRPOCritic method) +
  • +
  • (grl.algorithms.SRPOPolicy method)
  • (grl.datasets.GPD4RLDataset method)
  • @@ -380,6 +406,18 @@

    _

  • (grl.numerical_methods.SDE method)
  • (grl.numerical_methods.SDESolver method) +
  • +
  • (grl.rl_modules.DoubleQNetwork method) +
  • +
  • (grl.rl_modules.DoubleVNetwork method) +
  • +
  • (grl.rl_modules.GymEnvSimulator method) +
  • +
  • (grl.rl_modules.OneShotValueFunction method) +
  • +
  • (grl.rl_modules.QNetwork method) +
  • +
  • (grl.rl_modules.VNetwork method)
  • @@ -399,13 +437,71 @@

    A

    +

    B

    + + + +
    +

    C

    @@ -527,6 +665,18 @@

    G

  • data_prediction_function_with_energy_guidance() (grl.generative_models.EnergyConditionalDiffusionModel method)
  • +
  • deploy() (grl.algorithms.QGPOAlgorithm method) + +
  • diffusion() (grl.numerical_methods.GaussianConditionalProbabilityPath method)
  • diffusion_squared() (grl.numerical_methods.GaussianConditionalProbabilityPath method) @@ -447,6 +549,10 @@

    D

  • DiT2D (in module grl.neural_network)
  • DiT3D (class in grl.neural_network) +
  • +
  • DoubleQNetwork (class in grl.rl_modules) +
  • +
  • DoubleVNetwork (class in grl.rl_modules)
  • DPMSolver (class in grl.numerical_methods)
  • @@ -462,11 +568,17 @@

    D

    E

    @@ -488,9 +600,25 @@

    F

  • flow_matching_loss_with_mask() (grl.generative_models.IndependentConditionalFlowModel method)
  • -
  • forward() (grl.neural_network.ConcatenateLayer method) +
  • forward() (grl.algorithms.GMPGCritic method)
  • - grl.datasets + grl.algorithms
  • @@ -592,6 +758,8 @@

    H

    I

    +

    P

    + + + +
    +

    Q

    @@ -709,9 +909,17 @@

    Q

    S

    @@ -809,6 +1035,8 @@

    U

    V

    +
    diff --git a/objects.inv b/objects.inv index 871a0dd9b64c008095b0fb365f4e75cffb06bf47..3dd9c5f718b5d7dd8dfa9c9be6f0e2097b1fbeb3 100644 GIT binary patch delta 6266 zcmV-=7=`EXBE&I}fqzYN+c?s`=U1r8?^T*jG^f37jWeE9IiV-Y`pTjySjHQQph-xH z^XnJ<5DAj_MuVC|QZ~_e`soJHXn-_uD6>s+Ksl{e#oND+WPi%g>R+p*!111>zqLib z;e3yNZeHjyyJGdJM?uocMt!mf2c9iEUBv1dj*8J?tdnZagJ#mH$)J9A$u2; zB&BWuf8RYs?nd;rJ<7~*`IPX_$ReeE+909%_3LTPQa1Ak~}2bhv$aiWOTKRsnAegy-5CgqQ$+=JoaJg!bfK}s6GW^qQU${+fH zWH^2EZkrYcux&cPC-jn>F*%iS(L4aPdX8Qt$AXP+DEv#$dIduS>rx#akiijHM2iAM zv_pIECkh&NqQGG%0)G!X5pdYK;B|JBg5o3ol~7Fb z#>Mf|Tb%6ANlsZZ7XLu_@Iui(F43VxRmG&P&mgYIGd(6h!I(faev=A=gy0bp(7+mF z?opAEH%LT)pk4z41oZ(1&~-tAItL6gxH-I6fM{960n8Bt3}AbNn^MZe1ACMeUhgL)*WDdw8Mk#qaZQ7*>?bm;UgrVjz7kPh~Z;IfEZo_0>tnE2GDho zp|-{Yh>Zgbpq=4jBZm&`MX=F(fh!9VCn+T*fZ2y*aggm9u620Itw*`4=^%2^(_zS! zOnb14k`6-}x}t7rxk zfvHi{6QC)cKO-hpk++hO zf`3E3ZcHdXqnM3&l!*hxOiB?cX0n@G{QP{XH}5{zzl-{v65kY9-5Ft2)Sirs676x? z>{&ja^5%z-Dp@)!O=MwIJZI!bd`xKig7ZUgW$p4owWAV35oplyL(S*4_0rdKwKmWotol#vwE zw;+~%Z1!vaAo(8G&`B0-megnpGHvQ~E-{a~a%hLmWIugLV#Ztff%Jz>2XxxQO>uUb zuig#dr(iRUA?IW_2EnFHLpuK;_02EjM6cg5{}D`rB_4K~uOBPICPq5-K{KUKOn<0A z?7y6H4E|)36m+Jf(G(<8(&-F48QW`byZSMyzAwrn*rtJ<<`8qDjzNe8oK9!+VP2BO zv`q4<`Q-h&{c!nq51$JeG6=Sax6!}!c%SMH*$z?L*S}r5df-XML|u@X#z=A4VeUvF zu2`ikeyX3v(YLtBl6-Mm@FB=TZGXuApAKrfDBmc?!w`)nBF357+K{AM0%&R2d~TBn z+64wb8;x|y4BcO{@s9ZPciGmWALR+EDY#yW@>ssWr_+~ zw%d{3Z6*Q+ZbG37&=&oMj(;q@OtKwwI&0IfsKWbGl6?_>4+sbk>?wsJP=7Ux1crD- zWu3~>@vOWyL6%JI#_IU_z(XN}F;=;SQMv6^o)|h#S!vrM;;#&6QmRaPD!aLT_`D<8 zS&~JopbaoJz1J5V=M){-;sim$-zBH2!b#38a2`tXlsw@K)3=9&zJCCEO0R75SCT#7 zlf&-I*JZ8Te9D=RD|9veA^MJi+LWA=5=-AihEu$c2i5VOF0Q@aEU#Ftt0=#oV!FSA z+RKi~we-f->nSPGerf$Bex}#Z$%Ynse#R;84p&=He?P1i)_WD56tqW0w;uuLT80*6 zPGPv<8B~uMmT{O2YJUgXY`Z|swhPc~yP(as3)^hNld93emNmt(`&qP1SQ~28tvk&5 zkUhHa5ezy%J_DdRpK9QM@Vl6NL3X>oR5n-(AzAM{Ad|ylCfsZGmb^x zEz+qfv$$PXix=cWv0^1Y2yJegvyuM(GQ29Ah9fbe7D^q3Tz^50j&$+`KQiL134qa6 zdt8>7k8sWP7^1lzCCzmX0)A;OSagu)x_di*%+EV~JY@+brB{1Ruz(tDikgxvLn)0> zfh!i|t1##lMFlW4Iw#plmX#QzMmgHO5cg#R`Mj5d)U|&TUieT}bbnRR{lBy1ss6@JaV^AZxE4Z= zb#o}1M~0Zacy}j;3Yfcm!&1lm>W%p$E6qb>b*3;tD9Y?e)t8yL^Qmi#ADgkLkHm3p zmf`Xe(L+o9K8+)|Do0srOUX;MLw2ZOGB~Zf+E78%sip_*$!9Lnk$kK1k$h_0oBCkR zzQ97uq~hp|Q3I7}dk~?ezOL4A(5Qi?_9R+cOH3kKu;m-J=`@scj+0$ z(h^r%P#s`;3b|gynepsR1PJpPs=27MC6_H55Cbk?S_x;mf;-(nm3R8j5ogBYBlJP&i=i#$;SB5kt~rcp9k_K3XY+=ORX_|H;0$D(Ny7xq z$$#|>S_x;mkUd7*K^*%70m@e+RlXXz=c@!dUyY3N)d(P8C7}6gq|a9)-+Yxo5(&l%OXMT5~>wY~f zh=K0v17I%i#Quic>pHWi5fIzixsWkU8^T;7R`cEpUz-5!_W@>X$XMaol zGaA7S*jfo^y05BF>DVmw4LC$1UaI4D(Kh z3~oNRuHkGR#1c<<)w+R6dM^}=;7;Y{}x(Cx1}NBf2O4dj^B z4pEm8*?b!Zu!I@ktE(#X;0G~1jeodEXDG3h4|VS)%$SmR2_us6A?>?_8S@b?VFV)f zjJ=mI+;z>;5?5MKy;ie?5r{Z5_MZ_rHj)`(E@4mCJDKKg(3Y0C(t>JtO)@PIab^si zOmp0fl4+g1vtckDTAkv|C7kJcrwtqzaoR8waOcW?jFXg-60o;SL5(FfW`FR1(a@LM zP!Xdzp0w#`u$JcMrF}f4%SSu-B<`P&KGAQNfwMaW#*} z7+JTH7df)hiuL~vab8sD)+jK5-c2^xA{rg=VTBSX2Qs~Uh*Nqb% z5@*V!ENT0@Q|d}5Hv+m)?>-yXacR$`)SR5CvhjT4-1VmwZ92G9=in;|yMXNSBI!h| zy6tEsyzR(H*-d}0Zc}2fdkQeVQOS%$yMxW$59Lithp4#-%8a(=_J1o^?nZZSnFzsI zp@~A?=8Qv=u{5B8^ox!Ng#uz`n+mGI)p8ymiJzm-4ZwJO}0}2kQZ?BHJu9Q#49?j`GuV5)ut31PQf)SN-G;Mi^^iUZxZ|WV&v?hl z9%xyQwX?)AslH#)cN__L$e{iV^U0gmT$69$>_C?D5`Xx1ik+dF?G!h#klq9Tyh3N9 z<{Jys>LGp792c2nF)fq4YA(IF;#^=vCLDA~KQA2Dyy98-I2L$Y*vvq~ZDBXC1-uvL zlQBbcO}-&>fm+IMi{@!F#V{~S`s&(aRZb)ftdDQIxu&0Lg08=#R@t|UGK)pHN!8GP zfM1TASbwEklUrc~cTvEM%*asVB+L=*3J#1bRw;|0>Sr;baginYwVgWR6n$$s1=kSV z`M{a4k?G3x*8*W7=LDNvElf3>S3MBToC&{gtwRLai{{Qwn29-XR>FClC=e($F@m*A zH*$0;*?K)k^am-w>rP-0O83Vvj7#-8ZW#YGp?@f6|FYx_m6l59_5T032Rv!2CCx|!oJwFppH2wjOEJF|l9D~yM-gYf05A;d-b&pOy=IVjAd-ik^?!th=2&|LJY@%4L7hANJc~UbVMai8TJv;7qZu$ z(JiYmyd;_m7%98i7srZZXX!?~LVqn+_x9oQuKCGK=^;5)6;5*QJzM?pT=Il7Oy8K> zUL4pMfZlZ?+^f5XLF9pn81fhSd*kCuF=Fr5m8Kkowr&FO;6I7_27)ja;hdsFnNSqK zqygJl(#NPUV@9Y-KTKAx6~lS}z6a|q(sgYMguu>8$vG*p{7;<2xQEjUfPWF#Mf-SA z74PYiewwO=0t)W2`qTG{)mq+RBh5fK!QFKoS3$UGITz9W@(yZsdLx8(&gD<;7xhlB z)E{Cg|HRenDJjwZDt_vf!$F1iTZx~S9dV69zv@f)giWUC`5C7e<;(kLq#O<^u-{&q zuVxQbUM}gdznIp0a@c+Ox_<(|2<#%9vvpyGuH&gst#Sy5&7cDNtnSyc?YJtZC-(-k5Lui5Z+0L@Xl`tS0RS*PCtZqKz~EH0w2OV%@E!} z4B;x=5Z;*%;T`A@t^y9>o$(Ofxej6eZTVEp!44&y$2(rPw4knKJk|S5ibGw?iOwxn z*0qe|pHHM1>Iz1GnZ<2$Yv26OB>PwXLx~#qW=~E}8G2QI6CzYIt$Ex0`uCUV2OtWq zXJ@VFvDQa3tz*Y*+<#4<0F6ywDNB7QYGS(s*KtVKi+Qn6$+0+56t|D#S2!zX$adl& zRvnp#*q9Gt)&OlO?`4?s>o*s3_%GeGQZCfbP(%4qIh1{(%^7?sKbj5YM~I<(mJH=H zoE0-v|?Pk&iLN$J7)?(O(7Kks@Z zb+H#qh~YfpBOT8&IX?qZl4U5RF)DCXXWf8?%%CmhVGOG%DneM*;@SxN*@MnWc9JfJ zolNN=4dv_%AUUdD2(9ZP^#uk-0M=i!m6zL^3v#ukU1xc&umF>FU7Bv2pnGKGld5((H@E3Ml-sBm16b{k z`erY}`VRStO9buf7{*AK@!Qv#kI)mvnRW}C7wUJfn}3H=pqpM~Qz4d9BiW~%Q*2Si z`6Ibrt~)hFhFWLU3`}HX))s~h;=K{|(ORgW(gv86H@^R)dnT%Vb9tnN86=JH*5r)J zGe$qmfrly=4XdboB2>i|k@%3WixTH4E@nYQq-tK}wt2F8=${Bx4Tp)YMMtEHRL!g0 z^-orxyMH541*%Gfa~8o<#9}O6XDBIioFDZ1B=x@Xk5~8I<9D+D!*!d1?$`U(ElE!` zLed5`k!Q+mz2~K5z^W={I>lX!56>5Yq5semo0)LZ46uqOS=OM++@O?hyz_~Bu1CM& zA6u##*OlpgJ>m3w%*sltH)<~2*z|Y(c(tn^Yk#3{(9ilbZlytDHKUpzhoAK-r$cIu z^EKa@EO}_*xB=Uq2#m))RL{71x^A9YlH+SywTnso&}z2~CZ`IO%|)Z=w{AD7KEXX< z&ZO16A1bArj`$b(v7%(v?_jZD&=#)>R5Aj`_AdiZf9ioBH9pd1wH~OG}Z5xT7J-3G%Jk6^vtilJ8fTIpuPq zd5Y^9SfTgg{_go`z^)nEnxU>Q8wwCp>CNO_FTES_2E`@OIuWk}^3Wic0)Tkk44~mcX@l z8ciSD@MKMuz9bdJ>3V+fg9X7U&aVG`RWye+m{YnGugTN@qqJ!rHVW7E-!Y=H9)G`Z z-N8ZCx30>nud-rQR(%ygx4A6C+rZJzjtyn@kYA?r<-_Ji7jAl??7%WC`f^kOO=LK* zzN81-Cp4*$iq}|T0J}VYr@0wJgsHkjLvpQ$>LewR3agtiG@oG#@$Ju1lI@L&;Eg`k z6zuSbuoSyy;}t(&S-IoGnrEoc4|BcC)5}Vh{rGN%Yqk@>EEFgx-qx3b7U6T_h)xw( zs;Tzy#u_9a*RI$EpH3oR`>d_AYs(+lfPAM~Tl2PBF83A7Co!tQdRw*b;6b4lPt$dP kM2C3EA<3ep7~U%ac}6MG379Y&O;kD>ZmHw{02CHj5P+IT3;+NC delta 4376 zcmV+z5$Ep2G4LXgfqz?b+cp+{&#&N&_La`s>Qmo#;?&KI9lKGsuN<0$WNaytAt5{N zuV3&YUIF5E4%L~aV*xndcg}|^ArUwf$yInDDJ#AF{nuj_pAzEz+Y56V#98#Q$@)ms znEbqYVaH_C6!HJ%`5z9Q?bY4Wf2uz!j3^1q(z{-O5Zhh_0e?*?3xe7a!d|l21|^J` z9l#%VPmlINbk>ex^f=am?7q1K(pftO(&JbMayZW0mA_;!XpkeoABQ{>)w3>WZ#sf4 z3l1w#T^F`DZD8Y&g(Z>n$aCgGpF9`gBHaG+?f$m;C~6HNuw|6xR-udM?h%+bRh>ep zG{J%L48kigkADuitV#LHNDHk}NZ|-ChVFX5a|BZmDUmbDh=30fqzD;(B?0g8BoPk~ zvr-Eq#jNbrE#2==CC$=@>a(cbE#m6}FE>UQib}~KFGx%y#uJ16DXss&q|CD*$;uLu zVxMGh!7*gf3r!E`WLfbR7I9F9$DB8^wTO7FWf4J*EPu%k70WXsMVir)tkLSwYp1Q1 zb0y)|j+iibqwFQv)1M>`e&VxVfNh0VSw>m0POGem)*7i^Ns>jBy`x0-nI|0HKV@l5 zE9fvmnCwBRU`H%1fZQ)_YeEG2gcVVQWM_nm>;e~R-|Ch3X z2fy)23VNlaU_x0bsW@>bV{`VVsUO4guT}j(+cYp^#+VaT9$-AcsW`Ko(PNkdtO(OG z&k7b?wjHit@AzvWO%Ko>;!XJ1?(b7>%zAii+JE}@b5lp2WOV3IW*RMG+-`2^LQry} zBzUgUg2(Sco`mV@Zov;h2DF&%Kk8}QQNAXoeHRTRMYLCDzNV9Q3qZrb=5w2*pxMCS z7NZuIEYSG88gCt+ewJ-5y16_-)pf3YN*ZcDEuy_Q>K1^Crimyer-J*K684rAe+j=H z9e)>D_-=Gu@QPDMAtNdOcOSl!LNv?Bu{A-L=Ms1Lc4Puaf4F~w)`LUhsZoK0YH>@y zg9yNJfTOMup+oH+wCXiqv_HThSE17|p?~uQC838GR>EPQ(!5@&?BBsKuWzBM{rVQZ z;yEUC)pJP1D!QrV3>hov^NA5hv@KZWY=03dXA56BThz+o=+Hgr;gm31Pfbsxh)B+g zW;@ck&P2e#RVbtYta5R?^e5SGP|OyMHWs zm8_zdpfxZRt=AQuri>i;!wG?ezY9-gNyBvTfb&#j&*3voD0_bj*$bdjYO<@}!(@M- z9d=*8tsCXlmy~NgrFD0>YHD*YxBk2nr`>L6yTvsattCAYc=`Q?aNd;!|? zj>)BXgYw5IEJ(aIe?j-`5}dqik$?T3Mznn^Xhr_6TdmA{5l#l$q8@KQ1J1P^4P{Q@ zxZo?O4ind7oE6j#wAyw^)wV;ZwjFAbXPTpDZS4 ze*~S*&7U4n-=DJgX9w-iPU#O zv{WZXC2#!IISbFa;Y7(4lv6yS> znq4)GPsxyFEcnt6^}VI86LUkYvD&!H_69ng)3~Yi2BJol9I}G4m!o>-nwIQmuP$d9 zcXK%Ahc3f_Wny^u@O1Z|^Cu0;h^NxXl&`g_$UfD7YB1$lLZf$6rhnF|VtT6oR5N9^ zGG-$g9}Q&8tYpl386S0J%&cV0dKn+}WGoWD&oE*eTN?F&y0hC<-q^4nDKsmuvNW%} zfmKDORe6<-dF2hPDkiPUs|Ph38OMh8h^Q%JWoa(sz^WqClCiQemvLZKF->KhoTefj z=LM&(c*9dymJs=aM}PcFk+d1w_!G0y={G7m3#y(*hhlv6j$L^2vwh@QmR>0qfM)KE z7lMj1Mh?45D3Q-XIxyK(l3B}DjgRi_Zh9iqQl_|3nlx~$#xXP!HUZf$`lS=`(tj%ep70d_EwWAbWpkZl z&OHSfzh$}zqDLRGVqEL)N&f$!QChgLI0ZPy8 zK?+#P=VT2j$oXVGECJ5=Mx_A*6ubLQQUa=JwvQ}gT1+YhcA(yo!qzL2V)Ud)Ivsl=GhrG%fS43rN2&lmP>^7wM=R|g9 zF0sRJqu3oSVP%CFIQ;H~_Nl{e=9rqGZ+KXrk_-!X34aa6CuUw+Sg}_ta9~z}@HpwU zlopc1L1F5Kxcg;s0yh*8rea7t3Wg_!Lx8NO(q#ui9beFpF+~R6>Crr^-NhJR4YbW6 z%X8%+!PZv~XSW36fP6IM6-2(IsBj`)P&DbW{1+5c>JvJ{O5*=R^hS|*ZB=phAPsMl zfB;!frGLw=A$5E~L&g{RGEac|o5dDLn4;lrC!1eF0)*ot*WeWGfo9cXIqZhB>WpkF zE5yLzH=MOEZ*eon)C6(I$>N$BD4=x}8{AiMIqe2cz(bX)sJ~q=sqaTL zBc0n%FE2Fn;h5wO9#q!j>jE{{dCizoRwlRfyB_#;Yyr=W>idNBC~j;hZfc4fO^U4z z#kQu{>QD%7H*#&7I=6vpl!0+xvHEw4Vt>5G*DVYb*ZC?igs{%g24$UtvKWBWcdRN) zK)8+6!bTUm>K6rx2&HWeFOk(QX%~*g)FQdbPMmD*6=Yi(-F>Ubw|5w>DGYDcUCws) zw!df>)a@Qtvz&b(f_q6}LJ4p_q(mC-)-Id-4?IkZ4LU}7PQ~-pr9wi@kQerW$c_kRu_2J&z^aI`M7%k!aL11 z%fST#*{PQ=Q#UmLj1g`GSiH>Luz$Q=fULbFy=*Q?FWO7eQv&hiw{VlZQEBMVAX}-< zb@p~<;8DZ1+Xe3zTCdtlBjmyNE#oq#Vab84zKDPr)`b|B12$a8UO>|7vXa{e<=+r3 zTvzA`uPf9d+r;^?xxmt#c3|7x<*7${Xs>h}e9tzxV6wVj(!v`JQ1w-knt%QauZMJz zQgubDA|J$U)db++|76tHCQ646cD;%)lc6`ZneH(BW55Rbh$1g6@*Od zNyOsy1!_2*5wOKMzjD8-I)5FhJH%T01m%xYSdjQ4I(6i5P}sT^biZzos|nrA*U*Gd zrewdT5tZJ9jHu6u91aRux0m`?vxhP*))ee_&U&96c3;1F02s&;(Ud(;H**3_#fb%#LKXLW5Hby-Cdw(26FqH{5JQ6r1tD*y5JvwtVeadZ-XpN9LS z+K3qzar7T}o0ogGxXs(;gSd?%zTP(PydX?O@Y}q-r>bpUK}naOO296>lXl^qTNf@d zy6{fjg?B(*xCHOQJ4F}XL3H5~uM6*VyYLRQ3zvvpc&FWkceY)4_-*;Lm_;0ma1P(` z+ARwzEA6S?Wm4>I4Sy#(yRove(jNa@MvC56(EE!dXsTQJ&Ht=q_mzL2vy8imv(s}z zek}ebL`XBWysba}@pbwGkOkIPYu48>>z{h8To|AK*JhZoPArWSKD-}E`dZTCDG6Ct z7|~8l@7|A})7Rbkvw}EFN%@ko*DyJe<&$H@>_Yp;l2Z?nT(g?`vHuFgiY zLD@8=&Su_xPK+kX-Zw1Nb-X%J3bg5pT$O|cC6VnyPq9UY^CUU%VE(AHG+Ae68Xs!WE((zQBY2I*?_&3`c2kZNoz(O&`2E0{W{ zdEq*yKdNMJS0yWsn^qk7R9Ka}i-gUP8RMSVvC`ADiY5!WX@1R)obmGgGESh{YAzh`eA%e*$1 zcziHuvVXmt6x@Mh^P?se?dBzW^I7JfajJD{lj{#~44u;iH4j13YiA{Dh5pM^y}K$^ zjIpT_{W?(tZdh81+>bj7CkF+Gwb^Ej$M`UfJ<3HnAI22XY6f25J`>P_ucqTvN%roII)AU%b>7BQ<-%EnI;)PlLG{H$LfiC> z9JizZ40nFQnU?r-vw1xnCUb^~2b<_rvg}CzCb8eVL*Wab-SmB1K>XC@cukJiy`ZA+ z@Tevgi(~_@6|iuEH}ifT7U7XFQqaEz>lp*TC2-AgtkYx@KY3L}FJZ}OiJ2Q|M_8ttG5%tBNPzi@BTTXL3pTzXsHId3bUtoULon&y5tjlG#LS#w8m!F zPython Module Index

        grl.agents + + +     + grl.algorithms +     @@ -348,6 +353,11 @@

    Python Module Index

        grl.numerical_methods + + +     + grl.rl_modules +     diff --git a/searchindex.js b/searchindex.js index c1a09cb..3739b5a 100644 --- a/searchindex.js +++ b/searchindex.js @@ -1 +1 @@ -Search.setIndex({"docnames": ["api_doc/agents/index", "api_doc/algorithms/index", "api_doc/datasets/index", "api_doc/generative_models/index", "api_doc/neural_network/index", "api_doc/numerical_methods/index", "api_doc/rl_modules/index", "api_doc/utils/index", "concepts/index", "index", "tutorials/installation/index", "tutorials/quick_start/index", "user_guide/evaluating_agents", "user_guide/index", "user_guide/installation", "user_guide/training_agents", "user_guide/training_generative_models"], "filenames": ["api_doc/agents/index.rst", "api_doc/algorithms/index.rst", "api_doc/datasets/index.rst", "api_doc/generative_models/index.rst", "api_doc/neural_network/index.rst", "api_doc/numerical_methods/index.rst", "api_doc/rl_modules/index.rst", "api_doc/utils/index.rst", "concepts/index.rst", "index.rst", "tutorials/installation/index.rst", "tutorials/quick_start/index.rst", "user_guide/evaluating_agents.rst", "user_guide/index.rst", "user_guide/installation.rst", "user_guide/training_agents.rst", "user_guide/training_generative_models.rst"], "titles": ["grl.agents", "grl.algorithms", "grl.datasets", "grl.generative_models", "grl.neural_network", "grl.numerical_methods", "grl.rl_modules", "grl.utils", "Concepts", "GenerativeRL Documentation", "Installation", "Quick Start", "How to evaluate RL agents performance", "User Guide", "How to install GenerativeRL and its dependencies", "How to train and deploy reinforcement learning agents", "How to train generative models"], "terms": {"class": [0, 2, 3, 4, 5, 8, 12, 15, 16], "config": [0, 3, 5, 11, 12, 15, 16], "model": [0, 3, 4, 5, 9, 13, 15], "sourc": [0, 2, 3, 4, 5, 7, 8], "overview": [0, 2, 3, 4, 5, 7], "The": [0, 2, 3, 4, 5, 7, 8, 11, 12, 16], "qgpo": [0, 2, 11, 15], "algorithm": [0, 2, 8, 9, 11, 12, 15], "interfac": [0, 2, 3, 4, 5, 8, 11], "__init__": [0, 2, 3, 4, 5, 16], "action": [0, 2, 8, 12], "initi": [0, 2, 3, 4, 5, 8, 11, 16], "paramet": [0, 2, 3, 4, 5, 7, 8], "easydict": [0, 3, 5, 16], "configur": [0, 3, 5, 8, 11, 15, 16], "union": [0, 3, 4, 5], "torch": [0, 3, 4, 5], "nn": [0, 3, 5, 16], "modul": [0, 3, 5, 12, 15, 16], "moduledict": [0, 3], "act": [0, 11, 12, 15], "ob": 0, "return_as_torch_tensor": 0, "fals": [0, 3, 4, 5, 7], "given": [0, 3, 5, 8], "an": [0, 3, 4, 5, 8, 11, 12, 15], "observ": [0, 8, 11, 12, 15], "return": [0, 3, 4, 5, 7, 12, 15, 16], "np": 0, "ndarrai": 0, "tensor": [0, 3, 4, 5, 8], "dict": [0, 16], "bool": [0, 3, 4, 5, 7], "whether": [0, 3, 4, 5, 7, 8, 14], "type": [0, 3, 4, 5, 7, 8, 16], "srpo": 0, "train": [0, 2, 3, 8, 9, 11, 13], "gener": [0, 2, 3, 5, 7, 9, 13], "polici": [0, 2, 3, 8, 11], "thi": [0, 3, 4, 8, 9, 11, 16], "i": [0, 2, 3, 4, 5, 7, 8, 9, 11, 12, 13, 14, 16], "design": [0, 3, 9], "us": [0, 3, 4, 5, 7, 8, 9, 10, 11, 12, 14, 15, 16], "gmpgalgorithm": [0, 9], "gmpoalgorithm": [0, 9], "numpi": 0, "arrai": 0, "env_id": [2, 11, 12, 15], "base": [2, 3, 5, 8], "contrast": [2, 3], "energi": [2, 3], "predict": [2, 3, 5], "which": [2, 3, 5, 8, 12, 14, 16], "need": [2, 8, 11, 15], "true": [2, 3, 4, 5, 7], "fake": 2, "sampl": [2, 3, 5, 8, 16], "from": [2, 3, 5, 8, 10, 11, 15, 16], "support": [2, 3, 5, 11, 16], "behaviour": 2, "__getitem__": 2, "__len__": 2, "method": [2, 3, 5, 8, 11, 12, 15, 16], "str": [2, 3, 5], "environ": [2, 8, 11, 12, 14, 15], "id": 2, "d4rl": [2, 14], "sometim": 2, "data": [2, 3, 4, 5, 8, 16], "augment": [2, 8], "diffus": [3, 4, 5, 8, 9, 11, 16], "variou": [3, 8], "continu": [3, 8], "time": [3, 5, 8, 16], "path": [3, 5, 8, 13], "comput": [3, 5, 8], "score": [3, 8, 16], "function": [3, 4, 5, 8, 11, 16], "veloc": [3, 16], "It": [3, 5, 14], "can": [3, 5, 8, 10, 11, 12, 14, 15, 16], "via": 3, "nois": [3, 5, 16], "both": [3, 16], "match": [3, 8, 16], "loss": [3, 16], "flow": [3, 5, 8, 9, 16], "ar": [3, 8, 15, 16], "score_funct": 3, "score_matching_loss": [3, 16], "velocity_funct": 3, "flow_matching_loss": [3, 16], "data_prediction_funct": [3, 5], "t": [3, 4, 5, 8, 16], "x": [3, 4, 5, 16], "condit": [3, 4, 5, 16], "none": [3, 4, 5, 7, 16], "state": [3, 5, 11], "frac": [3, 5, 16], "sigma": [3, 4, 5, 8, 16], "x_t": [3, 8, 16], "2": [3, 4, 5, 14, 16], "nabla_": [3, 16], "log": [3, 5, 8, 11, 16], "p_": 3, "theta": [3, 16], "": [3, 5, 11], "input": [3, 4, 5, 12, 15, 16], "tensordict": [3, 4, 5], "treetensor": 3, "dpo_loss": 3, "ref_dm": 3, "beta": [3, 5], "process": [3, 5, 8, 11, 16], "direct": [3, 8], "optim": [3, 8, 11], "dpo": 3, "develop": [3, 10], "featur": 3, "recommend": [3, 5], "averag": 3, "across": [3, 8], "batch": [3, 4], "forward_sampl": 3, "t_span": [3, 5], "with_grad": 3, "solver_config": 3, "forward": [3, 4, 8, 16], "note": [3, 14], "revers": [3, 8, 16], "thu": 3, "form": [3, 5, 16], "rather": 3, "encod": 3, "latent": [3, 4], "space": [3, 5], "span": 3, "gradient": [3, 8], "solver": [3, 5, 16], "forward_sample_process": 3, "all": [3, 4], "intermedi": [3, 5], "log_prob": 3, "using_hutchinson_trace_estim": 3, "probabl": [3, 4, 5], "noise_funct": [3, 5, 16], "batch_siz": 3, "x_0": [3, 16], "final": [3, 4, 5], "size": [3, 4], "int": [3, 4, 5, 7], "tupl": [3, 4], "list": [3, 4, 13], "provid": [3, 7, 8, 9, 11, 15, 16], "gaussian": [3, 5, 16], "distribut": [3, 8, 16], "result": [3, 8], "shape": [3, 4, 5, 16], "where": [3, 5, 16], "number": [3, 4, 5, 11], "step": [3, 5, 11, 12, 14, 15], "b": 3, "could": 3, "scalar": [3, 5], "b1": 3, "b2": 3, "n": [3, 4, 5, 11], "d": [3, 5, 16], "dimens": [3, 4, 8], "d1": 3, "d2": 3, "extra": 3, "If": [3, 5, 7], "sample_forward_process": 3, "repeat": 3, "same": [3, 5, 16], "sample_forward_process_with_fixed_x": 3, "fixed_x": 3, "fixed_mask": 3, "fix": [3, 8], "mask": 3, "sample_with_fixed_x": 3, "sample_with_log_prob": 3, "likelihood": [3, 8, 16], "weighting_schem": 3, "uncondit": [3, 4], "weight": [3, 4, 8, 16], "scheme": 3, "maximum_likelihood": 3, "vanilla": 3, "maximum": [3, 8, 16], "estim": [3, 8, 16], "refer": [3, 11, 13, 15], "paper": 3, "more": [3, 5, 11, 13, 15], "detail": [3, 5, 11, 13], "lambda": [3, 5, 16], "denot": [3, 16], "g": [3, 5, 8, 14, 16], "numer": [3, 8], "stabil": 3, "we": [3, 5, 8, 11, 16], "mont": 3, "carlo": 3, "approxim": [3, 5], "integr": [3, 5, 8], "p": [3, 5, 14, 16], "balanc": 3, "mse": 3, "scale": [3, 4, 5, 16], "output": [3, 4, 5, 16], "valu": [3, 5, 7, 8], "through": [3, 4, 8, 11], "stochast": [3, 5, 8, 13], "differenti": [3, 5, 8, 16], "equat": [3, 5, 8, 16], "v_": [3, 16], "energy_model": 3, "text": [3, 16], "e": [3, 5, 8, 14, 16], "c": [3, 4, 14, 16], "sim": 3, "exp": 3, "mathcal": [3, 5, 16], "z": 3, "sample_without_energy_guid": 3, "score_function_with_energy_guid": 3, "energy_guidance_loss": 3, "data_prediction_function_with_energy_guid": 3, "guidance_scal": 3, "1": [3, 4, 5, 14, 16], "0": [3, 4, 5, 14, 16], "guidanc": [3, 4], "float": [3, 4, 5], "cep": 3, "propos": 3, "exact": 3, "guid": [3, 11], "offlin": [3, 8], "reinforc": [3, 9, 12, 13, 14], "learn": [3, 4, 9, 12, 13, 14], "noise_function_with_energy_guid": 3, "nose": 3, "nabla": 3, "sample_with_fixed_x_without_energy_guid": 3, "without": [3, 5], "independ": [3, 8, 16], "get_typ": 3, "x0": [3, 5], "x1": 3, "flow_matching_loss_with_mask": 3, "signal": [3, 8], "either": 3, "ha": [3, 5, 12, 15, 16], "correspond": 3, "element": [3, 4, 5], "usual": [3, 16], "x_1": [3, 16], "log_prob_x_0": 3, "function_log_prob_x_0": 3, "callabl": [3, 5], "hutchinson": 3, "trace": 3, "jacobian": 3, "drift": [3, 5, 8, 16], "faster": 3, "less": 3, "accur": [3, 8], "set": [3, 7, 16], "high": [3, 5, 8], "dimension": 3, "log_likelihood": 3, "optimal_transport_flow_matching_loss": 3, "transport": 3, "plan": 3, "two": 3, "sample_with_mask": 3, "sample_with_mask_forward_process": 3, "between": [3, 5, 8, 16], "flow_matching_loss_small_batch_ot_plan": 3, "small": 3, "acceler": 3, "concaten": 4, "along": 4, "last": [4, 5], "layer": [4, 16], "hidden_s": [4, 16], "output_s": 4, "activ": 4, "dropout": 4, "layernorm": 4, "final_activ": 4, "shrink": 4, "multi": 4, "perceptron": 4, "fulli": 4, "connect": 4, "fc1": 4, "act1": 4, "fcn": 4, "actn": 4, "out": 4, "hidden": 4, "channel": 4, "option": [4, 7], "zero": 4, "default": [4, 7, 15], "block": [4, 11], "shrinkag": 4, "factor": [4, 5], "kwarg": [4, 5], "pass": 4, "mlp": [4, 16], "keyword": 4, "argument": [4, 5], "output_dim": [4, 16], "t_dim": [4, 16], "input_dim": 4, "condition_dim": 4, "condition_hidden_dim": 4, "t_condition_hidden_dim": 4, "tempor": 4, "spatial": 4, "residu": 4, "network": [4, 8, 13], "multipl": 4, "temporalspatialresblock": 4, "input_s": 4, "32": 4, "patch_siz": 4, "in_channel": 4, "4": [4, 5], "1152": 4, "depth": 4, "28": 4, "num_head": 4, "16": 4, "mlp_ratio": 4, "class_dropout_prob": 4, "num_class": 4, "1000": 4, "learn_sigma": 4, "transform": [4, 8, 16], "backbon": [4, 16], "offici": 4, "implement": [4, 8, 12, 15], "github": [4, 10, 14], "repo": 4, "http": [4, 10, 14], "com": [4, 10, 14], "facebookresearch": 4, "blob": 4, "main": 4, "py": [4, 14], "patch": 4, "attent": 4, "head": 4, "respect": 4, "timestep": 4, "imag": [4, 8, 14], "represent": 4, "label": 4, "forward_with_cfg": 4, "cfg_scale": 4, "also": [4, 5, 8, 10, 15, 16], "classifi": [4, 8], "free": [4, 8], "initialize_weight": 4, "unpatchifi": 4, "img": 4, "h": 4, "w": [4, 16], "token_s": 4, "condition_embedd": 4, "1d": 4, "3d": 4, "inform": [4, 11, 13, 15], "origin": 4, "video": [4, 8], "alia": 4, "patch_block_s": 4, "10": 4, "convolv": 4, "each": 4, "token": 4, "total_patch": 4, "ordinari": [5, 8], "defin": [5, 8, 15, 16], "dx": 5, "f": [5, 8, 16], "dt": [5, 8, 16], "term": 5, "dw": 5, "wiener": [5, 16], "order": 5, "devic": [5, 16], "atol": 5, "1e": 5, "05": 5, "rtol": 5, "dpm_solver": 5, "singlestep": 5, "solver_typ": 5, "skip_typ": 5, "time_uniform": 5, "denois": [5, 8, 16], "dpm": 5, "should": 5, "3": [5, 14], "absolut": 5, "toler": 5, "adapt": [5, 8], "rel": 5, "total": 5, "evalu": [5, 8, 9, 11, 13, 15], "nfe": 5, "multistep": 5, "singlestep_fix": 5, "taylor": 5, "slightli": 5, "impact": 5, "perform": [5, 9, 13, 15], "logsnr": 5, "time_quadrat": 5, "diffusion_process": 5, "save_intermedi": 5, "diffusionprocess": 5, "t_start": 5, "solut": 5, "t_end": 5, "x_end": 5, "ode_solv": 5, "euler": [5, 8], "01": 5, "librari": [5, 8, 9, 11, 14, 16], "torchdyn": [5, 8, 16], "torchdiffeq": [5, 8], "current": [5, 16], "addit": [5, 14], "first": [5, 11], "For": [5, 8, 11, 13, 14, 15, 16], "exampl": [5, 8, 11, 13, 14, 16], "trajectori": 5, "len": [5, 16], "sde_solv": 5, "sde_noise_typ": 5, "diagon": 5, "sde_typ": 5, "ito": 5, "001": 5, "torchsd": 5, "stratonovich": 5, "logqp": 5, "case": [5, 11], "mu": 5, "written": 5, "mathrm": [5, 16], "w_": 5, "sqrt": 5, "covari": 5, "matrix": 5, "standard": [5, 16], "deviat": 5, "half": 5, "differ": [5, 8, 13], "vp": [5, 16], "int_": [5, 16], "linear": [5, 16], "todo": 5, "add": 5, "cosin": 5, "ve": 5, "opt": 5, "halflogsnr": 5, "inversehalflogsnr": 5, "invers": 5, "sinc": 5, "invert": 5, "beta_1": [5, 16], "beta_0": [5, 16], "d_covariance_dt": 5, "deriv": [5, 16], "d_log_scale_dt": 5, "d_scale_dt": 5, "d_std_dt": 5, "follow": [5, 8, 14, 15, 16], "diffusion_squar": 5, "drift_coeffici": 5, "coeffici": [5, 16], "satisfi": 5, "log_scal": 5, "std": 5, "seed_valu": 7, "cudnn_determinist": 7, "cudnn_benchmark": 7, "random": [7, 8], "seed": [7, 8], "make": [7, 8, 9, 11, 12, 14, 15], "cudnn": 7, "oper": 7, "determinist": 7, "enabl": [7, 8], "benchmark": 7, "convolut": 7, "framework": [8, 9], "consist": 8, "code": [8, 16], "api": [8, 11, 13, 15, 16], "deploy": [8, 11], "generativerl": [8, 10, 12, 13, 15, 16], "user": [8, 12, 16], "friendli": 8, "deploi": [8, 9, 11, 12, 13], "rl": [8, 9, 13, 15], "agent": [8, 9, 11, 13], "In": [8, 11, 12, 15, 16], "section": [8, 11, 13, 15], "explor": 8, "core": 8, "includ": [8, 11, 16], "discuss": 8, "kei": 8, "underpin": 8, "how": [8, 9, 11, 13], "thei": 8, "leverag": 8, "address": 8, "complex": 8, "problem": [8, 9, 14], "field": [8, 16], "addition": 8, "explain": 8, "why": 8, "import": [8, 11, 12, 14, 15, 16], "what": 8, "uniqu": 8, "wide": 8, "rang": [8, 11, 12, 15, 16], "applic": [8, 16], "machin": 8, "new": [8, 11, 16], "typic": [8, 16], "dataset": [8, 9, 11, 15], "most": 8, "unsupervis": 8, "techniqu": 8, "appli": 8, "task": 8, "audio": 8, "interpol": 8, "focus": 8, "dynam": 8, "These": [8, 16], "have": [8, 14], "capac": 8, "captur": 8, "demonstr": [8, 11], "promis": 8, "varieti": 8, "its": [8, 9, 11, 13, 15, 16], "variant": [8, 16], "qualiti": 8, "solv": [8, 9, 14], "od": [8, 9], "sde": [8, 9, 16], "dx_t": [8, 16], "dw_t": [8, 16], "unifi": [8, 12, 16], "howev": 8, "vari": 8, "definit": 8, "some": [8, 14], "under": [8, 12, 15], "common": 8, "while": [8, 15, 16], "other": [8, 11, 13, 15], "mai": 8, "requir": [8, 14], "specif": [8, 11, 15], "There": 8, "four": 8, "open": 8, "neural": [8, 13], "parameter": [8, 13], "certain": 8, "part": 8, "potenti": 8, "determin": 8, "procedur": 8, "fundament": 8, "object": [8, 13], "maxim": 8, "pretrain": 8, "like": 8, "bridg": [8, 16], "fine": 8, "tune": 8, "advantag": 8, "regress": 8, "adjoint": 8, "involv": 8, "depend": [8, 9, 13], "maruyama": 8, "rung": 8, "kutta": 8, "offer": 8, "flexibl": [8, 11], "allow": 8, "custom": [8, 13], "extend": 8, "suit": [8, 11, 14], "instanc": [8, 11, 15], "easili": 8, "own": [8, 16], "architectur": [8, 16], "creat": [8, 11, 15, 16], "tailor": 8, "format": [8, 11], "decis": [8, 9], "interact": [8, 15], "receiv": 8, "reward": [8, 11, 12, 15], "penalti": 8, "cumul": 8, "take": [8, 12, 15, 16], "updat": 8, "categor": 8, "directli": [8, 16], "onlin": 8, "strategi": 8, "off": 8, "actor": 8, "critic": 8, "research": 8, "improv": 8, "effici": 8, "synthet": 8, "decoupl": 8, "littl": 8, "modif": 8, "rank": 8, "least": 8, "automat": 8, "pytorch": [8, 14], "unif": 8, "within": [8, 11], "singl": 8, "simplic": 8, "simpl": [8, 11], "intuit": 8, "extens": 8, "dictionari": [8, 15], "modular": 8, "built": 8, "mix": 8, "compon": [8, 11], "reproduc": 8, "ensur": 8, "checkpoint": 8, "possibl": 8, "run": [8, 12, 14], "minim": 8, "seek": 8, "extern": 8, "lightweight": 8, "instal": [8, 9, 13], "platform": 8, "compat": 8, "exist": 8, "work": [8, 16], "seamlessli": 8, "openai": [8, 11], "gym": [8, 11, 12, 14, 15], "torchrl": 8, "python": [9, 14], "aim": 9, "combin": 9, "power": [9, 11], "capabl": 9, "quick": 9, "start": 9, "explan": 9, "principl": 9, "grl": [9, 10, 11, 12, 14, 15, 16], "qgpoagent": 9, "srpoagent": 9, "gpagent": 9, "qgpocrit": 9, "qgpopolici": 9, "qgpoalgorithm": [9, 11, 15], "srpocrit": 9, "srpopolici": 9, "srpoalgorithm": 9, "gmpocrit": 9, "gmpopolici": 9, "gmpgcritic": 9, "gmpgpolici": 9, "qgpod4rldataset": 9, "qgpodataset": 9, "gpd4rldataset": 9, "gpdataset": 9, "generative_model": [9, 16], "diffusionmodel": [9, 16], "energyconditionaldiffusionmodel": 9, "independentconditionalflowmodel": 9, "optimaltransportconditionalflowmodel": 9, "neural_network": [9, 16], "concatenatelay": 9, "multilayerperceptron": 9, "concatenatemlp": 9, "temporalspatialresidualnet": [9, 16], "dit": 9, "dit1d": 9, "dit2d": 9, "dit3d": 9, "numerical_method": [9, 16], "dpmsolver": 9, "odesolv": [9, 16], "sdesolv": 9, "gaussianconditionalprobabilitypath": [9, 16], "rl_modul": 9, "gymenvsimul": 9, "oneshotvaluefunct": 9, "vnetwork": 9, "doublevnetwork": 9, "qnetwork": 9, "doubleqnetwork": 9, "util": [9, 11], "set_se": 9, "pip": [10, 14], "you": [10, 11, 14, 15], "latest": 10, "version": [10, 14], "git": [10, 14], "opendilab": [10, 14], "easi": 11, "swiss": 11, "roll": 11, "colab": 11, "usag": [11, 13], "found": 11, "folder": 11, "grl_pipelin": [11, 15], "tutori": 11, "here": [11, 13, 16], "q": 11, "halfcheetah": 11, "diffusion_model": [11, 15, 16], "d4rl_halfcheetah_qgpo": [11, 15], "def": [11, 16], "qgpo_pipelin": 11, "env": [11, 12, 15], "reset": [11, 12, 15], "_": [11, 12, 15, 16], "num_deploy_step": [11, 12, 15], "render": [11, 12, 15], "done": [11, 12, 15], "__name__": 11, "__main__": 11, "info": 11, "necessari": 11, "well": 11, "encapsul": 11, "call": [11, 16], "after": [11, 14], "obtain": [11, 16], "A": [11, 16], "loop": 11, "execut": 11, "specifi": 11, "print": 11, "consol": 11, "modifi": 11, "your": 11, "advanc": [11, 13], "pleas": [11, 13, 15], "document": [11, 13, 15], "simul": 12, "collect": 12, "9": 14, "higher": 14, "command": 14, "clone": 14, "cd": 14, "pybullet": 14, "mujoco": 14, "deepmind": 14, "control": 14, "etc": 14, "dm_control": 14, "setup": 14, "licens": 14, "special": 14, "23": 14, "anoth": 14, "thing": 14, "sudo": 14, "apt": 14, "get": 14, "libgl1": 14, "mesa": 14, "glx": 14, "libglib2": 14, "libsm6": 14, "libxext6": 14, "libxrend": 14, "dev": 14, "y": 14, "swig": 14, "gcc": 14, "local": 14, "dnsutil": 14, "cmake": 14, "build": 14, "essenti": 14, "libglew": 14, "libosmesa6": 14, "libglfw3": 14, "libsdl2": 14, "libglm": 14, "libfreetype6": 14, "patchelf": 14, "ffmpeg": 14, "mkdir": 14, "root": 14, "wget": 14, "org": 14, "download": 14, "mujoco210": 14, "linux": 14, "x86_64": 14, "tar": 14, "gz": 14, "o": 14, "xf": 14, "export": 14, "ld_library_path": 14, "mjpro210": 14, "bin": 14, "farama": 14, "foundat": 14, "lockfil": 14, "cython": 14, "check": 14, "success": 14, "everi": [15, 16], "hyperparamet": 15, "copi": 15, "trained_model": 15, "divers": 16, "describ": 16, "evolut": 16, "over": 16, "increment": 16, "probability_path": 16, "kind": 16, "varianc": 16, "preserv": 16, "gvp": 16, "usal": 16, "want": 16, "normal": 16, "target": 16, "By": 16, "fokker": 16, "planck": 16, "kolmogorov": 16, "fpk": 16, "hat": 16, "_t": 16, "s_": 16, "codebas": 16, "ddpm": 16, "compar": 16, "Or": 16, "v": 16, "nerual": 16, "therefor": 16, "intrinsicmodel": 16, "ani": 16, "cnn": 16, "u": 16, "net": 16, "x_size": 16, "alpha": 16, "arg": 16, "linear_vp_sd": 16, "20": 16, "t_encod": 16, "512": 16, "256": 16, "128": 16, "t_embedding_dim": 16, "register_modul": 16, "regist": 16, "so": 16, "mymodul": 16, "self": 16, "super": 16, "modulelist": 16, "append": 16, "relu": 16, "mle": 16, "onli": 16, "mean": 16, "squar": 16, "error": 16, "l": 16, "dsm": 16, "mathbb": 16, "left": 16, "right": 16, "cfm": 16, "simpli": 16}, "objects": {"grl": [[0, 0, 0, "-", "agents"], [2, 0, 0, "-", "datasets"], [3, 0, 0, "-", "generative_models"], [4, 0, 0, "-", "neural_network"], [5, 0, 0, "-", "numerical_methods"], [7, 0, 0, "-", "utils"]], "grl.agents": [[0, 1, 1, "", "GPAgent"], [0, 1, 1, "", "QGPOAgent"], [0, 1, 1, "", "SRPOAgent"]], "grl.agents.GPAgent": [[0, 2, 1, "", "__init__"], [0, 2, 1, "", "act"]], "grl.agents.QGPOAgent": [[0, 2, 1, "", "__init__"], [0, 2, 1, "", "act"]], "grl.agents.SRPOAgent": [[0, 2, 1, "", "__init__"], [0, 2, 1, "", "act"]], "grl.datasets": [[2, 1, 1, "", "GPD4RLDataset"], [2, 1, 1, "", "GPDataset"], [2, 1, 1, "", "QGPOD4RLDataset"], [2, 1, 1, "", "QGPODataset"]], "grl.datasets.GPD4RLDataset": [[2, 2, 1, "", "__init__"]], "grl.datasets.GPDataset": [[2, 2, 1, "", "__init__"]], "grl.datasets.QGPOD4RLDataset": [[2, 2, 1, "", "__init__"]], "grl.datasets.QGPODataset": [[2, 2, 1, "", "__init__"]], "grl.generative_models": [[3, 1, 1, "", "DiffusionModel"], [3, 1, 1, "", "EnergyConditionalDiffusionModel"], [3, 1, 1, "", "IndependentConditionalFlowModel"], [3, 1, 1, "", "OptimalTransportConditionalFlowModel"]], "grl.generative_models.DiffusionModel": [[3, 2, 1, "", "__init__"], [3, 2, 1, "", "data_prediction_function"], [3, 2, 1, "", "dpo_loss"], [3, 2, 1, "", "flow_matching_loss"], [3, 2, 1, "", "forward_sample"], [3, 2, 1, "", "forward_sample_process"], [3, 2, 1, "", "log_prob"], [3, 2, 1, "", "noise_function"], [3, 2, 1, "", "sample"], [3, 2, 1, "", "sample_forward_process"], [3, 2, 1, "", "sample_forward_process_with_fixed_x"], [3, 2, 1, "", "sample_with_fixed_x"], [3, 2, 1, "", "sample_with_log_prob"], [3, 2, 1, "", "score_function"], [3, 2, 1, "", "score_matching_loss"], [3, 2, 1, "", "velocity_function"]], "grl.generative_models.EnergyConditionalDiffusionModel": [[3, 2, 1, "", "__init__"], [3, 2, 1, "", "data_prediction_function"], [3, 2, 1, "", "data_prediction_function_with_energy_guidance"], [3, 2, 1, "", "energy_guidance_loss"], [3, 2, 1, "", "flow_matching_loss"], [3, 2, 1, "", "noise_function"], [3, 2, 1, "", "noise_function_with_energy_guidance"], [3, 2, 1, "", "sample"], [3, 2, 1, "", "sample_forward_process"], [3, 2, 1, "", "sample_forward_process_with_fixed_x"], [3, 2, 1, "", "sample_with_fixed_x"], [3, 2, 1, "", "sample_with_fixed_x_without_energy_guidance"], [3, 2, 1, "", "sample_without_energy_guidance"], [3, 2, 1, "", "score_function"], [3, 2, 1, "", "score_function_with_energy_guidance"], [3, 2, 1, "", "score_matching_loss"], [3, 2, 1, "", "velocity_function"]], "grl.generative_models.IndependentConditionalFlowModel": [[3, 2, 1, "", "__init__"], [3, 2, 1, "", "flow_matching_loss"], [3, 2, 1, "", "flow_matching_loss_with_mask"], [3, 2, 1, "", "forward_sample"], [3, 2, 1, "", "forward_sample_process"], [3, 2, 1, "", "log_prob"], [3, 2, 1, "", "optimal_transport_flow_matching_loss"], [3, 2, 1, "", "sample"], [3, 2, 1, "", "sample_forward_process"], [3, 2, 1, "", "sample_with_log_prob"], [3, 2, 1, "", "sample_with_mask"], [3, 2, 1, "", "sample_with_mask_forward_process"]], "grl.generative_models.OptimalTransportConditionalFlowModel": [[3, 2, 1, "", "__init__"], [3, 2, 1, "", "flow_matching_loss"], [3, 2, 1, "", "flow_matching_loss_small_batch_OT_plan"], [3, 2, 1, "", "sample"], [3, 2, 1, "", "sample_forward_process"]], "grl.neural_network": [[4, 1, 1, "", "ConcatenateLayer"], [4, 1, 1, "", "ConcatenateMLP"], [4, 1, 1, "", "DiT"], [4, 1, 1, "", "DiT1D"], [4, 3, 1, "", "DiT2D"], [4, 1, 1, "", "DiT3D"], [4, 1, 1, "", "MultiLayerPerceptron"], [4, 1, 1, "", "TemporalSpatialResidualNet"]], "grl.neural_network.ConcatenateLayer": [[4, 2, 1, "", "__init__"], [4, 2, 1, "", "forward"]], "grl.neural_network.ConcatenateMLP": [[4, 2, 1, "", "__init__"], [4, 2, 1, "", "forward"]], "grl.neural_network.DiT": [[4, 2, 1, "", "__init__"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "forward_with_cfg"], [4, 2, 1, "", "initialize_weights"], [4, 2, 1, "", "unpatchify"]], "grl.neural_network.DiT1D": [[4, 2, 1, "", "__init__"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "initialize_weights"]], "grl.neural_network.DiT3D": [[4, 2, 1, "", "__init__"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "initialize_weights"], [4, 2, 1, "", "unpatchify"]], "grl.neural_network.MultiLayerPerceptron": [[4, 2, 1, "", "__init__"], [4, 2, 1, "", "forward"]], "grl.neural_network.TemporalSpatialResidualNet": [[4, 2, 1, "", "__init__"], [4, 2, 1, "", "forward"]], "grl.numerical_methods": [[5, 1, 1, "", "DPMSolver"], [5, 1, 1, "", "GaussianConditionalProbabilityPath"], [5, 1, 1, "", "ODE"], [5, 1, 1, "", "ODESolver"], [5, 1, 1, "", "SDE"], [5, 1, 1, "", "SDESolver"]], "grl.numerical_methods.DPMSolver": [[5, 2, 1, "", "__init__"], [5, 2, 1, "", "integrate"]], "grl.numerical_methods.GaussianConditionalProbabilityPath": [[5, 2, 1, "", "HalfLogSNR"], [5, 2, 1, "", "InverseHalfLogSNR"], [5, 2, 1, "", "__init__"], [5, 2, 1, "", "covariance"], [5, 2, 1, "", "d_covariance_dt"], [5, 2, 1, "", "d_log_scale_dt"], [5, 2, 1, "", "d_scale_dt"], [5, 2, 1, "", "d_std_dt"], [5, 2, 1, "", "diffusion"], [5, 2, 1, "", "diffusion_squared"], [5, 2, 1, "", "drift"], [5, 2, 1, "", "drift_coefficient"], [5, 2, 1, "", "log_scale"], [5, 2, 1, "", "scale"], [5, 2, 1, "", "std"]], "grl.numerical_methods.ODE": [[5, 2, 1, "", "__init__"]], "grl.numerical_methods.ODESolver": [[5, 2, 1, "", "__init__"], [5, 2, 1, "", "integrate"]], "grl.numerical_methods.SDE": [[5, 2, 1, "", "__init__"]], "grl.numerical_methods.SDESolver": [[5, 2, 1, "", "__init__"], [5, 2, 1, "", "integrate"]], "grl.utils": [[7, 4, 1, "", "set_seed"]]}, "objtypes": {"0": "py:module", "1": "py:class", "2": "py:method", "3": "py:attribute", "4": "py:function"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "class", "Python class"], "2": ["py", "method", "Python method"], "3": ["py", "attribute", "Python attribute"], "4": ["py", "function", "Python function"]}, "titleterms": {"grl": [0, 1, 2, 3, 4, 5, 6, 7], "agent": [0, 12, 15], "qgpoagent": 0, "srpoagent": 0, "gpagent": 0, "algorithm": 1, "qgpocrit": 1, "qgpopolici": 1, "qgpoalgorithm": 1, "srpocrit": 1, "srpopolici": 1, "srpoalgorithm": 1, "gmpocrit": 1, "gmpopolici": 1, "gmpoalgorithm": 1, "gmpgcritic": 1, "gmpgpolici": 1, "gmpgalgorithm": 1, "dataset": 2, "qgpod4rldataset": 2, "qgpodataset": 2, "gpd4rldataset": 2, "gpdataset": 2, "generative_model": 3, "diffusionmodel": 3, "energyconditionaldiffusionmodel": 3, "independentconditionalflowmodel": 3, "optimaltransportconditionalflowmodel": 3, "neural_network": 4, "concatenatelay": 4, "multilayerperceptron": 4, "concatenatemlp": 4, "temporalspatialresidualnet": 4, "dit": 4, "dit1d": 4, "dit2d": 4, "dit3d": 4, "numerical_method": 5, "od": 5, "sde": 5, "dpmsolver": 5, "odesolv": 5, "sdesolv": 5, "gaussianconditionalprobabilitypath": 5, "rl_modul": 6, "gymenvsimul": 6, "oneshotvaluefunct": 6, "vnetwork": 6, "doublevnetwork": 6, "qnetwork": 6, "doubleqnetwork": 6, "util": 7, "set_se": 7, "concept": [8, 9], "overview": [8, 9], "gener": [8, 11, 16], "model": [8, 11, 16], "reinforc": [8, 11, 15], "learn": [8, 11, 15], "design": 8, "principl": 8, "generativerl": [9, 11, 14], "document": 9, "tutori": 9, "user": [9, 13], "guid": [9, 13], "api": 9, "instal": [10, 14], "quick": 11, "start": 11, "explan": 11, "how": [12, 14, 15, 16], "evalu": 12, "rl": 12, "perform": 12, "its": 14, "depend": 14, "train": [15, 16], "deploi": 15, "stochast": 16, "path": 16, "parameter": 16, "custom": 16, "neural": 16, "network": 16, "object": 16, "differ": 16}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 8, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.viewcode": 1, "sphinx.ext.todo": 2, "sphinx.ext.intersphinx": 1, "nbsphinx": 4, "sphinx": 57}, "alltitles": {"grl.agents": [[0, "module-grl.agents"]], "QGPOAgent": [[0, "qgpoagent"]], "SRPOAgent": [[0, "srpoagent"]], "GPAgent": [[0, "gpagent"]], "grl.algorithms": [[1, "grl-algorithms"]], "QGPOCritic": [[1, "qgpocritic"]], "QGPOPolicy": [[1, "qgpopolicy"]], "QGPOAlgorithm": [[1, "qgpoalgorithm"]], "SRPOCritic": [[1, "srpocritic"]], "SRPOPolicy": [[1, "srpopolicy"]], "SRPOAlgorithm": [[1, "srpoalgorithm"]], "GMPOCritic": [[1, "gmpocritic"]], "GMPOPolicy": [[1, "gmpopolicy"]], "GMPOAlgorithm": [[1, "gmpoalgorithm"]], "GMPGCritic": [[1, "gmpgcritic"]], "GMPGPolicy": [[1, "gmpgpolicy"]], "GMPGAlgorithm": [[1, "gmpgalgorithm"]], "grl.datasets": [[2, "module-grl.datasets"]], "QGPOD4RLDataset": [[2, "qgpod4rldataset"]], "QGPODataset": [[2, "qgpodataset"]], "GPD4RLDataset": [[2, "gpd4rldataset"]], "GPDataset": [[2, "gpdataset"]], "grl.generative_models": [[3, "module-grl.generative_models"]], "DiffusionModel": [[3, "diffusionmodel"]], "EnergyConditionalDiffusionModel": [[3, "energyconditionaldiffusionmodel"]], "IndependentConditionalFlowModel": [[3, "independentconditionalflowmodel"]], "OptimalTransportConditionalFlowModel": [[3, "optimaltransportconditionalflowmodel"]], "grl.neural_network": [[4, "module-grl.neural_network"]], "ConcatenateLayer": [[4, "concatenatelayer"]], "MultiLayerPerceptron": [[4, "multilayerperceptron"]], "ConcatenateMLP": [[4, "concatenatemlp"]], "TemporalSpatialResidualNet": [[4, "temporalspatialresidualnet"]], "DiT": [[4, "dit"]], "DiT1D": [[4, "dit1d"]], "DiT2D": [[4, "dit2d"]], "DiT3D": [[4, "dit3d"]], "grl.numerical_methods": [[5, "module-grl.numerical_methods"]], "ODE": [[5, "ode"]], "SDE": [[5, "sde"]], "DPMSolver": [[5, "dpmsolver"]], "ODESolver": [[5, "odesolver"]], "SDESolver": [[5, "sdesolver"]], "GaussianConditionalProbabilityPath": [[5, "gaussianconditionalprobabilitypath"]], "grl.rl_modules": [[6, "grl-rl-modules"]], "GymEnvSimulator": [[6, "gymenvsimulator"]], "OneShotValueFunction": [[6, "oneshotvaluefunction"]], "VNetwork": [[6, "vnetwork"]], "DoubleVNetwork": [[6, "doublevnetwork"]], "QNetwork": [[6, "qnetwork"]], "DoubleQNetwork": [[6, "doubleqnetwork"]], "grl.utils": [[7, "module-grl.utils"]], "set_seed": [[7, "set-seed"]], "Concepts": [[8, "concepts"], [9, null]], "Concepts Overview": [[8, "concepts-overview"]], "Generative Models": [[8, "generative-models"]], "Reinforcement Learning": [[8, "reinforcement-learning"], [11, "reinforcement-learning"]], "Design Principles": [[8, "design-principles"]], "GenerativeRL Documentation": [[9, "generativerl-documentation"]], "Overview": [[9, "overview"]], "Tutorials": [[9, null]], "User Guide": [[9, null], [13, "user-guide"], [13, null]], "API Documentation": [[9, null]], "Installation": [[10, "installation"]], "Quick Start": [[11, "quick-start"]], "Generative model in GenerativeRL": [[11, "generative-model-in-generativerl"]], "Explanation": [[11, "explanation"]], "How to evaluate RL agents performance": [[12, "how-to-evaluate-rl-agents-performance"]], "How to install GenerativeRL and its dependencies": [[14, "how-to-install-generativerl-and-its-dependencies"]], "How to train and deploy reinforcement learning agents": [[15, "how-to-train-and-deploy-reinforcement-learning-agents"]], "How to train generative models": [[16, "how-to-train-generative-models"]], "Stochastic path": [[16, "stochastic-path"]], "Model parameterization": [[16, "model-parameterization"]], "Customized neural network": [[16, "customized-neural-network"]], "Training objective for different generative models": [[16, "training-objective-for-different-generative-models"]]}, "indexentries": {"gpagent (class in grl.agents)": [[0, "grl.agents.GPAgent"]], "qgpoagent (class in grl.agents)": [[0, "grl.agents.QGPOAgent"]], "srpoagent (class in grl.agents)": [[0, "grl.agents.SRPOAgent"]], "__init__() (grl.agents.gpagent method)": [[0, "grl.agents.GPAgent.__init__"]], "__init__() (grl.agents.qgpoagent method)": [[0, "grl.agents.QGPOAgent.__init__"]], "__init__() (grl.agents.srpoagent method)": [[0, "grl.agents.SRPOAgent.__init__"]], "act() (grl.agents.gpagent method)": [[0, "grl.agents.GPAgent.act"]], "act() (grl.agents.qgpoagent method)": [[0, "grl.agents.QGPOAgent.act"]], "act() (grl.agents.srpoagent method)": [[0, "grl.agents.SRPOAgent.act"]], "grl.agents": [[0, "module-grl.agents"]], "module": [[0, "module-grl.agents"], [2, "module-grl.datasets"], [3, "module-grl.generative_models"], [4, "module-grl.neural_network"], [5, "module-grl.numerical_methods"], [7, "module-grl.utils"]], "gpd4rldataset (class in grl.datasets)": [[2, "grl.datasets.GPD4RLDataset"]], "gpdataset (class in grl.datasets)": [[2, "grl.datasets.GPDataset"]], "qgpod4rldataset (class in grl.datasets)": [[2, "grl.datasets.QGPOD4RLDataset"]], "qgpodataset (class in grl.datasets)": [[2, "grl.datasets.QGPODataset"]], "__init__() (grl.datasets.gpd4rldataset method)": [[2, "grl.datasets.GPD4RLDataset.__init__"]], "__init__() (grl.datasets.gpdataset method)": [[2, "grl.datasets.GPDataset.__init__"]], "__init__() (grl.datasets.qgpod4rldataset method)": [[2, "grl.datasets.QGPOD4RLDataset.__init__"]], "__init__() (grl.datasets.qgpodataset method)": [[2, "grl.datasets.QGPODataset.__init__"]], "grl.datasets": [[2, "module-grl.datasets"]], "diffusionmodel (class in grl.generative_models)": [[3, "grl.generative_models.DiffusionModel"]], "energyconditionaldiffusionmodel (class in grl.generative_models)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel"]], "independentconditionalflowmodel (class in grl.generative_models)": [[3, "grl.generative_models.IndependentConditionalFlowModel"]], "optimaltransportconditionalflowmodel (class in grl.generative_models)": [[3, "grl.generative_models.OptimalTransportConditionalFlowModel"]], "__init__() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.__init__"]], "__init__() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.__init__"]], "__init__() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.__init__"]], "__init__() (grl.generative_models.optimaltransportconditionalflowmodel method)": [[3, "grl.generative_models.OptimalTransportConditionalFlowModel.__init__"]], "data_prediction_function() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.data_prediction_function"]], "data_prediction_function() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.data_prediction_function"]], "data_prediction_function_with_energy_guidance() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.data_prediction_function_with_energy_guidance"]], "dpo_loss() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.dpo_loss"]], "energy_guidance_loss() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.energy_guidance_loss"]], "flow_matching_loss() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.flow_matching_loss"]], "flow_matching_loss() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.flow_matching_loss"]], "flow_matching_loss() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.flow_matching_loss"]], "flow_matching_loss() (grl.generative_models.optimaltransportconditionalflowmodel method)": [[3, "grl.generative_models.OptimalTransportConditionalFlowModel.flow_matching_loss"]], "flow_matching_loss_small_batch_ot_plan() (grl.generative_models.optimaltransportconditionalflowmodel method)": [[3, "grl.generative_models.OptimalTransportConditionalFlowModel.flow_matching_loss_small_batch_OT_plan"]], "flow_matching_loss_with_mask() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.flow_matching_loss_with_mask"]], "forward_sample() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.forward_sample"]], "forward_sample() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.forward_sample"]], "forward_sample_process() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.forward_sample_process"]], "forward_sample_process() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.forward_sample_process"]], "grl.generative_models": [[3, "module-grl.generative_models"]], "log_prob() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.log_prob"]], "log_prob() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.log_prob"]], "noise_function() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.noise_function"]], "noise_function() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.noise_function"]], "noise_function_with_energy_guidance() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.noise_function_with_energy_guidance"]], "optimal_transport_flow_matching_loss() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.optimal_transport_flow_matching_loss"]], "sample() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.sample"]], "sample() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.sample"]], "sample() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.sample"]], "sample() (grl.generative_models.optimaltransportconditionalflowmodel method)": [[3, "grl.generative_models.OptimalTransportConditionalFlowModel.sample"]], "sample_forward_process() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.sample_forward_process"]], "sample_forward_process() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.sample_forward_process"]], "sample_forward_process() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.sample_forward_process"]], "sample_forward_process() (grl.generative_models.optimaltransportconditionalflowmodel method)": [[3, "grl.generative_models.OptimalTransportConditionalFlowModel.sample_forward_process"]], "sample_forward_process_with_fixed_x() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.sample_forward_process_with_fixed_x"]], "sample_forward_process_with_fixed_x() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.sample_forward_process_with_fixed_x"]], "sample_with_fixed_x() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.sample_with_fixed_x"]], "sample_with_fixed_x() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.sample_with_fixed_x"]], "sample_with_fixed_x_without_energy_guidance() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.sample_with_fixed_x_without_energy_guidance"]], "sample_with_log_prob() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.sample_with_log_prob"]], "sample_with_log_prob() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.sample_with_log_prob"]], "sample_with_mask() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.sample_with_mask"]], "sample_with_mask_forward_process() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.sample_with_mask_forward_process"]], "sample_without_energy_guidance() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.sample_without_energy_guidance"]], "score_function() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.score_function"]], "score_function() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.score_function"]], "score_function_with_energy_guidance() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.score_function_with_energy_guidance"]], "score_matching_loss() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.score_matching_loss"]], "score_matching_loss() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.score_matching_loss"]], "velocity_function() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.velocity_function"]], "velocity_function() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.velocity_function"]], "concatenatelayer (class in grl.neural_network)": [[4, "grl.neural_network.ConcatenateLayer"]], "concatenatemlp (class in grl.neural_network)": [[4, "grl.neural_network.ConcatenateMLP"]], "dit (class in grl.neural_network)": [[4, "grl.neural_network.DiT"]], "dit1d (class in grl.neural_network)": [[4, "grl.neural_network.DiT1D"]], "dit2d (in module grl.neural_network)": [[4, "grl.neural_network.DiT2D"]], "dit3d (class in grl.neural_network)": [[4, "grl.neural_network.DiT3D"]], "multilayerperceptron (class in grl.neural_network)": [[4, "grl.neural_network.MultiLayerPerceptron"]], "temporalspatialresidualnet (class in grl.neural_network)": [[4, "grl.neural_network.TemporalSpatialResidualNet"]], "__init__() (grl.neural_network.concatenatelayer method)": [[4, "grl.neural_network.ConcatenateLayer.__init__"]], "__init__() (grl.neural_network.concatenatemlp method)": [[4, "grl.neural_network.ConcatenateMLP.__init__"]], "__init__() (grl.neural_network.dit method)": [[4, "grl.neural_network.DiT.__init__"]], "__init__() (grl.neural_network.dit1d method)": [[4, "grl.neural_network.DiT1D.__init__"]], "__init__() (grl.neural_network.dit3d method)": [[4, "grl.neural_network.DiT3D.__init__"]], "__init__() (grl.neural_network.multilayerperceptron method)": [[4, "grl.neural_network.MultiLayerPerceptron.__init__"]], "__init__() (grl.neural_network.temporalspatialresidualnet method)": [[4, "grl.neural_network.TemporalSpatialResidualNet.__init__"]], "forward() (grl.neural_network.concatenatelayer method)": [[4, "grl.neural_network.ConcatenateLayer.forward"]], "forward() (grl.neural_network.concatenatemlp method)": [[4, "grl.neural_network.ConcatenateMLP.forward"]], "forward() (grl.neural_network.dit method)": [[4, "grl.neural_network.DiT.forward"]], "forward() (grl.neural_network.dit1d method)": [[4, "grl.neural_network.DiT1D.forward"]], "forward() (grl.neural_network.dit3d method)": [[4, "grl.neural_network.DiT3D.forward"]], "forward() (grl.neural_network.multilayerperceptron method)": [[4, "grl.neural_network.MultiLayerPerceptron.forward"]], "forward() (grl.neural_network.temporalspatialresidualnet method)": [[4, "grl.neural_network.TemporalSpatialResidualNet.forward"]], "forward_with_cfg() (grl.neural_network.dit method)": [[4, "grl.neural_network.DiT.forward_with_cfg"]], "grl.neural_network": [[4, "module-grl.neural_network"]], "initialize_weights() (grl.neural_network.dit method)": [[4, "grl.neural_network.DiT.initialize_weights"]], "initialize_weights() (grl.neural_network.dit1d method)": [[4, "grl.neural_network.DiT1D.initialize_weights"]], "initialize_weights() (grl.neural_network.dit3d method)": [[4, "grl.neural_network.DiT3D.initialize_weights"]], "unpatchify() (grl.neural_network.dit method)": [[4, "grl.neural_network.DiT.unpatchify"]], "unpatchify() (grl.neural_network.dit3d method)": [[4, "grl.neural_network.DiT3D.unpatchify"]], "dpmsolver (class in grl.numerical_methods)": [[5, "grl.numerical_methods.DPMSolver"]], "gaussianconditionalprobabilitypath (class in grl.numerical_methods)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath"]], "halflogsnr() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.HalfLogSNR"]], "inversehalflogsnr() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.InverseHalfLogSNR"]], "ode (class in grl.numerical_methods)": [[5, "grl.numerical_methods.ODE"]], "odesolver (class in grl.numerical_methods)": [[5, "grl.numerical_methods.ODESolver"]], "sde (class in grl.numerical_methods)": [[5, "grl.numerical_methods.SDE"]], "sdesolver (class in grl.numerical_methods)": [[5, "grl.numerical_methods.SDESolver"]], "__init__() (grl.numerical_methods.dpmsolver method)": [[5, "grl.numerical_methods.DPMSolver.__init__"]], "__init__() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.__init__"]], "__init__() (grl.numerical_methods.ode method)": [[5, "grl.numerical_methods.ODE.__init__"]], "__init__() (grl.numerical_methods.odesolver method)": [[5, "grl.numerical_methods.ODESolver.__init__"]], "__init__() (grl.numerical_methods.sde method)": [[5, "grl.numerical_methods.SDE.__init__"]], "__init__() (grl.numerical_methods.sdesolver method)": [[5, "grl.numerical_methods.SDESolver.__init__"]], "covariance() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.covariance"]], "d_covariance_dt() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.d_covariance_dt"]], "d_log_scale_dt() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.d_log_scale_dt"]], "d_scale_dt() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.d_scale_dt"]], "d_std_dt() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.d_std_dt"]], "diffusion() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.diffusion"]], "diffusion_squared() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.diffusion_squared"]], "drift() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.drift"]], "drift_coefficient() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.drift_coefficient"]], "grl.numerical_methods": [[5, "module-grl.numerical_methods"]], "integrate() (grl.numerical_methods.dpmsolver method)": [[5, "grl.numerical_methods.DPMSolver.integrate"]], "integrate() (grl.numerical_methods.odesolver method)": [[5, "grl.numerical_methods.ODESolver.integrate"]], "integrate() (grl.numerical_methods.sdesolver method)": [[5, "grl.numerical_methods.SDESolver.integrate"]], "log_scale() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.log_scale"]], "scale() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.scale"]], "std() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.std"]], "grl.utils": [[7, "module-grl.utils"]], "set_seed() (in module grl.utils)": [[7, "grl.utils.set_seed"]]}}) \ No newline at end of file +Search.setIndex({"docnames": ["api_doc/agents/index", "api_doc/algorithms/index", "api_doc/datasets/index", "api_doc/generative_models/index", "api_doc/neural_network/index", "api_doc/numerical_methods/index", "api_doc/rl_modules/index", "api_doc/utils/index", "concepts/index", "index", "tutorials/installation/index", "tutorials/quick_start/index", "user_guide/evaluating_agents", "user_guide/index", "user_guide/installation", "user_guide/training_agents", "user_guide/training_generative_models"], "filenames": ["api_doc/agents/index.rst", "api_doc/algorithms/index.rst", "api_doc/datasets/index.rst", "api_doc/generative_models/index.rst", "api_doc/neural_network/index.rst", "api_doc/numerical_methods/index.rst", "api_doc/rl_modules/index.rst", "api_doc/utils/index.rst", "concepts/index.rst", "index.rst", "tutorials/installation/index.rst", "tutorials/quick_start/index.rst", "user_guide/evaluating_agents.rst", "user_guide/index.rst", "user_guide/installation.rst", "user_guide/training_agents.rst", "user_guide/training_generative_models.rst"], "titles": ["grl.agents", "grl.algorithms", "grl.datasets", "grl.generative_models", "grl.neural_network", "grl.numerical_methods", "grl.rl_modules", "grl.utils", "Concepts", "GenerativeRL Documentation", "Installation", "Quick Start", "How to evaluate RL agents performance", "User Guide", "How to install GenerativeRL and its dependencies", "How to train and deploy reinforcement learning agents", "How to train generative models"], "terms": {"class": [0, 1, 2, 3, 4, 5, 6, 8, 12, 15, 16], "config": [0, 1, 3, 5, 6, 11, 12, 15, 16], "model": [0, 1, 3, 4, 5, 9, 13, 15], "sourc": [0, 1, 2, 3, 4, 5, 6, 7, 8], "overview": [0, 1, 2, 3, 4, 5, 6, 7], "The": [0, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 16], "qgpo": [0, 1, 2, 11, 15], "algorithm": [0, 2, 8, 9, 11, 12, 15], "interfac": [0, 1, 2, 3, 4, 5, 6, 8, 11], "__init__": [0, 1, 2, 3, 4, 5, 6, 16], "action": [0, 1, 2, 6, 8, 12], "initi": [0, 1, 2, 3, 4, 5, 6, 8, 11, 16], "paramet": [0, 1, 2, 3, 4, 5, 6, 7, 8], "easydict": [0, 1, 3, 5, 6, 16], "configur": [0, 1, 3, 5, 6, 8, 11, 15, 16], "union": [0, 1, 3, 4, 5, 6], "torch": [0, 1, 3, 4, 5, 6], "nn": [0, 1, 3, 5, 6, 16], "modul": [0, 1, 3, 5, 6, 12, 15, 16], "moduledict": [0, 1, 3], "act": [0, 11, 12, 15], "ob": 0, "return_as_torch_tensor": 0, "fals": [0, 1, 3, 4, 5, 6, 7], "given": [0, 1, 3, 5, 6, 8], "an": [0, 1, 3, 4, 5, 8, 11, 12, 15], "observ": [0, 6, 8, 11, 12, 15], "return": [0, 1, 3, 4, 5, 6, 7, 12, 15, 16], "np": 0, "ndarrai": 0, "tensor": [0, 1, 3, 4, 5, 6, 8], "dict": [0, 1, 6, 16], "bool": [0, 1, 3, 4, 5, 6, 7], "whether": [0, 1, 3, 4, 5, 6, 7, 8, 14], "type": [0, 1, 3, 4, 5, 6, 7, 8, 16], "srpo": [0, 1], "train": [0, 1, 2, 3, 6, 8, 9, 11, 13], "gener": [0, 1, 2, 3, 5, 7, 9, 13], "polici": [0, 1, 2, 3, 6, 8, 11], "thi": [0, 1, 3, 4, 6, 8, 9, 11, 16], "i": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 16], "design": [0, 3, 9], "us": [0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16], "gmpgalgorithm": [0, 9], "gmpoalgorithm": [0, 9], "numpi": 0, "arrai": 0, "critic": [1, 8], "network": [1, 4, 6, 8, 13], "forward": [1, 3, 4, 6, 8, 16], "compute_double_q": [1, 6], "state": [1, 3, 5, 6, 11], "none": [1, 3, 4, 5, 6, 7, 16], "output": [1, 3, 4, 5, 6, 16], "two": [1, 3, 6], "q": [1, 6, 11], "tensordict": [1, 3, 4, 5, 6], "input": [1, 3, 4, 5, 6, 12, 15, 16], "first": [1, 5, 6, 11], "q2": [1, 6], "second": [1, 6], "q1": [1, 6], "q_loss": 1, "reward": [1, 8, 11, 12, 15], "next_stat": 1, "done": [1, 11, 12, 15], "fake_next_act": 1, "discount_factor": 1, "1": [1, 3, 4, 5, 14, 16], "0": [1, 3, 4, 5, 14, 16], "calcul": [1, 6], "loss": [1, 3, 6, 16], "next": 1, "fake": [1, 2], "float": [1, 3, 4, 5], "discount": 1, "factor": [1, 4, 5], "sampl": [1, 2, 3, 5, 8, 16], "behaviour_policy_sampl": 1, "compute_q": 1, "behaviour_policy_loss": 1, "energy_guidance_loss": [1, 3], "intern": [1, 6], "share": [1, 6], "both": [1, 3, 6, 16], "scriptmodul": [1, 6], "behaviour": [1, 2], "batch_siz": [1, 3], "solver_config": [1, 3], "t_span": [1, 3, 5], "which": [1, 2, 3, 5, 6, 8, 12, 14, 16], "condit": [1, 3, 4, 5, 6, 16], "od": [1, 8, 9], "solver": [1, 3, 5, 16], "time": [1, 3, 5, 8, 16], "span": [1, 3], "sde": [1, 8, 9, 16], "valu": [1, 3, 5, 6, 7, 8], "energi": [1, 2, 3], "guidanc": [1, 3, 4], "guidance_scal": [1, 3], "scale": [1, 3, 4, 5, 6, 16], "simul": [1, 6, 12], "dataset": [1, 8, 9, 11, 15], "guid": [1, 3, 11], "optim": [1, 3, 8, 11], "offlin": [1, 3, 8], "reinforc": [1, 3, 9, 12, 13, 14], "learn": [1, 3, 4, 9, 12, 13, 14], "base": [1, 2, 3, 5, 8], "diffus": [1, 3, 4, 5, 8, 9, 11, 16], "deploi": [1, 8, 9, 11, 12, 13], "must": 1, "contain": 1, "follow": [1, 5, 8, 14, 15, 16], "kei": [1, 8], "deploy": [1, 8, 11], "object": [1, 8, 13], "environ": [1, 2, 6, 8, 11, 12, 14, 15], "qgpodataset": [1, 9], "qgpoagent": [1, 9], "A": [1, 6, 11, 16], "weight": [1, 3, 4, 8, 16], "bia": 1, "run": [1, 6, 8, 12, 14], "creat": [1, 8, 11, 15, 16], "automat": [1, 8], "when": 1, "function": [1, 3, 4, 5, 6, 8, 11, 16], "call": [1, 11, 16], "v_loss": [1, 6], "srpo_actor_loss": 1, "srpoagent": [1, 9], "gmpo": 1, "includ": [1, 8, 11, 16], "optin": 1, "policy_optimization_loss_by_advantage_weighted_regress": 1, "policy_optimization_loss_by_advantage_weighted_regression_softmax": 1, "maximum_likelihood": [1, 3], "with_grad": [1, 3], "size": [1, 3, 4], "int": [1, 3, 4, 5, 6, 7], "tupl": [1, 3, 4], "list": [1, 3, 4, 6, 13], "batch": [1, 3, 4], "gradient": [1, 3, 8], "beta": [1, 3, 5], "weight_clamp": 1, "100": 1, "fake_act": 1, "seed": [1, 7, 8], "gpg": 1, "gpdataset": [1, 9], "random": [1, 6, 7, 8], "gpo": 1, "in_support_ql_loss": 1, "gmpg": 1, "env_id": [2, 6, 11, 12, 15], "contrast": [2, 3], "predict": [2, 3, 5], "need": [2, 6, 8, 11, 15], "true": [2, 3, 4, 5, 7], "from": [2, 3, 5, 8, 10, 11, 15, 16], "support": [2, 3, 5, 11, 16], "__getitem__": 2, "__len__": 2, "method": [2, 3, 5, 6, 8, 11, 12, 15, 16], "str": [2, 3, 5, 6], "id": [2, 6], "d4rl": [2, 14], "sometim": 2, "data": [2, 3, 4, 5, 8, 16], "augment": [2, 8], "variou": [3, 8], "continu": [3, 8], "path": [3, 5, 8, 13], "comput": [3, 5, 8], "score": [3, 8, 16], "veloc": [3, 16], "It": [3, 5, 6, 14], "can": [3, 5, 8, 10, 11, 12, 14, 15, 16], "via": 3, "nois": [3, 5, 16], "match": [3, 8, 16], "flow": [3, 5, 8, 9, 16], "ar": [3, 8, 15, 16], "score_funct": 3, "score_matching_loss": [3, 16], "velocity_funct": 3, "flow_matching_loss": [3, 16], "data_prediction_funct": [3, 5], "t": [3, 4, 5, 8, 16], "x": [3, 4, 5, 16], "frac": [3, 5, 16], "sigma": [3, 4, 5, 8, 16], "x_t": [3, 8, 16], "2": [3, 4, 5, 14, 16], "nabla_": [3, 16], "log": [3, 5, 8, 11, 16], "p_": 3, "theta": [3, 16], "": [3, 5, 11], "treetensor": 3, "dpo_loss": 3, "ref_dm": 3, "process": [3, 5, 6, 8, 11, 16], "direct": [3, 8], "dpo": 3, "develop": [3, 10], "featur": 3, "recommend": [3, 5], "averag": 3, "across": [3, 8], "forward_sampl": 3, "note": [3, 14], "revers": [3, 8, 16], "thu": 3, "form": [3, 5, 16], "rather": 3, "encod": 3, "latent": [3, 4], "space": [3, 5], "forward_sample_process": 3, "all": [3, 4], "intermedi": [3, 5], "log_prob": 3, "using_hutchinson_trace_estim": 3, "probabl": [3, 4, 5], "noise_funct": [3, 5, 16], "x_0": [3, 16], "final": [3, 4, 5], "provid": [3, 7, 8, 9, 11, 15, 16], "gaussian": [3, 5, 16], "distribut": [3, 8, 16], "result": [3, 8], "shape": [3, 4, 5, 16], "where": [3, 5, 16], "number": [3, 4, 5, 6, 11], "step": [3, 5, 6, 11, 12, 14, 15], "b": 3, "could": 3, "scalar": [3, 5], "b1": 3, "b2": 3, "n": [3, 4, 5, 11], "d": [3, 5, 16], "dimens": [3, 4, 8], "d1": 3, "d2": 3, "extra": 3, "If": [3, 5, 7], "sample_forward_process": 3, "repeat": 3, "same": [3, 5, 16], "sample_forward_process_with_fixed_x": 3, "fixed_x": 3, "fixed_mask": 3, "fix": [3, 8], "mask": 3, "sample_with_fixed_x": 3, "sample_with_log_prob": 3, "likelihood": [3, 8, 16], "weighting_schem": 3, "uncondit": [3, 4], "scheme": 3, "vanilla": 3, "maximum": [3, 8, 16], "estim": [3, 8, 16], "refer": [3, 11, 13, 15], "paper": 3, "more": [3, 5, 11, 13, 15], "detail": [3, 5, 11, 13], "lambda": [3, 5, 16], "denot": [3, 16], "g": [3, 5, 8, 14, 16], "numer": [3, 8], "stabil": 3, "we": [3, 5, 8, 11, 16], "mont": 3, "carlo": 3, "approxim": [3, 5, 6], "integr": [3, 5, 8], "p": [3, 5, 14, 16], "balanc": 3, "mse": 3, "through": [3, 4, 8, 11], "stochast": [3, 5, 8, 13], "differenti": [3, 5, 8, 16], "equat": [3, 5, 8, 16], "v_": [3, 16], "energy_model": 3, "text": [3, 16], "e": [3, 5, 8, 14, 16], "c": [3, 4, 14, 16], "sim": 3, "exp": 3, "mathcal": [3, 5, 16], "z": 3, "sample_without_energy_guid": 3, "score_function_with_energy_guid": 3, "data_prediction_function_with_energy_guid": 3, "cep": 3, "propos": 3, "exact": 3, "noise_function_with_energy_guid": 3, "nose": 3, "nabla": 3, "sample_with_fixed_x_without_energy_guid": 3, "without": [3, 5], "independ": [3, 8, 16], "get_typ": 3, "x0": [3, 5], "x1": 3, "flow_matching_loss_with_mask": 3, "signal": [3, 8], "either": 3, "ha": [3, 5, 6, 12, 15, 16], "correspond": 3, "element": [3, 4, 5], "usual": [3, 16], "x_1": [3, 16], "log_prob_x_0": 3, "function_log_prob_x_0": 3, "callabl": [3, 5, 6], "hutchinson": 3, "trace": 3, "jacobian": 3, "drift": [3, 5, 8, 16], "faster": 3, "less": 3, "accur": [3, 8], "set": [3, 7, 16], "high": [3, 5, 8], "dimension": 3, "log_likelihood": 3, "optimal_transport_flow_matching_loss": 3, "transport": 3, "plan": 3, "sample_with_mask": 3, "sample_with_mask_forward_process": 3, "between": [3, 5, 8, 16], "flow_matching_loss_small_batch_ot_plan": 3, "small": [3, 6], "acceler": 3, "concaten": 4, "along": 4, "last": [4, 5, 6], "layer": [4, 16], "hidden_s": [4, 16], "output_s": 4, "activ": 4, "dropout": 4, "layernorm": 4, "final_activ": 4, "shrink": 4, "multi": 4, "perceptron": 4, "fulli": 4, "connect": 4, "fc1": 4, "act1": 4, "fcn": 4, "actn": 4, "out": 4, "hidden": 4, "channel": 4, "option": [4, 7], "zero": 4, "default": [4, 7, 15], "block": [4, 11], "shrinkag": 4, "kwarg": [4, 5], "pass": 4, "mlp": [4, 16], "keyword": 4, "argument": [4, 5], "output_dim": [4, 16], "t_dim": [4, 16], "input_dim": 4, "condition_dim": 4, "condition_hidden_dim": 4, "t_condition_hidden_dim": 4, "tempor": 4, "spatial": 4, "residu": 4, "multipl": 4, "temporalspatialresblock": 4, "input_s": 4, "32": 4, "patch_siz": 4, "in_channel": 4, "4": [4, 5], "1152": 4, "depth": 4, "28": 4, "num_head": 4, "16": 4, "mlp_ratio": 4, "class_dropout_prob": 4, "num_class": 4, "1000": 4, "learn_sigma": 4, "transform": [4, 8, 16], "backbon": [4, 16], "offici": 4, "implement": [4, 8, 12, 15], "github": [4, 10, 14], "repo": 4, "http": [4, 10, 14], "com": [4, 10, 14], "facebookresearch": 4, "blob": 4, "main": 4, "py": [4, 14], "patch": 4, "attent": 4, "head": 4, "respect": 4, "timestep": 4, "imag": [4, 8, 14], "represent": 4, "label": 4, "forward_with_cfg": 4, "cfg_scale": 4, "also": [4, 5, 8, 10, 15, 16], "classifi": [4, 8], "free": [4, 8], "initialize_weight": 4, "unpatchifi": 4, "img": 4, "h": 4, "w": [4, 16], "token_s": 4, "condition_embedd": 4, "1d": 4, "3d": 4, "inform": [4, 6, 11, 13, 15], "origin": 4, "video": [4, 8], "alia": 4, "patch_block_s": 4, "10": 4, "convolv": 4, "each": [4, 6], "token": 4, "total_patch": 4, "ordinari": [5, 8], "defin": [5, 8, 15, 16], "dx": 5, "f": [5, 8, 16], "dt": [5, 8, 16], "term": 5, "dw": 5, "wiener": [5, 16], "order": 5, "devic": [5, 16], "atol": 5, "1e": 5, "05": 5, "rtol": 5, "dpm_solver": 5, "singlestep": 5, "solver_typ": 5, "skip_typ": 5, "time_uniform": 5, "denois": [5, 8, 16], "dpm": 5, "should": 5, "3": [5, 14], "absolut": 5, "toler": 5, "adapt": [5, 8], "rel": 5, "total": 5, "evalu": [5, 6, 8, 9, 11, 13, 15], "nfe": 5, "multistep": 5, "singlestep_fix": 5, "taylor": 5, "slightli": 5, "impact": 5, "perform": [5, 9, 13, 15], "logsnr": 5, "time_quadrat": 5, "diffusion_process": 5, "save_intermedi": 5, "diffusionprocess": 5, "t_start": 5, "solut": 5, "t_end": 5, "x_end": 5, "ode_solv": 5, "euler": [5, 8], "01": 5, "librari": [5, 8, 9, 11, 14, 16], "torchdyn": [5, 8, 16], "torchdiffeq": [5, 8], "current": [5, 16], "addit": [5, 14], "For": [5, 8, 11, 13, 14, 15, 16], "exampl": [5, 8, 11, 13, 14, 16], "trajectori": 5, "len": [5, 16], "sde_solv": 5, "sde_noise_typ": 5, "diagon": 5, "sde_typ": 5, "ito": 5, "001": 5, "torchsd": 5, "stratonovich": 5, "logqp": 5, "case": [5, 6, 11], "mu": 5, "written": 5, "mathrm": [5, 16], "w_": 5, "sqrt": 5, "covari": 5, "matrix": 5, "standard": [5, 16], "deviat": 5, "half": 5, "differ": [5, 8, 13], "vp": [5, 16], "int_": [5, 16], "linear": [5, 16], "todo": 5, "add": 5, "cosin": 5, "ve": 5, "opt": 5, "halflogsnr": 5, "inversehalflogsnr": 5, "invers": 5, "sinc": 5, "invert": 5, "beta_1": [5, 16], "beta_0": [5, 16], "d_covariance_dt": 5, "deriv": [5, 16], "d_log_scale_dt": 5, "d_scale_dt": 5, "d_std_dt": 5, "diffusion_squar": 5, "drift_coeffici": 5, "coeffici": [5, 16], "satisfi": 5, "log_scal": 5, "std": 5, "simpl": [6, 8, 11], "gym": [6, 8, 11, 12, 14, 15], "generativerl": [6, 8, 10, 12, 13, 15, 16], "collect": [6, 12], "episod": 6, "singl": [6, 8], "suitabl": 6, "experi": 6, "collect_episod": 6, "collect_step": 6, "accord": 6, "num_episod": 6, "num_step": 6, "sever": 6, "reset": [6, 11, 12, 15], "begin": 6, "No": 6, "histori": 6, "store": 6, "dictionari": [6, 8, 15], "random_polici": 6, "until": 6, "end": 6, "render_arg": 6, "resultswil": 6, "one": 6, "shot": 6, "mean": [6, 16], "bellman": 6, "backup": 6, "compute_double_v": 6, "v2": 6, "v1": 6, "v": [6, 16], "doubl": 6, "compute_mininum_v": 6, "minimum": 6, "minimum_v": 6, "compute_mininum_q": 6, "minimum_q": 6, "seed_valu": 7, "cudnn_determinist": 7, "cudnn_benchmark": 7, "make": [7, 8, 9, 11, 12, 14, 15], "cudnn": 7, "oper": 7, "determinist": 7, "enabl": [7, 8], "benchmark": 7, "convolut": 7, "framework": [8, 9], "consist": 8, "code": [8, 16], "api": [8, 11, 13, 15, 16], "user": [8, 12, 16], "friendli": 8, "rl": [8, 9, 13, 15], "agent": [8, 9, 11, 13], "In": [8, 11, 12, 15, 16], "section": [8, 11, 13, 15], "explor": 8, "core": 8, "discuss": 8, "underpin": 8, "how": [8, 9, 11, 13], "thei": 8, "leverag": 8, "address": 8, "complex": 8, "problem": [8, 9, 14], "field": [8, 16], "addition": 8, "explain": 8, "why": 8, "import": [8, 11, 12, 14, 15, 16], "what": 8, "uniqu": 8, "wide": 8, "rang": [8, 11, 12, 15, 16], "applic": [8, 16], "machin": 8, "new": [8, 11, 16], "typic": [8, 16], "most": 8, "unsupervis": 8, "techniqu": 8, "appli": 8, "task": 8, "audio": 8, "interpol": 8, "focus": 8, "dynam": 8, "These": [8, 16], "have": [8, 14], "capac": 8, "captur": 8, "demonstr": [8, 11], "promis": 8, "varieti": 8, "its": [8, 9, 11, 13, 15, 16], "variant": [8, 16], "qualiti": 8, "solv": [8, 9, 14], "dx_t": [8, 16], "dw_t": [8, 16], "unifi": [8, 12, 16], "howev": 8, "vari": 8, "definit": 8, "some": [8, 14], "under": [8, 12, 15], "common": 8, "while": [8, 15, 16], "other": [8, 11, 13, 15], "mai": 8, "requir": [8, 14], "specif": [8, 11, 15], "There": 8, "four": 8, "open": 8, "neural": [8, 13], "parameter": [8, 13], "certain": 8, "part": 8, "potenti": 8, "determin": 8, "procedur": 8, "fundament": 8, "maxim": 8, "pretrain": 8, "like": 8, "bridg": [8, 16], "fine": 8, "tune": 8, "advantag": 8, "regress": 8, "adjoint": 8, "involv": 8, "depend": [8, 9, 13], "maruyama": 8, "rung": 8, "kutta": 8, "offer": 8, "flexibl": [8, 11], "allow": 8, "custom": [8, 13], "extend": 8, "suit": [8, 11, 14], "instanc": [8, 11, 15], "easili": 8, "own": [8, 16], "architectur": [8, 16], "tailor": 8, "format": [8, 11], "decis": [8, 9], "interact": [8, 15], "receiv": 8, "penalti": 8, "cumul": 8, "take": [8, 12, 15, 16], "updat": 8, "categor": 8, "directli": [8, 16], "onlin": 8, "strategi": 8, "off": 8, "actor": 8, "research": 8, "improv": 8, "effici": 8, "synthet": 8, "decoupl": 8, "littl": 8, "modif": 8, "rank": 8, "least": 8, "pytorch": [8, 14], "unif": 8, "within": [8, 11], "simplic": 8, "intuit": 8, "extens": 8, "modular": 8, "built": 8, "mix": 8, "compon": [8, 11], "reproduc": 8, "ensur": 8, "checkpoint": 8, "possibl": 8, "minim": 8, "seek": 8, "extern": 8, "lightweight": 8, "instal": [8, 9, 13], "platform": 8, "compat": 8, "exist": 8, "work": [8, 16], "seamlessli": 8, "openai": [8, 11], "torchrl": 8, "python": [9, 14], "aim": 9, "combin": 9, "power": [9, 11], "capabl": 9, "quick": 9, "start": 9, "explan": 9, "principl": 9, "grl": [9, 10, 11, 12, 14, 15, 16], "gpagent": 9, "qgpocrit": 9, "qgpopolici": 9, "qgpoalgorithm": [9, 11, 15], "srpocrit": 9, "srpopolici": 9, "srpoalgorithm": 9, "gmpocrit": 9, "gmpopolici": 9, "gmpgcritic": 9, "gmpgpolici": 9, "qgpod4rldataset": 9, "gpd4rldataset": 9, "generative_model": [9, 16], "diffusionmodel": [9, 16], "energyconditionaldiffusionmodel": 9, "independentconditionalflowmodel": 9, "optimaltransportconditionalflowmodel": 9, "neural_network": [9, 16], "concatenatelay": 9, "multilayerperceptron": 9, "concatenatemlp": 9, "temporalspatialresidualnet": [9, 16], "dit": 9, "dit1d": 9, "dit2d": 9, "dit3d": 9, "numerical_method": [9, 16], "dpmsolver": 9, "odesolv": [9, 16], "sdesolv": 9, "gaussianconditionalprobabilitypath": [9, 16], "rl_modul": 9, "gymenvsimul": 9, "oneshotvaluefunct": 9, "vnetwork": 9, "doublevnetwork": 9, "qnetwork": 9, "doubleqnetwork": 9, "util": [9, 11], "set_se": 9, "pip": [10, 14], "you": [10, 11, 14, 15], "latest": 10, "version": [10, 14], "git": [10, 14], "opendilab": [10, 14], "easi": 11, "swiss": 11, "roll": 11, "colab": 11, "usag": [11, 13], "found": 11, "folder": 11, "grl_pipelin": [11, 15], "tutori": 11, "here": [11, 13, 16], "halfcheetah": 11, "diffusion_model": [11, 15, 16], "d4rl_halfcheetah_qgpo": [11, 15], "def": [11, 16], "qgpo_pipelin": 11, "env": [11, 12, 15], "_": [11, 12, 15, 16], "num_deploy_step": [11, 12, 15], "render": [11, 12, 15], "__name__": 11, "__main__": 11, "info": 11, "necessari": 11, "well": 11, "encapsul": 11, "after": [11, 14], "obtain": [11, 16], "loop": 11, "execut": 11, "specifi": 11, "print": 11, "consol": 11, "modifi": 11, "your": 11, "advanc": [11, 13], "pleas": [11, 13, 15], "document": [11, 13, 15], "9": 14, "higher": 14, "command": 14, "clone": 14, "cd": 14, "pybullet": 14, "mujoco": 14, "deepmind": 14, "control": 14, "etc": 14, "dm_control": 14, "setup": 14, "licens": 14, "special": 14, "23": 14, "anoth": 14, "thing": 14, "sudo": 14, "apt": 14, "get": 14, "libgl1": 14, "mesa": 14, "glx": 14, "libglib2": 14, "libsm6": 14, "libxext6": 14, "libxrend": 14, "dev": 14, "y": 14, "swig": 14, "gcc": 14, "local": 14, "dnsutil": 14, "cmake": 14, "build": 14, "essenti": 14, "libglew": 14, "libosmesa6": 14, "libglfw3": 14, "libsdl2": 14, "libglm": 14, "libfreetype6": 14, "patchelf": 14, "ffmpeg": 14, "mkdir": 14, "root": 14, "wget": 14, "org": 14, "download": 14, "mujoco210": 14, "linux": 14, "x86_64": 14, "tar": 14, "gz": 14, "o": 14, "xf": 14, "export": 14, "ld_library_path": 14, "mjpro210": 14, "bin": 14, "farama": 14, "foundat": 14, "lockfil": 14, "cython": 14, "check": 14, "success": 14, "everi": [15, 16], "hyperparamet": 15, "copi": 15, "trained_model": 15, "divers": 16, "describ": 16, "evolut": 16, "over": 16, "increment": 16, "probability_path": 16, "kind": 16, "varianc": 16, "preserv": 16, "gvp": 16, "usal": 16, "want": 16, "normal": 16, "target": 16, "By": 16, "fokker": 16, "planck": 16, "kolmogorov": 16, "fpk": 16, "hat": 16, "_t": 16, "s_": 16, "codebas": 16, "ddpm": 16, "compar": 16, "Or": 16, "nerual": 16, "therefor": 16, "intrinsicmodel": 16, "ani": 16, "cnn": 16, "u": 16, "net": 16, "x_size": 16, "alpha": 16, "arg": 16, "linear_vp_sd": 16, "20": 16, "t_encod": 16, "512": 16, "256": 16, "128": 16, "t_embedding_dim": 16, "register_modul": 16, "regist": 16, "so": 16, "mymodul": 16, "self": 16, "super": 16, "modulelist": 16, "append": 16, "relu": 16, "mle": 16, "onli": 16, "squar": 16, "error": 16, "l": 16, "dsm": 16, "mathbb": 16, "left": 16, "right": 16, "cfm": 16, "simpli": 16}, "objects": {"grl": [[0, 0, 0, "-", "agents"], [1, 0, 0, "-", "algorithms"], [2, 0, 0, "-", "datasets"], [3, 0, 0, "-", "generative_models"], [4, 0, 0, "-", "neural_network"], [5, 0, 0, "-", "numerical_methods"], [6, 0, 0, "-", "rl_modules"], [7, 0, 0, "-", "utils"]], "grl.agents": [[0, 1, 1, "", "GPAgent"], [0, 1, 1, "", "QGPOAgent"], [0, 1, 1, "", "SRPOAgent"]], "grl.agents.GPAgent": [[0, 2, 1, "", "__init__"], [0, 2, 1, "", "act"]], "grl.agents.QGPOAgent": [[0, 2, 1, "", "__init__"], [0, 2, 1, "", "act"]], "grl.agents.SRPOAgent": [[0, 2, 1, "", "__init__"], [0, 2, 1, "", "act"]], "grl.algorithms": [[1, 1, 1, "", "GMPGAlgorithm"], [1, 1, 1, "", "GMPGCritic"], [1, 1, 1, "", "GMPGPolicy"], [1, 1, 1, "", "GMPOAlgorithm"], [1, 1, 1, "", "GMPOCritic"], [1, 1, 1, "", "GMPOPolicy"], [1, 1, 1, "", "QGPOAlgorithm"], [1, 1, 1, "", "QGPOCritic"], [1, 1, 1, "", "QGPOPolicy"], [1, 1, 1, "", "SRPOAlgorithm"], [1, 1, 1, "", "SRPOCritic"], [1, 1, 1, "", "SRPOPolicy"]], "grl.algorithms.GMPGAlgorithm": [[1, 2, 1, "", "__init__"], [1, 2, 1, "", "train"]], "grl.algorithms.GMPGCritic": [[1, 2, 1, "", "__init__"], [1, 2, 1, "", "compute_double_q"], [1, 2, 1, "", "forward"], [1, 2, 1, "", "in_support_ql_loss"]], "grl.algorithms.GMPGPolicy": [[1, 2, 1, "", "__init__"], [1, 2, 1, "", "behaviour_policy_loss"], [1, 2, 1, "", "behaviour_policy_sample"], [1, 2, 1, "", "compute_q"], [1, 2, 1, "", "forward"], [1, 2, 1, "", "sample"]], "grl.algorithms.GMPOAlgorithm": [[1, 2, 1, "", "__init__"], [1, 2, 1, "", "train"]], "grl.algorithms.GMPOCritic": [[1, 2, 1, "", "__init__"], [1, 2, 1, "", "compute_double_q"], [1, 2, 1, "", "forward"]], "grl.algorithms.GMPOPolicy": [[1, 2, 1, "", "__init__"], [1, 2, 1, "", "behaviour_policy_loss"], [1, 2, 1, "", "behaviour_policy_sample"], [1, 2, 1, "", "compute_q"], [1, 2, 1, "", "forward"], [1, 2, 1, "", "policy_optimization_loss_by_advantage_weighted_regression"], [1, 2, 1, "", "policy_optimization_loss_by_advantage_weighted_regression_softmax"], [1, 2, 1, "", "sample"]], "grl.algorithms.QGPOAlgorithm": [[1, 2, 1, "", "__init__"], [1, 2, 1, "", "deploy"], [1, 2, 1, "", "train"]], "grl.algorithms.QGPOCritic": [[1, 2, 1, "", "__init__"], [1, 2, 1, "", "compute_double_q"], [1, 2, 1, "", "forward"], [1, 2, 1, "", "q_loss"]], "grl.algorithms.QGPOPolicy": [[1, 2, 1, "", "__init__"], [1, 2, 1, "", "behaviour_policy_loss"], [1, 2, 1, "", "behaviour_policy_sample"], [1, 2, 1, "", "compute_q"], [1, 2, 1, "", "energy_guidance_loss"], [1, 2, 1, "", "forward"], [1, 2, 1, "", "q_loss"], [1, 2, 1, "", "sample"]], "grl.algorithms.SRPOAlgorithm": [[1, 2, 1, "", "__init__"], [1, 2, 1, "", "deploy"], [1, 2, 1, "", "train"]], "grl.algorithms.SRPOCritic": [[1, 2, 1, "", "__init__"], [1, 2, 1, "", "forward"]], "grl.algorithms.SRPOPolicy": [[1, 2, 1, "", "__init__"], [1, 2, 1, "", "behaviour_policy_loss"], [1, 2, 1, "", "forward"], [1, 2, 1, "", "sample"], [1, 2, 1, "", "srpo_actor_loss"]], "grl.datasets": [[2, 1, 1, "", "GPD4RLDataset"], [2, 1, 1, "", "GPDataset"], [2, 1, 1, "", "QGPOD4RLDataset"], [2, 1, 1, "", "QGPODataset"]], "grl.datasets.GPD4RLDataset": [[2, 2, 1, "", "__init__"]], "grl.datasets.GPDataset": [[2, 2, 1, "", "__init__"]], "grl.datasets.QGPOD4RLDataset": [[2, 2, 1, "", "__init__"]], "grl.datasets.QGPODataset": [[2, 2, 1, "", "__init__"]], "grl.generative_models": [[3, 1, 1, "", "DiffusionModel"], [3, 1, 1, "", "EnergyConditionalDiffusionModel"], [3, 1, 1, "", "IndependentConditionalFlowModel"], [3, 1, 1, "", "OptimalTransportConditionalFlowModel"]], "grl.generative_models.DiffusionModel": [[3, 2, 1, "", "__init__"], [3, 2, 1, "", "data_prediction_function"], [3, 2, 1, "", "dpo_loss"], [3, 2, 1, "", "flow_matching_loss"], [3, 2, 1, "", "forward_sample"], [3, 2, 1, "", "forward_sample_process"], [3, 2, 1, "", "log_prob"], [3, 2, 1, "", "noise_function"], [3, 2, 1, "", "sample"], [3, 2, 1, "", "sample_forward_process"], [3, 2, 1, "", "sample_forward_process_with_fixed_x"], [3, 2, 1, "", "sample_with_fixed_x"], [3, 2, 1, "", "sample_with_log_prob"], [3, 2, 1, "", "score_function"], [3, 2, 1, "", "score_matching_loss"], [3, 2, 1, "", "velocity_function"]], "grl.generative_models.EnergyConditionalDiffusionModel": [[3, 2, 1, "", "__init__"], [3, 2, 1, "", "data_prediction_function"], [3, 2, 1, "", "data_prediction_function_with_energy_guidance"], [3, 2, 1, "", "energy_guidance_loss"], [3, 2, 1, "", "flow_matching_loss"], [3, 2, 1, "", "noise_function"], [3, 2, 1, "", "noise_function_with_energy_guidance"], [3, 2, 1, "", "sample"], [3, 2, 1, "", "sample_forward_process"], [3, 2, 1, "", "sample_forward_process_with_fixed_x"], [3, 2, 1, "", "sample_with_fixed_x"], [3, 2, 1, "", "sample_with_fixed_x_without_energy_guidance"], [3, 2, 1, "", "sample_without_energy_guidance"], [3, 2, 1, "", "score_function"], [3, 2, 1, "", "score_function_with_energy_guidance"], [3, 2, 1, "", "score_matching_loss"], [3, 2, 1, "", "velocity_function"]], "grl.generative_models.IndependentConditionalFlowModel": [[3, 2, 1, "", "__init__"], [3, 2, 1, "", "flow_matching_loss"], [3, 2, 1, "", "flow_matching_loss_with_mask"], [3, 2, 1, "", "forward_sample"], [3, 2, 1, "", "forward_sample_process"], [3, 2, 1, "", "log_prob"], [3, 2, 1, "", "optimal_transport_flow_matching_loss"], [3, 2, 1, "", "sample"], [3, 2, 1, "", "sample_forward_process"], [3, 2, 1, "", "sample_with_log_prob"], [3, 2, 1, "", "sample_with_mask"], [3, 2, 1, "", "sample_with_mask_forward_process"]], "grl.generative_models.OptimalTransportConditionalFlowModel": [[3, 2, 1, "", "__init__"], [3, 2, 1, "", "flow_matching_loss"], [3, 2, 1, "", "flow_matching_loss_small_batch_OT_plan"], [3, 2, 1, "", "sample"], [3, 2, 1, "", "sample_forward_process"]], "grl.neural_network": [[4, 1, 1, "", "ConcatenateLayer"], [4, 1, 1, "", "ConcatenateMLP"], [4, 1, 1, "", "DiT"], [4, 1, 1, "", "DiT1D"], [4, 3, 1, "", "DiT2D"], [4, 1, 1, "", "DiT3D"], [4, 1, 1, "", "MultiLayerPerceptron"], [4, 1, 1, "", "TemporalSpatialResidualNet"]], "grl.neural_network.ConcatenateLayer": [[4, 2, 1, "", "__init__"], [4, 2, 1, "", "forward"]], "grl.neural_network.ConcatenateMLP": [[4, 2, 1, "", "__init__"], [4, 2, 1, "", "forward"]], "grl.neural_network.DiT": [[4, 2, 1, "", "__init__"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "forward_with_cfg"], [4, 2, 1, "", "initialize_weights"], [4, 2, 1, "", "unpatchify"]], "grl.neural_network.DiT1D": [[4, 2, 1, "", "__init__"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "initialize_weights"]], "grl.neural_network.DiT3D": [[4, 2, 1, "", "__init__"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "initialize_weights"], [4, 2, 1, "", "unpatchify"]], "grl.neural_network.MultiLayerPerceptron": [[4, 2, 1, "", "__init__"], [4, 2, 1, "", "forward"]], "grl.neural_network.TemporalSpatialResidualNet": [[4, 2, 1, "", "__init__"], [4, 2, 1, "", "forward"]], "grl.numerical_methods": [[5, 1, 1, "", "DPMSolver"], [5, 1, 1, "", "GaussianConditionalProbabilityPath"], [5, 1, 1, "", "ODE"], [5, 1, 1, "", "ODESolver"], [5, 1, 1, "", "SDE"], [5, 1, 1, "", "SDESolver"]], "grl.numerical_methods.DPMSolver": [[5, 2, 1, "", "__init__"], [5, 2, 1, "", "integrate"]], "grl.numerical_methods.GaussianConditionalProbabilityPath": [[5, 2, 1, "", "HalfLogSNR"], [5, 2, 1, "", "InverseHalfLogSNR"], [5, 2, 1, "", "__init__"], [5, 2, 1, "", "covariance"], [5, 2, 1, "", "d_covariance_dt"], [5, 2, 1, "", "d_log_scale_dt"], [5, 2, 1, "", "d_scale_dt"], [5, 2, 1, "", "d_std_dt"], [5, 2, 1, "", "diffusion"], [5, 2, 1, "", "diffusion_squared"], [5, 2, 1, "", "drift"], [5, 2, 1, "", "drift_coefficient"], [5, 2, 1, "", "log_scale"], [5, 2, 1, "", "scale"], [5, 2, 1, "", "std"]], "grl.numerical_methods.ODE": [[5, 2, 1, "", "__init__"]], "grl.numerical_methods.ODESolver": [[5, 2, 1, "", "__init__"], [5, 2, 1, "", "integrate"]], "grl.numerical_methods.SDE": [[5, 2, 1, "", "__init__"]], "grl.numerical_methods.SDESolver": [[5, 2, 1, "", "__init__"], [5, 2, 1, "", "integrate"]], "grl.rl_modules": [[6, 1, 1, "", "DoubleQNetwork"], [6, 1, 1, "", "DoubleVNetwork"], [6, 1, 1, "", "GymEnvSimulator"], [6, 1, 1, "", "OneShotValueFunction"], [6, 1, 1, "", "QNetwork"], [6, 1, 1, "", "VNetwork"]], "grl.rl_modules.DoubleQNetwork": [[6, 2, 1, "", "__init__"], [6, 2, 1, "", "compute_double_q"], [6, 2, 1, "", "compute_mininum_q"], [6, 2, 1, "", "forward"]], "grl.rl_modules.DoubleVNetwork": [[6, 2, 1, "", "__init__"], [6, 2, 1, "", "compute_double_v"], [6, 2, 1, "", "compute_mininum_v"], [6, 2, 1, "", "forward"]], "grl.rl_modules.GymEnvSimulator": [[6, 2, 1, "", "__init__"], [6, 2, 1, "", "collect_episodes"], [6, 2, 1, "", "collect_steps"], [6, 2, 1, "", "evaluate"]], "grl.rl_modules.OneShotValueFunction": [[6, 2, 1, "", "__init__"], [6, 2, 1, "", "compute_double_v"], [6, 2, 1, "", "forward"], [6, 2, 1, "", "v_loss"]], "grl.rl_modules.QNetwork": [[6, 2, 1, "", "__init__"], [6, 2, 1, "", "forward"]], "grl.rl_modules.VNetwork": [[6, 2, 1, "", "__init__"], [6, 2, 1, "", "forward"]], "grl.utils": [[7, 4, 1, "", "set_seed"]]}, "objtypes": {"0": "py:module", "1": "py:class", "2": "py:method", "3": "py:attribute", "4": "py:function"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "class", "Python class"], "2": ["py", "method", "Python method"], "3": ["py", "attribute", "Python attribute"], "4": ["py", "function", "Python function"]}, "titleterms": {"grl": [0, 1, 2, 3, 4, 5, 6, 7], "agent": [0, 12, 15], "qgpoagent": 0, "srpoagent": 0, "gpagent": 0, "algorithm": 1, "qgpocrit": 1, "qgpopolici": 1, "qgpoalgorithm": 1, "srpocrit": 1, "srpopolici": 1, "srpoalgorithm": 1, "gmpocrit": 1, "gmpopolici": 1, "gmpoalgorithm": 1, "gmpgcritic": 1, "gmpgpolici": 1, "gmpgalgorithm": 1, "dataset": 2, "qgpod4rldataset": 2, "qgpodataset": 2, "gpd4rldataset": 2, "gpdataset": 2, "generative_model": 3, "diffusionmodel": 3, "energyconditionaldiffusionmodel": 3, "independentconditionalflowmodel": 3, "optimaltransportconditionalflowmodel": 3, "neural_network": 4, "concatenatelay": 4, "multilayerperceptron": 4, "concatenatemlp": 4, "temporalspatialresidualnet": 4, "dit": 4, "dit1d": 4, "dit2d": 4, "dit3d": 4, "numerical_method": 5, "od": 5, "sde": 5, "dpmsolver": 5, "odesolv": 5, "sdesolv": 5, "gaussianconditionalprobabilitypath": 5, "rl_modul": 6, "gymenvsimul": 6, "oneshotvaluefunct": 6, "vnetwork": 6, "doublevnetwork": 6, "qnetwork": 6, "doubleqnetwork": 6, "util": 7, "set_se": 7, "concept": [8, 9], "overview": [8, 9], "gener": [8, 11, 16], "model": [8, 11, 16], "reinforc": [8, 11, 15], "learn": [8, 11, 15], "design": 8, "principl": 8, "generativerl": [9, 11, 14], "document": 9, "tutori": 9, "user": [9, 13], "guid": [9, 13], "api": 9, "instal": [10, 14], "quick": 11, "start": 11, "explan": 11, "how": [12, 14, 15, 16], "evalu": 12, "rl": 12, "perform": 12, "its": 14, "depend": 14, "train": [15, 16], "deploi": 15, "stochast": 16, "path": 16, "parameter": 16, "custom": 16, "neural": 16, "network": 16, "object": 16, "differ": 16}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 8, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.viewcode": 1, "sphinx.ext.todo": 2, "sphinx.ext.intersphinx": 1, "nbsphinx": 4, "sphinx": 57}, "alltitles": {"grl.agents": [[0, "module-grl.agents"]], "QGPOAgent": [[0, "qgpoagent"]], "SRPOAgent": [[0, "srpoagent"]], "GPAgent": [[0, "gpagent"]], "grl.algorithms": [[1, "module-grl.algorithms"]], "QGPOCritic": [[1, "qgpocritic"]], "QGPOPolicy": [[1, "qgpopolicy"]], "QGPOAlgorithm": [[1, "qgpoalgorithm"]], "SRPOCritic": [[1, "srpocritic"]], "SRPOPolicy": [[1, "srpopolicy"]], "SRPOAlgorithm": [[1, "srpoalgorithm"]], "GMPOCritic": [[1, "gmpocritic"]], "GMPOPolicy": [[1, "gmpopolicy"]], "GMPOAlgorithm": [[1, "gmpoalgorithm"]], "GMPGCritic": [[1, "gmpgcritic"]], "GMPGPolicy": [[1, "gmpgpolicy"]], "GMPGAlgorithm": [[1, "gmpgalgorithm"]], "grl.datasets": [[2, "module-grl.datasets"]], "QGPOD4RLDataset": [[2, "qgpod4rldataset"]], "QGPODataset": [[2, "qgpodataset"]], "GPD4RLDataset": [[2, "gpd4rldataset"]], "GPDataset": [[2, "gpdataset"]], "grl.generative_models": [[3, "module-grl.generative_models"]], "DiffusionModel": [[3, "diffusionmodel"]], "EnergyConditionalDiffusionModel": [[3, "energyconditionaldiffusionmodel"]], "IndependentConditionalFlowModel": [[3, "independentconditionalflowmodel"]], "OptimalTransportConditionalFlowModel": [[3, "optimaltransportconditionalflowmodel"]], "grl.neural_network": [[4, "module-grl.neural_network"]], "ConcatenateLayer": [[4, "concatenatelayer"]], "MultiLayerPerceptron": [[4, "multilayerperceptron"]], "ConcatenateMLP": [[4, "concatenatemlp"]], "TemporalSpatialResidualNet": [[4, "temporalspatialresidualnet"]], "DiT": [[4, "dit"]], "DiT1D": [[4, "dit1d"]], "DiT2D": [[4, "dit2d"]], "DiT3D": [[4, "dit3d"]], "grl.numerical_methods": [[5, "module-grl.numerical_methods"]], "ODE": [[5, "ode"]], "SDE": [[5, "sde"]], "DPMSolver": [[5, "dpmsolver"]], "ODESolver": [[5, "odesolver"]], "SDESolver": [[5, "sdesolver"]], "GaussianConditionalProbabilityPath": [[5, "gaussianconditionalprobabilitypath"]], "grl.rl_modules": [[6, "module-grl.rl_modules"]], "GymEnvSimulator": [[6, "gymenvsimulator"]], "OneShotValueFunction": [[6, "oneshotvaluefunction"]], "VNetwork": [[6, "vnetwork"]], "DoubleVNetwork": [[6, "doublevnetwork"]], "QNetwork": [[6, "qnetwork"]], "DoubleQNetwork": [[6, "doubleqnetwork"]], "grl.utils": [[7, "module-grl.utils"]], "set_seed": [[7, "set-seed"]], "Concepts": [[8, "concepts"], [9, null]], "Concepts Overview": [[8, "concepts-overview"]], "Generative Models": [[8, "generative-models"]], "Reinforcement Learning": [[8, "reinforcement-learning"], [11, "reinforcement-learning"]], "Design Principles": [[8, "design-principles"]], "GenerativeRL Documentation": [[9, "generativerl-documentation"]], "Overview": [[9, "overview"]], "Tutorials": [[9, null]], "User Guide": [[9, null], [13, "user-guide"], [13, null]], "API Documentation": [[9, null]], "Installation": [[10, "installation"]], "Quick Start": [[11, "quick-start"]], "Generative model in GenerativeRL": [[11, "generative-model-in-generativerl"]], "Explanation": [[11, "explanation"]], "How to evaluate RL agents performance": [[12, "how-to-evaluate-rl-agents-performance"]], "How to install GenerativeRL and its dependencies": [[14, "how-to-install-generativerl-and-its-dependencies"]], "How to train and deploy reinforcement learning agents": [[15, "how-to-train-and-deploy-reinforcement-learning-agents"]], "How to train generative models": [[16, "how-to-train-generative-models"]], "Stochastic path": [[16, "stochastic-path"]], "Model parameterization": [[16, "model-parameterization"]], "Customized neural network": [[16, "customized-neural-network"]], "Training objective for different generative models": [[16, "training-objective-for-different-generative-models"]]}, "indexentries": {"gpagent (class in grl.agents)": [[0, "grl.agents.GPAgent"]], "qgpoagent (class in grl.agents)": [[0, "grl.agents.QGPOAgent"]], "srpoagent (class in grl.agents)": [[0, "grl.agents.SRPOAgent"]], "__init__() (grl.agents.gpagent method)": [[0, "grl.agents.GPAgent.__init__"]], "__init__() (grl.agents.qgpoagent method)": [[0, "grl.agents.QGPOAgent.__init__"]], "__init__() (grl.agents.srpoagent method)": [[0, "grl.agents.SRPOAgent.__init__"]], "act() (grl.agents.gpagent method)": [[0, "grl.agents.GPAgent.act"]], "act() (grl.agents.qgpoagent method)": [[0, "grl.agents.QGPOAgent.act"]], "act() (grl.agents.srpoagent method)": [[0, "grl.agents.SRPOAgent.act"]], "grl.agents": [[0, "module-grl.agents"]], "module": [[0, "module-grl.agents"], [1, "module-grl.algorithms"], [2, "module-grl.datasets"], [3, "module-grl.generative_models"], [4, "module-grl.neural_network"], [5, "module-grl.numerical_methods"], [6, "module-grl.rl_modules"], [7, "module-grl.utils"]], "gmpgalgorithm (class in grl.algorithms)": [[1, "grl.algorithms.GMPGAlgorithm"]], "gmpgcritic (class in grl.algorithms)": [[1, "grl.algorithms.GMPGCritic"]], "gmpgpolicy (class in grl.algorithms)": [[1, "grl.algorithms.GMPGPolicy"]], "gmpoalgorithm (class in grl.algorithms)": [[1, "grl.algorithms.GMPOAlgorithm"]], "gmpocritic (class in grl.algorithms)": [[1, "grl.algorithms.GMPOCritic"]], "gmpopolicy (class in grl.algorithms)": [[1, "grl.algorithms.GMPOPolicy"]], "qgpoalgorithm (class in grl.algorithms)": [[1, "grl.algorithms.QGPOAlgorithm"]], "qgpocritic (class in grl.algorithms)": [[1, "grl.algorithms.QGPOCritic"]], "qgpopolicy (class in grl.algorithms)": [[1, "grl.algorithms.QGPOPolicy"]], "srpoalgorithm (class in grl.algorithms)": [[1, "grl.algorithms.SRPOAlgorithm"]], "srpocritic (class in grl.algorithms)": [[1, "grl.algorithms.SRPOCritic"]], "srpopolicy (class in grl.algorithms)": [[1, "grl.algorithms.SRPOPolicy"]], "__init__() (grl.algorithms.gmpgalgorithm method)": [[1, "grl.algorithms.GMPGAlgorithm.__init__"]], "__init__() (grl.algorithms.gmpgcritic method)": [[1, "grl.algorithms.GMPGCritic.__init__"]], "__init__() (grl.algorithms.gmpgpolicy method)": [[1, "grl.algorithms.GMPGPolicy.__init__"]], "__init__() (grl.algorithms.gmpoalgorithm method)": [[1, "grl.algorithms.GMPOAlgorithm.__init__"]], "__init__() (grl.algorithms.gmpocritic method)": [[1, "grl.algorithms.GMPOCritic.__init__"]], "__init__() (grl.algorithms.gmpopolicy method)": [[1, "grl.algorithms.GMPOPolicy.__init__"]], "__init__() (grl.algorithms.qgpoalgorithm method)": [[1, "grl.algorithms.QGPOAlgorithm.__init__"]], "__init__() (grl.algorithms.qgpocritic method)": [[1, "grl.algorithms.QGPOCritic.__init__"]], "__init__() (grl.algorithms.qgpopolicy method)": [[1, "grl.algorithms.QGPOPolicy.__init__"]], "__init__() (grl.algorithms.srpoalgorithm method)": [[1, "grl.algorithms.SRPOAlgorithm.__init__"]], "__init__() (grl.algorithms.srpocritic method)": [[1, "grl.algorithms.SRPOCritic.__init__"]], "__init__() (grl.algorithms.srpopolicy method)": [[1, "grl.algorithms.SRPOPolicy.__init__"]], "behaviour_policy_loss() (grl.algorithms.gmpgpolicy method)": [[1, "grl.algorithms.GMPGPolicy.behaviour_policy_loss"]], "behaviour_policy_loss() (grl.algorithms.gmpopolicy method)": [[1, "grl.algorithms.GMPOPolicy.behaviour_policy_loss"]], "behaviour_policy_loss() (grl.algorithms.qgpopolicy method)": [[1, "grl.algorithms.QGPOPolicy.behaviour_policy_loss"]], "behaviour_policy_loss() (grl.algorithms.srpopolicy method)": [[1, "grl.algorithms.SRPOPolicy.behaviour_policy_loss"]], "behaviour_policy_sample() (grl.algorithms.gmpgpolicy method)": [[1, "grl.algorithms.GMPGPolicy.behaviour_policy_sample"]], "behaviour_policy_sample() (grl.algorithms.gmpopolicy method)": [[1, "grl.algorithms.GMPOPolicy.behaviour_policy_sample"]], "behaviour_policy_sample() (grl.algorithms.qgpopolicy method)": [[1, "grl.algorithms.QGPOPolicy.behaviour_policy_sample"]], "compute_double_q() (grl.algorithms.gmpgcritic method)": [[1, "grl.algorithms.GMPGCritic.compute_double_q"]], "compute_double_q() (grl.algorithms.gmpocritic method)": [[1, "grl.algorithms.GMPOCritic.compute_double_q"]], "compute_double_q() (grl.algorithms.qgpocritic method)": [[1, "grl.algorithms.QGPOCritic.compute_double_q"]], "compute_q() (grl.algorithms.gmpgpolicy method)": [[1, "grl.algorithms.GMPGPolicy.compute_q"]], "compute_q() (grl.algorithms.gmpopolicy method)": [[1, "grl.algorithms.GMPOPolicy.compute_q"]], "compute_q() (grl.algorithms.qgpopolicy method)": [[1, "grl.algorithms.QGPOPolicy.compute_q"]], "deploy() (grl.algorithms.qgpoalgorithm method)": [[1, "grl.algorithms.QGPOAlgorithm.deploy"]], "deploy() (grl.algorithms.srpoalgorithm method)": [[1, "grl.algorithms.SRPOAlgorithm.deploy"]], "energy_guidance_loss() (grl.algorithms.qgpopolicy method)": [[1, "grl.algorithms.QGPOPolicy.energy_guidance_loss"]], "forward() (grl.algorithms.gmpgcritic method)": [[1, "grl.algorithms.GMPGCritic.forward"]], "forward() (grl.algorithms.gmpgpolicy method)": [[1, "grl.algorithms.GMPGPolicy.forward"]], "forward() (grl.algorithms.gmpocritic method)": [[1, "grl.algorithms.GMPOCritic.forward"]], "forward() (grl.algorithms.gmpopolicy method)": [[1, "grl.algorithms.GMPOPolicy.forward"]], "forward() (grl.algorithms.qgpocritic method)": [[1, "grl.algorithms.QGPOCritic.forward"]], "forward() (grl.algorithms.qgpopolicy method)": [[1, "grl.algorithms.QGPOPolicy.forward"]], "forward() (grl.algorithms.srpocritic method)": [[1, "grl.algorithms.SRPOCritic.forward"]], "forward() (grl.algorithms.srpopolicy method)": [[1, "grl.algorithms.SRPOPolicy.forward"]], "grl.algorithms": [[1, "module-grl.algorithms"]], "in_support_ql_loss() (grl.algorithms.gmpgcritic method)": [[1, "grl.algorithms.GMPGCritic.in_support_ql_loss"]], "policy_optimization_loss_by_advantage_weighted_regression() (grl.algorithms.gmpopolicy method)": [[1, "grl.algorithms.GMPOPolicy.policy_optimization_loss_by_advantage_weighted_regression"]], "policy_optimization_loss_by_advantage_weighted_regression_softmax() (grl.algorithms.gmpopolicy method)": [[1, "grl.algorithms.GMPOPolicy.policy_optimization_loss_by_advantage_weighted_regression_softmax"]], "q_loss() (grl.algorithms.qgpocritic method)": [[1, "grl.algorithms.QGPOCritic.q_loss"]], "q_loss() (grl.algorithms.qgpopolicy method)": [[1, "grl.algorithms.QGPOPolicy.q_loss"]], "sample() (grl.algorithms.gmpgpolicy method)": [[1, "grl.algorithms.GMPGPolicy.sample"]], "sample() (grl.algorithms.gmpopolicy method)": [[1, "grl.algorithms.GMPOPolicy.sample"]], "sample() (grl.algorithms.qgpopolicy method)": [[1, "grl.algorithms.QGPOPolicy.sample"]], "sample() (grl.algorithms.srpopolicy method)": [[1, "grl.algorithms.SRPOPolicy.sample"]], "srpo_actor_loss() (grl.algorithms.srpopolicy method)": [[1, "grl.algorithms.SRPOPolicy.srpo_actor_loss"]], "train() (grl.algorithms.gmpgalgorithm method)": [[1, "grl.algorithms.GMPGAlgorithm.train"]], "train() (grl.algorithms.gmpoalgorithm method)": [[1, "grl.algorithms.GMPOAlgorithm.train"]], "train() (grl.algorithms.qgpoalgorithm method)": [[1, "grl.algorithms.QGPOAlgorithm.train"]], "train() (grl.algorithms.srpoalgorithm method)": [[1, "grl.algorithms.SRPOAlgorithm.train"]], "gpd4rldataset (class in grl.datasets)": [[2, "grl.datasets.GPD4RLDataset"]], "gpdataset (class in grl.datasets)": [[2, "grl.datasets.GPDataset"]], "qgpod4rldataset (class in grl.datasets)": [[2, "grl.datasets.QGPOD4RLDataset"]], "qgpodataset (class in grl.datasets)": [[2, "grl.datasets.QGPODataset"]], "__init__() (grl.datasets.gpd4rldataset method)": [[2, "grl.datasets.GPD4RLDataset.__init__"]], "__init__() (grl.datasets.gpdataset method)": [[2, "grl.datasets.GPDataset.__init__"]], "__init__() (grl.datasets.qgpod4rldataset method)": [[2, "grl.datasets.QGPOD4RLDataset.__init__"]], "__init__() (grl.datasets.qgpodataset method)": [[2, "grl.datasets.QGPODataset.__init__"]], "grl.datasets": [[2, "module-grl.datasets"]], "diffusionmodel (class in grl.generative_models)": [[3, "grl.generative_models.DiffusionModel"]], "energyconditionaldiffusionmodel (class in grl.generative_models)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel"]], "independentconditionalflowmodel (class in grl.generative_models)": [[3, "grl.generative_models.IndependentConditionalFlowModel"]], "optimaltransportconditionalflowmodel (class in grl.generative_models)": [[3, "grl.generative_models.OptimalTransportConditionalFlowModel"]], "__init__() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.__init__"]], "__init__() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.__init__"]], "__init__() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.__init__"]], "__init__() (grl.generative_models.optimaltransportconditionalflowmodel method)": [[3, "grl.generative_models.OptimalTransportConditionalFlowModel.__init__"]], "data_prediction_function() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.data_prediction_function"]], "data_prediction_function() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.data_prediction_function"]], "data_prediction_function_with_energy_guidance() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.data_prediction_function_with_energy_guidance"]], "dpo_loss() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.dpo_loss"]], "energy_guidance_loss() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.energy_guidance_loss"]], "flow_matching_loss() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.flow_matching_loss"]], "flow_matching_loss() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.flow_matching_loss"]], "flow_matching_loss() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.flow_matching_loss"]], "flow_matching_loss() (grl.generative_models.optimaltransportconditionalflowmodel method)": [[3, "grl.generative_models.OptimalTransportConditionalFlowModel.flow_matching_loss"]], "flow_matching_loss_small_batch_ot_plan() (grl.generative_models.optimaltransportconditionalflowmodel method)": [[3, "grl.generative_models.OptimalTransportConditionalFlowModel.flow_matching_loss_small_batch_OT_plan"]], "flow_matching_loss_with_mask() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.flow_matching_loss_with_mask"]], "forward_sample() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.forward_sample"]], "forward_sample() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.forward_sample"]], "forward_sample_process() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.forward_sample_process"]], "forward_sample_process() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.forward_sample_process"]], "grl.generative_models": [[3, "module-grl.generative_models"]], "log_prob() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.log_prob"]], "log_prob() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.log_prob"]], "noise_function() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.noise_function"]], "noise_function() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.noise_function"]], "noise_function_with_energy_guidance() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.noise_function_with_energy_guidance"]], "optimal_transport_flow_matching_loss() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.optimal_transport_flow_matching_loss"]], "sample() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.sample"]], "sample() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.sample"]], "sample() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.sample"]], "sample() (grl.generative_models.optimaltransportconditionalflowmodel method)": [[3, "grl.generative_models.OptimalTransportConditionalFlowModel.sample"]], "sample_forward_process() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.sample_forward_process"]], "sample_forward_process() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.sample_forward_process"]], "sample_forward_process() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.sample_forward_process"]], "sample_forward_process() (grl.generative_models.optimaltransportconditionalflowmodel method)": [[3, "grl.generative_models.OptimalTransportConditionalFlowModel.sample_forward_process"]], "sample_forward_process_with_fixed_x() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.sample_forward_process_with_fixed_x"]], "sample_forward_process_with_fixed_x() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.sample_forward_process_with_fixed_x"]], "sample_with_fixed_x() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.sample_with_fixed_x"]], "sample_with_fixed_x() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.sample_with_fixed_x"]], "sample_with_fixed_x_without_energy_guidance() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.sample_with_fixed_x_without_energy_guidance"]], "sample_with_log_prob() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.sample_with_log_prob"]], "sample_with_log_prob() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.sample_with_log_prob"]], "sample_with_mask() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.sample_with_mask"]], "sample_with_mask_forward_process() (grl.generative_models.independentconditionalflowmodel method)": [[3, "grl.generative_models.IndependentConditionalFlowModel.sample_with_mask_forward_process"]], "sample_without_energy_guidance() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.sample_without_energy_guidance"]], "score_function() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.score_function"]], "score_function() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.score_function"]], "score_function_with_energy_guidance() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.score_function_with_energy_guidance"]], "score_matching_loss() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.score_matching_loss"]], "score_matching_loss() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.score_matching_loss"]], "velocity_function() (grl.generative_models.diffusionmodel method)": [[3, "grl.generative_models.DiffusionModel.velocity_function"]], "velocity_function() (grl.generative_models.energyconditionaldiffusionmodel method)": [[3, "grl.generative_models.EnergyConditionalDiffusionModel.velocity_function"]], "concatenatelayer (class in grl.neural_network)": [[4, "grl.neural_network.ConcatenateLayer"]], "concatenatemlp (class in grl.neural_network)": [[4, "grl.neural_network.ConcatenateMLP"]], "dit (class in grl.neural_network)": [[4, "grl.neural_network.DiT"]], "dit1d (class in grl.neural_network)": [[4, "grl.neural_network.DiT1D"]], "dit2d (in module grl.neural_network)": [[4, "grl.neural_network.DiT2D"]], "dit3d (class in grl.neural_network)": [[4, "grl.neural_network.DiT3D"]], "multilayerperceptron (class in grl.neural_network)": [[4, "grl.neural_network.MultiLayerPerceptron"]], "temporalspatialresidualnet (class in grl.neural_network)": [[4, "grl.neural_network.TemporalSpatialResidualNet"]], "__init__() (grl.neural_network.concatenatelayer method)": [[4, "grl.neural_network.ConcatenateLayer.__init__"]], "__init__() (grl.neural_network.concatenatemlp method)": [[4, "grl.neural_network.ConcatenateMLP.__init__"]], "__init__() (grl.neural_network.dit method)": [[4, "grl.neural_network.DiT.__init__"]], "__init__() (grl.neural_network.dit1d method)": [[4, "grl.neural_network.DiT1D.__init__"]], "__init__() (grl.neural_network.dit3d method)": [[4, "grl.neural_network.DiT3D.__init__"]], "__init__() (grl.neural_network.multilayerperceptron method)": [[4, "grl.neural_network.MultiLayerPerceptron.__init__"]], "__init__() (grl.neural_network.temporalspatialresidualnet method)": [[4, "grl.neural_network.TemporalSpatialResidualNet.__init__"]], "forward() (grl.neural_network.concatenatelayer method)": [[4, "grl.neural_network.ConcatenateLayer.forward"]], "forward() (grl.neural_network.concatenatemlp method)": [[4, "grl.neural_network.ConcatenateMLP.forward"]], "forward() (grl.neural_network.dit method)": [[4, "grl.neural_network.DiT.forward"]], "forward() (grl.neural_network.dit1d method)": [[4, "grl.neural_network.DiT1D.forward"]], "forward() (grl.neural_network.dit3d method)": [[4, "grl.neural_network.DiT3D.forward"]], "forward() (grl.neural_network.multilayerperceptron method)": [[4, "grl.neural_network.MultiLayerPerceptron.forward"]], "forward() (grl.neural_network.temporalspatialresidualnet method)": [[4, "grl.neural_network.TemporalSpatialResidualNet.forward"]], "forward_with_cfg() (grl.neural_network.dit method)": [[4, "grl.neural_network.DiT.forward_with_cfg"]], "grl.neural_network": [[4, "module-grl.neural_network"]], "initialize_weights() (grl.neural_network.dit method)": [[4, "grl.neural_network.DiT.initialize_weights"]], "initialize_weights() (grl.neural_network.dit1d method)": [[4, "grl.neural_network.DiT1D.initialize_weights"]], "initialize_weights() (grl.neural_network.dit3d method)": [[4, "grl.neural_network.DiT3D.initialize_weights"]], "unpatchify() (grl.neural_network.dit method)": [[4, "grl.neural_network.DiT.unpatchify"]], "unpatchify() (grl.neural_network.dit3d method)": [[4, "grl.neural_network.DiT3D.unpatchify"]], "dpmsolver (class in grl.numerical_methods)": [[5, "grl.numerical_methods.DPMSolver"]], "gaussianconditionalprobabilitypath (class in grl.numerical_methods)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath"]], "halflogsnr() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.HalfLogSNR"]], "inversehalflogsnr() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.InverseHalfLogSNR"]], "ode (class in grl.numerical_methods)": [[5, "grl.numerical_methods.ODE"]], "odesolver (class in grl.numerical_methods)": [[5, "grl.numerical_methods.ODESolver"]], "sde (class in grl.numerical_methods)": [[5, "grl.numerical_methods.SDE"]], "sdesolver (class in grl.numerical_methods)": [[5, "grl.numerical_methods.SDESolver"]], "__init__() (grl.numerical_methods.dpmsolver method)": [[5, "grl.numerical_methods.DPMSolver.__init__"]], "__init__() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.__init__"]], "__init__() (grl.numerical_methods.ode method)": [[5, "grl.numerical_methods.ODE.__init__"]], "__init__() (grl.numerical_methods.odesolver method)": [[5, "grl.numerical_methods.ODESolver.__init__"]], "__init__() (grl.numerical_methods.sde method)": [[5, "grl.numerical_methods.SDE.__init__"]], "__init__() (grl.numerical_methods.sdesolver method)": [[5, "grl.numerical_methods.SDESolver.__init__"]], "covariance() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.covariance"]], "d_covariance_dt() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.d_covariance_dt"]], "d_log_scale_dt() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.d_log_scale_dt"]], "d_scale_dt() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.d_scale_dt"]], "d_std_dt() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.d_std_dt"]], "diffusion() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.diffusion"]], "diffusion_squared() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.diffusion_squared"]], "drift() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.drift"]], "drift_coefficient() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.drift_coefficient"]], "grl.numerical_methods": [[5, "module-grl.numerical_methods"]], "integrate() (grl.numerical_methods.dpmsolver method)": [[5, "grl.numerical_methods.DPMSolver.integrate"]], "integrate() (grl.numerical_methods.odesolver method)": [[5, "grl.numerical_methods.ODESolver.integrate"]], "integrate() (grl.numerical_methods.sdesolver method)": [[5, "grl.numerical_methods.SDESolver.integrate"]], "log_scale() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.log_scale"]], "scale() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.scale"]], "std() (grl.numerical_methods.gaussianconditionalprobabilitypath method)": [[5, "grl.numerical_methods.GaussianConditionalProbabilityPath.std"]], "doubleqnetwork (class in grl.rl_modules)": [[6, "grl.rl_modules.DoubleQNetwork"]], "doublevnetwork (class in grl.rl_modules)": [[6, "grl.rl_modules.DoubleVNetwork"]], "gymenvsimulator (class in grl.rl_modules)": [[6, "grl.rl_modules.GymEnvSimulator"]], "oneshotvaluefunction (class in grl.rl_modules)": [[6, "grl.rl_modules.OneShotValueFunction"]], "qnetwork (class in grl.rl_modules)": [[6, "grl.rl_modules.QNetwork"]], "vnetwork (class in grl.rl_modules)": [[6, "grl.rl_modules.VNetwork"]], "__init__() (grl.rl_modules.doubleqnetwork method)": [[6, "grl.rl_modules.DoubleQNetwork.__init__"]], "__init__() (grl.rl_modules.doublevnetwork method)": [[6, "grl.rl_modules.DoubleVNetwork.__init__"]], "__init__() (grl.rl_modules.gymenvsimulator method)": [[6, "grl.rl_modules.GymEnvSimulator.__init__"]], "__init__() (grl.rl_modules.oneshotvaluefunction method)": [[6, "grl.rl_modules.OneShotValueFunction.__init__"]], "__init__() (grl.rl_modules.qnetwork method)": [[6, "grl.rl_modules.QNetwork.__init__"]], "__init__() (grl.rl_modules.vnetwork method)": [[6, "grl.rl_modules.VNetwork.__init__"]], "collect_episodes() (grl.rl_modules.gymenvsimulator method)": [[6, "grl.rl_modules.GymEnvSimulator.collect_episodes"]], "collect_steps() (grl.rl_modules.gymenvsimulator method)": [[6, "grl.rl_modules.GymEnvSimulator.collect_steps"]], "compute_double_q() (grl.rl_modules.doubleqnetwork method)": [[6, "grl.rl_modules.DoubleQNetwork.compute_double_q"]], "compute_double_v() (grl.rl_modules.doublevnetwork method)": [[6, "grl.rl_modules.DoubleVNetwork.compute_double_v"]], "compute_double_v() (grl.rl_modules.oneshotvaluefunction method)": [[6, "grl.rl_modules.OneShotValueFunction.compute_double_v"]], "compute_mininum_q() (grl.rl_modules.doubleqnetwork method)": [[6, "grl.rl_modules.DoubleQNetwork.compute_mininum_q"]], "compute_mininum_v() (grl.rl_modules.doublevnetwork method)": [[6, "grl.rl_modules.DoubleVNetwork.compute_mininum_v"]], "evaluate() (grl.rl_modules.gymenvsimulator method)": [[6, "grl.rl_modules.GymEnvSimulator.evaluate"]], "forward() (grl.rl_modules.doubleqnetwork method)": [[6, "grl.rl_modules.DoubleQNetwork.forward"]], "forward() (grl.rl_modules.doublevnetwork method)": [[6, "grl.rl_modules.DoubleVNetwork.forward"]], "forward() (grl.rl_modules.oneshotvaluefunction method)": [[6, "grl.rl_modules.OneShotValueFunction.forward"]], "forward() (grl.rl_modules.qnetwork method)": [[6, "grl.rl_modules.QNetwork.forward"]], "forward() (grl.rl_modules.vnetwork method)": [[6, "grl.rl_modules.VNetwork.forward"]], "grl.rl_modules": [[6, "module-grl.rl_modules"]], "v_loss() (grl.rl_modules.oneshotvaluefunction method)": [[6, "grl.rl_modules.OneShotValueFunction.v_loss"]], "grl.utils": [[7, "module-grl.utils"]], "set_seed() (in module grl.utils)": [[7, "grl.utils.set_seed"]]}}) \ No newline at end of file