Source code for gymnasium.experimental.wrappers.common

"""A collection of common wrappers.

* ``AutoresetV0`` - Auto-resets the environment
* ``PassiveEnvCheckerV0`` - Passive environment checker that does not modify any environment data
* ``OrderEnforcingV0`` - Enforces the order of function calls to environments
* ``RecordEpisodeStatisticsV0`` - Records the episode statistics
"""
from __future__ import annotations

import time
from collections import deque
from typing import Any, SupportsFloat

import numpy as np

import gymnasium as gym
from gymnasium.core import ActType, ObsType, RenderFrame
from gymnasium.error import ResetNeeded
from gymnasium.utils.passive_env_checker import (
    check_action_space,
    check_observation_space,
    env_render_passive_checker,
    env_reset_passive_checker,
    env_step_passive_checker,
)


__all__ = [
    "AutoresetV0",
    "PassiveEnvCheckerV0",
    "OrderEnforcingV0",
    "RecordEpisodeStatisticsV0",
]


[docs] class AutoresetV0( gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs ): """A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.""" def __init__(self, env: gym.Env[ObsType, ActType]): """A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`. Args: env (gym.Env): The environment to apply the wrapper """ gym.utils.RecordConstructorArgs.__init__(self) gym.Wrapper.__init__(self, env) self._episode_ended: bool = False self._reset_options: dict[str, Any] | None = None def step( self, action: ActType ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered in the previous step. Args: action: The action to take Returns: The autoreset environment :meth:`step` """ if self._episode_ended: obs, info = self.env.reset(options=self._reset_options) self._episode_ended = True return obs, 0, False, False, info else: obs, reward, terminated, truncated, info = super().step(action) self._episode_ended = terminated or truncated return obs, reward, terminated, truncated, info def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[ObsType, dict[str, Any]]: """Resets the environment, saving the options used.""" self._episode_ended = False self._reset_options = options return super().reset(seed=seed, options=self._reset_options)
[docs] class PassiveEnvCheckerV0( gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs ): """A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API.""" def __init__(self, env: gym.Env[ObsType, ActType]): """Initialises the wrapper with the environments, run the observation and action space tests.""" gym.utils.RecordConstructorArgs.__init__(self) gym.Wrapper.__init__(self, env) assert hasattr( env, "action_space" ), "The environment must specify an action space. https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/" check_action_space(env.action_space) assert hasattr( env, "observation_space" ), "The environment must specify an observation space. https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/" check_observation_space(env.observation_space) self._checked_reset: bool = False self._checked_step: bool = False self._checked_render: bool = False def step( self, action: ActType ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """Steps through the environment that on the first call will run the `passive_env_step_check`.""" if self._checked_step is False: self._checked_step = True return env_step_passive_checker(self.env, action) else: return self.env.step(action) def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[ObsType, dict[str, Any]]: """Resets the environment that on the first call will run the `passive_env_reset_check`.""" if self._checked_reset is False: self._checked_reset = True return env_reset_passive_checker(self.env, seed=seed, options=options) else: return self.env.reset(seed=seed, options=options) def render(self) -> RenderFrame | list[RenderFrame] | None: """Renders the environment that on the first call will run the `passive_env_render_check`.""" if self._checked_render is False: self._checked_render = True return env_render_passive_checker(self.env) else: return self.env.render()
[docs] class OrderEnforcingV0( gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs ): """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`. Example: >>> import gymnasium as gym >>> from gymnasium.experimental.wrappers import OrderEnforcingV0 >>> env = gym.make("CartPole-v1", render_mode="human") >>> env = OrderEnforcingV0(env) >>> env.step(0) Traceback (most recent call last): ... gymnasium.error.ResetNeeded: Cannot call env.step() before calling env.reset() >>> env.render() Traceback (most recent call last): ... gymnasium.error.ResetNeeded: Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper. >>> _ = env.reset() >>> env.render() >>> _ = env.step(0) >>> env.close() """ def __init__( self, env: gym.Env[ObsType, ActType], disable_render_order_enforcing: bool = False, ): """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`. Args: env: The environment to wrap disable_render_order_enforcing: If to disable render order enforcing """ gym.utils.RecordConstructorArgs.__init__( self, disable_render_order_enforcing=disable_render_order_enforcing ) gym.Wrapper.__init__(self, env) self._has_reset: bool = False self._disable_render_order_enforcing: bool = disable_render_order_enforcing def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]: """Steps through the environment.""" if not self._has_reset: raise ResetNeeded("Cannot call env.step() before calling env.reset()") return super().step(action) def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[ObsType, dict[str, Any]]: """Resets the environment with `kwargs`.""" self._has_reset = True return super().reset(seed=seed, options=options) def render(self) -> RenderFrame | list[RenderFrame] | None: """Renders the environment with `kwargs`.""" if not self._disable_render_order_enforcing and not self._has_reset: raise ResetNeeded( "Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, " "set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper." ) return super().render() @property def has_reset(self): """Returns if the environment has been reset before.""" return self._has_reset
[docs] class RecordEpisodeStatisticsV0( gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs ): """This wrapper will keep track of cumulative rewards and episode lengths. At the end of an episode, the statistics of the episode will be added to ``info`` using the key ``episode``. If using a vectorized environment also the key ``_episode`` is used which indicates whether the env at the respective index has the episode statistics. After the completion of an episode, ``info`` will look like this:: >>> info = { ... "episode": { ... "r": "<cumulative reward>", ... "l": "<episode length>", ... "t": "<elapsed time since beginning of episode>" ... }, ... } For a vectorized environments the output will be in the form of:: >>> infos = { ... "final_observation": "<array of length num-envs>", ... "_final_observation": "<boolean array of length num-envs>", ... "final_info": "<array of length num-envs>", ... "_final_info": "<boolean array of length num-envs>", ... "episode": { ... "r": "<array of cumulative reward>", ... "l": "<array of episode length>", ... "t": "<array of elapsed time since beginning of episode>" ... }, ... "_episode": "<boolean array of length num-envs>" ... } Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via :attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively. Attributes: episode_reward_buffer: The cumulative rewards of the last ``deque_size``-many episodes episode_length_buffer: The lengths of the last ``deque_size``-many episodes """ def __init__( self, env: gym.Env[ObsType, ActType], buffer_length: int | None = 100, stats_key: str = "episode", ): """This wrapper will keep track of cumulative rewards and episode lengths. Args: env (Env): The environment to apply the wrapper buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue` stats_key: The info key for the episode statistics """ gym.utils.RecordConstructorArgs.__init__(self) gym.Wrapper.__init__(self, env) self._stats_key = stats_key self.episode_count = 0 self.episode_start_time: float = -1 self.episode_reward: float = -1 self.episode_length: int = -1 self.episode_time_length_buffer: deque[int] = deque(maxlen=buffer_length) self.episode_reward_buffer: deque[float] = deque(maxlen=buffer_length) self.episode_length_buffer: deque[int] = deque(maxlen=buffer_length) def step( self, action: ActType ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """Steps through the environment, recording the episode statistics.""" obs, reward, terminated, truncated, info = super().step(action) self.episode_reward += reward self.episode_length += 1 if terminated or truncated: assert self._stats_key not in info episode_time_length = np.round( time.perf_counter() - self.episode_start_time, 6 ) info[self._stats_key] = { "r": self.episode_reward, "l": self.episode_length, "t": episode_time_length, } self.episode_time_length_buffer.append(episode_time_length) self.episode_reward_buffer.append(self.episode_reward) self.episode_length_buffer.append(self.episode_length) self.episode_count += 1 return obs, reward, terminated, truncated, info def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[ObsType, dict[str, Any]]: """Resets the environment using seed and options and resets the episode rewards and lengths.""" obs, info = super().reset(seed=seed, options=options) self.episode_start_time = time.perf_counter() self.episode_reward = 0 self.episode_length = 0 return obs, info