From bf2238a91dd896e318deee608ab61e380064b31c Mon Sep 17 00:00:00 2001 From: Gary Lvov Date: Wed, 9 Oct 2024 12:50:18 -0400 Subject: [PATCH] convert obs to class --- pyproject.toml | 1 + .../omni/isaac/lab/envs/mdp/observations.py | 114 ++++++++++-------- 2 files changed, 63 insertions(+), 52 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 51d4375907..9a82624b8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ known_third_party = [ "warp", "carb", "Semantics", + "torchvision" ] # Imports from this repository known_first_party = "omni.isaac.lab" diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/observations.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/observations.py index cc06afa5fb..5973e13ae1 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/observations.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/observations.py @@ -17,6 +17,8 @@ import omni.isaac.lab.utils.math as math_utils from omni.isaac.lab.assets import Articulation, RigidObject from omni.isaac.lab.managers import SceneEntityCfg +from omni.isaac.lab.managers.manager_base import ManagerTermBase +from omni.isaac.lab.managers.manager_term_cfg import ObservationTermCfg from omni.isaac.lab.sensors import Camera, RayCaster, RayCasterCamera, TiledCamera if TYPE_CHECKING: @@ -233,61 +235,69 @@ def image( return images.clone() -def image_features( - env: ManagerBasedEnv, - sensor_cfg: SceneEntityCfg = SceneEntityCfg("tiled_camera"), - data_type: str = "rgb", - convert_perspective_to_orthogonal: bool = True, - model_name: str = "Theia", - model_zoo_cfg: dict | None = None, -) -> torch.Tensor: - """Extracted image features with a frozen encoder from Images of a specific datatype from the camera sensor. - - Args: - env: The environment the cameras are placed within. - sensor_cfg: The desired sensor to read from. Defaults to SceneEntityCfg("tiled_camera"). - data_type: The data type to pull from the desired camera. Defaults to "rgb". - model_name: The name of which model to use from the model_zoo_cfg to use to extract features. - model_zoo_cfg: A dictionary with string keys and callable values. Should include "model", - (mapped to a callable with no arguments to return the model), "preprocess" (mapped to - a callable which consumes the images and returns the preprocessed images), - and "inference" (mapped to a callable that provided the model, and the preproccessed images, - returns the features.) +class image_features(ManagerTermBase): + """Extracted image features with a frozen encoder from images of a specific datatype from the camera sensor. - Returns: - The features from the images produced at the last timestep + Calls :meth:`image` to get the images, then performs inference. On initialization, + for a model zoo different from the default, define model_zoo_cfg: A dictionary with string keys and callable values. + Should include "model", (mapped to a callable with no arguments to return the model), "preprocess" (mapped to + a callable which consumes the images and returns the preprocessed images), + and "inference" (mapped to a callable that provided the model, and the preproccessed images, returns the features.) """ - if not hasattr(image_features, "model_zoo"): - image_features.model_zoo = {} - - if model_zoo_cfg is None: - model_zoo_cfg = { - "ResNet18": { - "model": lambda: models.resnet18(pretrained=True).eval().to("cuda:0"), - "preprocess": lambda img: ( - img.permute(0, 3, 1, 2) # Convert [batch, height, width, 3] -> [batch, 3, height, width] - - torch.tensor([0.485, 0.456, 0.406], device=img.device).view(1, 3, 1, 1) - ) / torch.tensor([0.229, 0.224, 0.225], device=img.device).view(1, 3, 1, 1), - "inference": lambda model, images: model(images), - }, - } - - if model_name not in image_features.model_zoo: - print(f"[INFO]: Adding {model_name} to persistent frozen feature extraction model zoo...") - image_features.model_zoo[model_name] = model_zoo_cfg[model_name]["model"]() - - images = image( - env=env, - sensor_cfg=sensor_cfg, - data_type=data_type, - convert_perspective_to_orthogonal=convert_perspective_to_orthogonal, - normalize=True, # want this for training stability - ) - - proc_images = model_zoo_cfg[model_name]["preprocess"](images) - features = model_zoo_cfg[model_name]["inference"](image_features.model_zoo[model_name], proc_images) - return features + def __init__( + self, + cfg: ObservationTermCfg, + env: ManagerBasedEnv, + model_zoo_cfg: dict | None = None, + initialize_all: bool = False, + ): + super().__init__(cfg, env) + if model_zoo_cfg is None: + self.model_zoo_cfg = { + "ResNet18": { + "model": lambda: models.resnet18(pretrained=True).eval().to("cuda:0"), + "preprocess": lambda img: ( + img.permute(0, 3, 1, 2) # Convert [batch, height, width, 3] -> [batch, 3, height, width] + # Normalize in the format expected by pytorch; https://pytorch.org/hub/pytorch_vision_resnet/ + - torch.tensor([0.485, 0.456, 0.406], device=img.device).view(1, 3, 1, 1) + ) / torch.tensor([0.229, 0.224, 0.225], device=img.device).view(1, 3, 1, 1), + "inference": lambda model, images: model(images), + }, + } + self.reset_model(initialize_all=initialize_all) + + # The following is named reset_model instead of reset as otherwise it's called at the end of every episode + def reset_model(self, initialize_all=False): + self.model_zoo = {} + if initialize_all: + for model_name, model_callables in self.model_zoo_cfg.items(): + self.model_zoo[model_name] = model_callables["model"]() + + def __call__( + self, + env: ManagerBasedEnv, + sensor_cfg: SceneEntityCfg = SceneEntityCfg("tiled_camera"), + data_type: str = "rgb", + convert_perspective_to_orthogonal: bool = False, + model_name: str = "ResNet18", + ): + if model_name not in self.model_zoo: + print(f"[INFO]: Adding {model_name} to the model zoo") + self.model_zoo[model_name] = self.model_zoo_cfg[model_name]["model"]() + + images = image( + env=env, + sensor_cfg=sensor_cfg, + data_type=data_type, + convert_perspective_to_orthogonal=convert_perspective_to_orthogonal, + normalize=True, # want this for training stability + ) + + proc_images = self.model_zoo_cfg[model_name]["preprocess"](images) + features = self.model_zoo_cfg[model_name]["inference"](self.model_zoo[model_name], proc_images) + + return features """