diff --git a/hive/envs/minigrid/__init__.py b/hive/envs/minigrid/__init__.py deleted file mode 100644 index d4a89b9b..00000000 --- a/hive/envs/minigrid/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from hive.envs.minigrid.minigrid import MiniGridEnv diff --git a/notebooks/env_tutorial.ipynb b/notebooks/env_tutorial.ipynb new file mode 100644 index 00000000..2b2093b4 --- /dev/null +++ b/notebooks/env_tutorial.ipynb @@ -0,0 +1,609 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU", + "gpuClass": "standard" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/assets/colab-badge.svg)" + ], + "metadata": { + "id": "iBPurVluSqjt" + } + }, + { + "cell_type": "markdown", + "source": [ + "# About the tutorial\n", + "Datasets are essential in both supervised and unsupervised machine learning settings. In a typical reinforcement learning (RL) setting, the agent must interact with the environment in order to collect data for learning. Thus, environments serve a kind of similar function in RL as datasets do in supervised and unsupervised learning. In this tutorial, we will explain how to use RLHive environments. Note that this tutorial is on single-agent environments." + ], + "metadata": { + "id": "TiiOCUnwDlXB" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qy1F1g74qfo_" + }, + "source": [ + "# Introduction and Setup" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### RLHive Installation" + ], + "metadata": { + "id": "EJutnqZVjNTu" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xi-qOSphrCH7" + }, + "source": [ + "For installation, you can check [this notebook](https://colab.research.google.com/drive/11YirxgoVD7gjN02TdAeyFXOL1qH7Eydv?usp=sharing)." + ] + }, + { + "cell_type": "markdown", + "source": [ + "### How to install environments\n", + "\n", + "RLHive currently supports the following environments:\n", + "\n", + "\n", + "\n", + "* Gym classic control\n", + "* Atari\n", + "* Minigrid (single-agent grid world)\n", + "* Marlgrid (multi-agent)\n", + "* Pettingzoo (multi-agent)\n", + "\n", + "To install Gym, you could simply run `pip install gymnasium`. You can also install dependencies necessary for the environments that RLHive comes with by running `pip install rlhive[]` where `` is a comma separated list made up of `atari`, `gym_minigrid`, and `pettingzoo`.\n", + "\n", + "Marlgrid are also supported, but must be installed separately. Moreover, MinAtar could be reached directly via Gym.\n", + "\n", + "* To install Marlgrid, run `pip install marlgrid@https://github.com/kandouss/marlgrid/archive/refs/heads/master.zip`" + ], + "metadata": { + "id": "eGZyL2zzGKEt" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JWmMKQBPFoO8" + }, + "outputs": [], + "source": [ + "!pip install git+https://github.com/chandar-lab/RLHive.git@dev\n", + "!pip install RLHive['minigrid']" + ] + }, + { + "cell_type": "code", + "source": [ + "from hive import envs\n", + "from hive.utils.registry import registry\n", + "from hive.envs.base import BaseEnv\n", + "from hive.envs.gym_env import GymEnv\n", + "from hive.envs.env_spec import EnvSpec\n", + "import gymnasium as gym\n", + "import minigrid\n", + "from minigrid import ReseedWrapper\n", + "from gym.spaces.discrete import Discrete" + ], + "metadata": { + "id": "VKastN5fSqsP" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Creating environments" + ], + "metadata": { + "id": "Bjn9nUcSnvPa" + } + }, + { + "cell_type": "markdown", + "source": [ + "Every environment used in RLHive should be a subclass of `hive.envs.base.BaseEnv`. It should provide a `reset()` function that resets the environment to a new episode and returns a tuple of `(observation, turn)` and a `step()` function that takes in an action, performs the step in the environment, and returns a tuple of `(observation, reward, terminated, truncated, turn, info)`. The `terminated` variable is `True` when the environment terminates (e.g., when a task is completed). The `truncated` variable is `True` if the episode is truncated due to a time limit or a reason that is not defined as part of the task MDP. Note that the `info` variable is a dictionary containing auxiliary diagnostic information for debugging, learning, and logging. For instance, it could contain individual reward terms that are combined to produce the total reward. The `turn` variable corresponds to the index of the agent whose turn it is (in multi-agent environments).\n", + "\n", + "The `reward` can be a single number, an array, or a dictionary. If it is a number, then that same reward will be given to every single agent. If it is an array, the agents get the reward corresponding to their index in the runner. If it is a dictionary, the keys should be the agent ids, and the value the reward for that agent." + ], + "metadata": { + "id": "37-E7egZn8oc" + } + }, + { + "cell_type": "markdown", + "source": [ + "### `GymEnv`" + ], + "metadata": { + "id": "5ZuasJNftrIp" + } + }, + { + "cell_type": "markdown", + "source": [ + "The [OpenAI gym](https://gymnasium.farama.org/), which provides a flexible manner of designing environments, initializing them, and interacting with them, has become well-known between RL researchers. \n", + "\n", + "If your environment is a gym environment, and you do not need to preprocess the observations generated by the environment, then you can directly use the `hive.envs.gym_env.GymEnv`." + ], + "metadata": { + "id": "Sjhcfhu9twX4" + } + }, + { + "cell_type": "code", + "source": [ + "env = GymEnv(\"CartPole-v0\")" + ], + "metadata": { + "id": "-6K303tyS6Fv" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### `EnvSpec`\n", + "\n", + "Each environment should also provide an `EnvSpec` environment that indicates what space is for the observations, action. These should be lists with one element for each agent. The agent uses this information to create its network according to provided format of valid actions and observations." + ], + "metadata": { + "id": "FdppbbQfoVRg" + } + }, + { + "cell_type": "code", + "source": [ + "env_spec = env.env_spec\n", + "obs_spec, act_spec = env_spec.observation_space[0], env_spec.action_space[0]\n", + "print(\"Environment name : \\n\", env_spec.env_name)\n", + "print(\"Environment observation space: \\n\", obs_spec)\n", + "print(\"Environment action space: \\n\", act_spec)\n", + "print(\"Environment info: \\n\", env_spec.env_info)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "GlUB4vYSNfB8", + "outputId": "61918772-5919-4c39-bf97-dd465d25bc3b" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Environment name : \n", + " CartPole-v0\n", + "Environment observation space: \n", + " Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)\n", + "Environment action space: \n", + " Discrete(2)\n", + "Environment info: \n", + " {}\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Environment basic methods" + ], + "metadata": { + "id": "ujC05sQURorb" + } + }, + { + "cell_type": "markdown", + "source": [ + "To work with any environment, we `reset` the environment to a new initial state, and then use `step` to perform the specified action and return updated information collected from the environment. Moreover, since for image-based environments rendering is important, you can use use `render` function. \n", + "Finally, when we're done with the environment, we can `close` it." + ], + "metadata": { + "id": "drzzPq_ISBr7" + } + }, + { + "cell_type": "code", + "source": [ + "obs, turn = env.reset()\n", + "print(\"Environment initial observation : \\n\", obs)\n", + "print(\"Environment initial turn: \\n\", turn)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mhnAQ_63R2Am", + "outputId": "9fec792e-fcbe-4c54-acb7-9b82d32b0afe" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Environment initial observation : \n", + " [-0.02180136 0.02328087 -0.01036017 0.00271058]\n", + "Environment initial turn: \n", + " 0\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "The `turn` indicats the agent ID, which is 0 in the case of a single agent setting." + ], + "metadata": { + "id": "a4JeyDM7UwaQ" + } + }, + { + "cell_type": "code", + "source": [ + "num_steps = 100\n", + "\n", + "for t in range(num_steps):\n", + " obs, reward, terminated, truncated, turn, info = env.step(act_spec.sample()) # Random policy\n", + " \n", + " if terminated or truncated:\n", + " break\n", + "\n", + "env.close()" + ], + "metadata": { + "id": "PAde34AbUiLG" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Custom environment\n", + "\n", + "You can also create your own custom environment using `GymEnv`. If you need to add extra preprocessing or change the default way that environment/`EnvSpec` creation is done, you can simply subclass this class and override either `create_env()` and/or `create_env_spec()`.\n" + ], + "metadata": { + "id": "lBK7hXAnHHht" + } + }, + { + "cell_type": "code", + "source": [ + "class MiniGridEnv(GymEnv):\n", + " def __init__(self, env_name, num_players=1, seed=42, **kwargs):\n", + " super().__init__(env_name, num_players, seed=seed, **kwargs)\n", + "\n", + " def create_env(self, env_name, seed, **kwargs):\n", + " self._env = gym.make(env_name, **kwargs)\n", + " self._env = ReseedWrapper(self._env, seeds=[seed])\n", + "\n", + " def create_env_spec(self, env_name, **kwargs):\n", + " env_spec = super().create_env_spec(env_name, **kwargs)\n", + " return env_spec\n", + "\n", + " def step(self, action):\n", + " return super().step(action)" + ], + "metadata": { + "id": "Nw8WqXWfisji" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "When you're using gym-based environments, like the `MiniGridEnv`, you can conveniently register the environment for future use by calling `gym.register`:" + ], + "metadata": { + "id": "eyIzVdOgdM9x" + } + }, + { + "cell_type": "code", + "source": [ + "gym.register(id = 'MyMiniGrid', entry_point = MiniGridEnv)" + ], + "metadata": { + "id": "CP3qqjPFTHQr" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "We can also create an environment from scratch inherting `hive.envs.base.BaseEnv`. For instance, in the following cell we have `GridEnv`; it is a 1$\\times$7 grid, indexed from -3 to 3 from left to right. The agent always starts in cell number 0, and at each step, it can choose to walk right (if possible), left (if possible), or stay in the current cell. The agent would be rewarded only when it is in cell 1." + ], + "metadata": { + "id": "BA3SZKwTbMGd" + } + }, + { + "cell_type": "code", + "source": [ + "class GridEnv(BaseEnv):\n", + " def __init__(self, env_name = 'GridEnv', max_steps = 20, **kwargs):\n", + " self._num_grid = 7\n", + " self._observation = 0\n", + " self._num_steps = 0\n", + " self._max_steps = max_steps\n", + "\n", + " super().__init__(self.create_env_spec(env_name, **kwargs), 1)\n", + "\n", + " def create_env_spec(self, env_name, **kwargs):\n", + " observation_spaces = [Discrete(self._num_grid, start = self._num_grid // 2)]\n", + " action_spaces = [Discrete(3, start = -1)]\n", + " return EnvSpec(\n", + " env_name=env_name,\n", + " observation_space=observation_spaces,\n", + " action_space=action_spaces,\n", + " )\n", + "\n", + " def reset(self):\n", + " self._observation = self._num_steps = 0\n", + " return self._observation, self._turn\n", + "\n", + " def step(self, action):\n", + " self._num_steps += 1\n", + "\n", + " if action == 1:\n", + " self._observation = min(self._num_grid // 2, self._observation+1)\n", + " elif action == -1:\n", + " self._observation = max(-self._num_grid // 2, self._observation-1)\n", + " \n", + " if self._observation == 1:\n", + " reward = 1\n", + " else:\n", + " reward = 0\n", + "\n", + " truncated = self._num_steps == self._max_steps\n", + " info = {}\n", + "\n", + " return self._observation, reward, False, truncated, self._turn, info\n", + "\n", + " def render(self):\n", + " pass\n", + " def close(self):\n", + " pass\n", + " def save(self):\n", + " pass\n", + " def seed(self):\n", + " pass\n", + " " + ], + "metadata": { + "id": "ntMi-6cmbQ18" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "env = GridEnv()\n", + "env_spec = env.env_spec\n", + "obs_spec, act_spec = env_spec.observation_space[0], env_spec.action_space[0]\n", + "print(\"Environment name : \\n\", env_spec.env_name)\n", + "print(\"Environment observation space: \\n\", obs_spec)\n", + "print(\"Environment action space: \\n\", act_spec)\n", + "print(\"Environment info: \\n\", env_spec.env_info)" + ], + "metadata": { + "id": "wBNYTJw1xxFD", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "8c442fb6-aedc-410e-8877-f0dcc519375a" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Environment name : \n", + " GridEnv\n", + "Environment observation space: \n", + " Discrete(7, start=3)\n", + "Environment action space: \n", + " Discrete(3, start=-1)\n", + "Environment info: \n", + " {}\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "terminated = truncated = False\n", + "env.reset()\n", + "\n", + "while not terminated and not truncated:\n", + " obs, reward, terminated, truncated, turn, info = env.step(act_spec.sample())\n", + " print(\"Cell {}, Reward {}\".format(obs, reward))\n", + "\n", + "env.close()" + ], + "metadata": { + "id": "9zEL32h-yTgx", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "7287698c-1582-4b6c-c1ab-321f9f55a449" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Cell -1, Reward 0\n", + "Cell 0, Reward 0\n", + "Cell 0, Reward 0\n", + "Cell 1, Reward 1\n", + "Cell 0, Reward 0\n", + "Cell 0, Reward 0\n", + "Cell -1, Reward 0\n", + "Cell -1, Reward 0\n", + "Cell 0, Reward 0\n", + "Cell 0, Reward 0\n", + "Cell -1, Reward 0\n", + "Cell -1, Reward 0\n", + "Cell -2, Reward 0\n", + "Cell -3, Reward 0\n", + "Cell -2, Reward 0\n", + "Cell -1, Reward 0\n", + "Cell 0, Reward 0\n", + "Cell 0, Reward 0\n", + "Cell 0, Reward 0\n", + "Cell 1, Reward 1\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "#### Registering environments\n", + "You can register your environment to create it in one line. The registry module `hive.utils.registry` is used to register classes in the RLHive Registry. Consider registering `GridEnv` we created before:" + ], + "metadata": { + "id": "FHr4g3ljmo8N" + } + }, + { + "cell_type": "code", + "source": [ + "registry.register(name = 'GridEnv', constructor = GridEnv, type = GridEnv)" + ], + "metadata": { + "id": "3TxnV6Zdn-XX" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Then, you can pass a config `dict` to the getter function to create an object of the environment. The configs should have two fields; The `name`, which is the name used when registering a class in the registry, and `**kwargs`, keyword arguments that will be passed to the constructor." + ], + "metadata": { + "id": "5zJEbZ_WEs-s" + } + }, + { + "cell_type": "code", + "source": [ + "environment = {'name': 'GridEnv', 'kwargs': {'env_name': 'GridEnv'}}\n", + "grid_env_fn, full_configs = envs.get_env(environment, 'environment')\n", + "grid_env = grid_env_fn()" + ], + "metadata": { + "id": "_iTWBzDHEj27" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "More than one environment can be registered at once using the `register_all` method. Consider registering two environments `Env1` and `Env2` (inheriting `BaseEnv`):" + ], + "metadata": { + "id": "zdXJ0cOlnP1s" + } + }, + { + "cell_type": "code", + "source": [ + "class Env1(BaseEnv):\n", + " def __init__(self, env_name = 'Env1', **kwargs):\n", + " pass\n", + " def reset(self):\n", + " pass\n", + " def step(self):\n", + " pass\n", + " def render(self):\n", + " pass\n", + " def close(self):\n", + " pass\n", + " def save(self):\n", + " pass\n", + "\n", + "class Env2(BaseEnv):\n", + " def __init__(self, env_name = 'Env2', **kwargs):\n", + " pass\n", + " def reset(self):\n", + " pass\n", + " def step(self):\n", + " pass\n", + " def render(self):\n", + " pass\n", + " def close(self):\n", + " pass\n", + " def save(self):\n", + " pass" + ], + "metadata": { + "id": "XvfqR3uXmoQm" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "registry.register_all(\n", + " BaseEnv,\n", + " {\n", + " \"Env1\": Env2,\n", + " \"Env1\": Env2,\n", + " },\n", + ")" + ], + "metadata": { + "id": "PXkYMQrjnKGJ" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file