Source code for gymnasium.experimental.wrappers.rendering

"""A collections of rendering-based wrappers.

* ``RenderCollectionV0`` - Collects rendered frames into a list
* ``RecordVideoV0`` - Records a video of the environments
* ``HumanRenderingV0`` - Provides human rendering of environments with ``"rgb_array"``
"""
from __future__ import annotations

import os
from copy import deepcopy
from typing import Any, Callable, List, SupportsFloat

import numpy as np

import gymnasium as gym
from gymnasium import error, logger
from gymnasium.core import ActType, ObsType, RenderFrame
from gymnasium.error import DependencyNotInstalled


__all__ = ["RenderCollectionV0", "RecordVideoV0", "HumanRenderingV0"]


[docs] class RenderCollectionV0( gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs ): """Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``.""" def __init__( self, env: gym.Env[ObsType, ActType], pop_frames: bool = True, reset_clean: bool = True, ): """Initialize a :class:`RenderCollection` instance. Args: env: The environment that is being wrapped pop_frames (bool): If true, clear the collection frames after ``meth:render`` is called. Default value is ``True``. reset_clean (bool): If true, clear the collection frames when ``meth:reset`` is called. Default value is ``True``. """ gym.utils.RecordConstructorArgs.__init__( self, pop_frames=pop_frames, reset_clean=reset_clean ) gym.Wrapper.__init__(self, env) assert env.render_mode is not None assert not env.render_mode.endswith("_list") self.frame_list: list[RenderFrame] = [] self.pop_frames = pop_frames self.reset_clean = reset_clean self.metadata = deepcopy(self.env.metadata) if f"{self.env.render_mode}_list" not in self.metadata["render_modes"]: self.metadata["render_modes"].append(f"{self.env.render_mode}_list") @property def render_mode(self): """Returns the collection render_mode name.""" return f"{self.env.render_mode}_list" def step( self, action: ActType ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """Perform a step in the base environment and collect a frame.""" output = super().step(action) self.frame_list.append(super().render()) return output def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[ObsType, dict[str, Any]]: """Reset the base environment, eventually clear the frame_list, and collect a frame.""" output = super().reset(seed=seed, options=options) if self.reset_clean: self.frame_list = [] self.frame_list.append(super().render()) return output def render(self) -> list[RenderFrame]: """Returns the collection of frames and, if pop_frames = True, clears it.""" frames = self.frame_list if self.pop_frames: self.frame_list = [] return frames
[docs] class RecordVideoV0( gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs ): """This wrapper records videos of rollouts. Usually, you only want to record episodes intermittently, say every hundredth episode. To do this, you can specify ``episode_trigger`` or ``step_trigger``. They should be functions returning a boolean that indicates whether a recording should be started at the current episode or step, respectively. If neither :attr:`episode_trigger` nor ``step_trigger`` is passed, a default ``episode_trigger`` will be employed, i.e. capped_cubic_video_schedule. This function starts a video at every episode that is a power of 3 until 1000 and then every 1000 episodes. By default, the recording will be stopped once reset is called. However, you can also create recordings of fixed length (possibly spanning several episodes) by passing a strictly positive value for ``video_length``. """ def __init__( self, env: gym.Env[ObsType, ActType], video_folder: str, episode_trigger: Callable[[int], bool] | None = None, step_trigger: Callable[[int], bool] | None = None, video_length: int = 0, name_prefix: str = "rl-video", fps: int | None = None, disable_logger: bool = False, ): """Wrapper records videos of rollouts. Args: env: The environment that will be wrapped video_folder (str): The folder where the recordings will be stored episode_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this episode step_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this step video_length (int): The length of recorded episodes. If 0, entire episodes are recorded. Otherwise, snippets of the specified length are captured name_prefix (str): Will be prepended to the filename of the recordings fps (int): The frame per second in the video. The default value is the one specified in the environment metadata. If the environment metadata doesn't specify `render_fps`, the value 30 is used. disable_logger (bool): Whether to disable moviepy logger or not """ gym.utils.RecordConstructorArgs.__init__( self, video_folder=video_folder, episode_trigger=episode_trigger, step_trigger=step_trigger, video_length=video_length, name_prefix=name_prefix, disable_logger=disable_logger, ) gym.Wrapper.__init__(self, env) if env.render_mode in {None, "human", "ansi"}: raise ValueError( f"Render mode is {env.render_mode}, which is incompatible with RecordVideo.", "Initialize your environment with a render_mode that returns an image, such as rgb_array.", ) if episode_trigger is None and step_trigger is None: def capped_cubic_video_schedule(episode_id: int) -> bool: if episode_id < 1000: return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id else: return episode_id % 1000 == 0 episode_trigger = capped_cubic_video_schedule self.episode_trigger = episode_trigger self.step_trigger = step_trigger self.disable_logger = disable_logger self.video_folder = os.path.abspath(video_folder) if os.path.isdir(self.video_folder): logger.warn( f"Overwriting existing videos at {self.video_folder} folder " f"(try specifying a different `video_folder` for the `RecordVideo` wrapper if this is not desired)" ) os.makedirs(self.video_folder, exist_ok=True) if fps is None: fps = self.metadata.get("render_fps", 30) self.frames_per_sec: int = fps self.name_prefix: str = name_prefix self._video_name: str | None = None self.video_length: int = video_length if video_length != 0 else float("inf") self.recording: bool = False self.recorded_frames: list[RenderFrame] = [] self.render_history: list[RenderFrame] = [] self.step_id = -1 self.episode_id = -1 try: import moviepy # noqa: F401 except ImportError as e: raise error.DependencyNotInstalled( "MoviePy is not installed, run `pip install moviepy`" ) from e def _capture_frame(self): assert self.recording, "Cannot capture a frame, recording wasn't started." frame = self.env.render() if isinstance(frame, List): if len(frame) == 0: # render was called return self.render_history += frame frame = frame[-1] if isinstance(frame, np.ndarray): self.recorded_frames.append(frame) else: self.stop_recording() logger.warn( "Recording stopped: expected type of frame returned by render ", f"to be a numpy array, got instead {type(frame)}.", ) def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[ObsType, dict[str, Any]]: """Reset the environment and eventually starts a new recording.""" obs, info = super().reset(seed=seed, options=options) self.episode_id += 1 if self.recording and self.video_length == float("inf"): self.stop_recording() if self.episode_trigger and self.episode_trigger(self.episode_id): self.start_recording(f"{self.name_prefix}-episode-{self.episode_id}") if self.recording: self._capture_frame() if len(self.recorded_frames) > self.video_length: self.stop_recording() return obs, info def step( self, action: ActType ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """Steps through the environment using action, recording observations if :attr:`self.recording`.""" obs, rew, terminated, truncated, info = self.env.step(action) self.step_id += 1 if self.step_trigger and self.step_trigger(self.step_id): self.start_recording(f"{self.name_prefix}-step-{self.step_id}") if self.recording: self._capture_frame() if len(self.recorded_frames) > self.video_length: self.stop_recording() return obs, rew, terminated, truncated, info def start_recording(self, video_name: str): """Start a new recording. If it is already recording, stops the current recording before starting the new one.""" if self.recording: self.stop_recording() self.recording = True self._video_name = video_name def stop_recording(self): """Stop current recording and saves the video.""" assert self.recording, "stop_recording was called, but no recording was started" if len(self.recorded_frames) == 0: logger.warn("Ignored saving a video as there were zero frames to save.") else: try: from moviepy.video.io.ImageSequenceClip import ImageSequenceClip except ImportError as e: raise error.DependencyNotInstalled( "MoviePy is not installed, run `pip install moviepy`" ) from e clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec) moviepy_logger = None if self.disable_logger else "bar" path = os.path.join(self.video_folder, f"{self._video_name}.mp4") clip.write_videofile(path, logger=moviepy_logger) self.recorded_frames = [] self.recording = False self._video_name = None def render(self) -> RenderFrame | list[RenderFrame]: """Compute the render frames as specified by render_mode attribute during initialization of the environment.""" render_out = super().render() if self.recording and isinstance(render_out, List): self.recorded_frames += render_out if len(self.render_history) > 0: tmp_history = self.render_history self.render_history = [] return tmp_history + render_out else: return render_out def close(self): """Closes the wrapper then the video recorder.""" super().close() if self.recording: self.stop_recording() def __del__(self): """Warn the user in case last video wasn't saved.""" if len(self.recorded_frames) > 0: logger.warn("Unable to save last video! Did you call close()?")
[docs] class HumanRenderingV0( gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs ): """Performs human rendering for an environment that only supports "rgb_array"rendering. This wrapper is particularly useful when you have implemented an environment that can produce RGB images but haven't implemented any code to render the images to the screen. If you want to use this wrapper with your environments, remember to specify ``"render_fps"`` in the metadata of your environment. The ``render_mode`` of the wrapped environment must be either ``'rgb_array'`` or ``'rgb_array_list'``. Example: >>> import gymnasium as gym >>> from gymnasium.experimental.wrappers import HumanRenderingV0 >>> env = gym.make("LunarLander-v2", render_mode="rgb_array") >>> wrapped = HumanRenderingV0(env) >>> obs, _ = wrapped.reset() # This will start rendering to the screen The wrapper can also be applied directly when the environment is instantiated, simply by passing ``render_mode="human"`` to ``make``. The wrapper will only be applied if the environment does not implement human-rendering natively (i.e. ``render_mode`` does not contain ``"human"``). >>> env = gym.make("phys2d/CartPole-v1", render_mode="human") # CartPoleJax-v1 doesn't implement human-rendering natively >>> obs, _ = env.reset() # This will start rendering to the screen Warning: If the base environment uses ``render_mode="rgb_array_list"``, its (i.e. the *base environment's*) render method will always return an empty list: >>> env = gym.make("LunarLander-v2", render_mode="rgb_array_list") >>> wrapped = HumanRenderingV0(env) >>> obs, _ = wrapped.reset() >>> env.render() # env.render() will always return an empty list! [] """ def __init__(self, env: gym.Env[ObsType, ActType]): """Initialize a :class:`HumanRendering` instance. Args: env: The environment that is being wrapped """ gym.utils.RecordConstructorArgs.__init__(self) gym.Wrapper.__init__(self, env) assert env.render_mode in [ "rgb_array", "rgb_array_list", ], f"Expected env.render_mode to be one of 'rgb_array' or 'rgb_array_list' but got '{env.render_mode}'" assert ( "render_fps" in env.metadata ), "The base environment must specify 'render_fps' to be used with the HumanRendering wrapper" self.screen_size = None self.window = None self.clock = None if "human" not in self.metadata["render_modes"]: self.metadata = deepcopy(self.env.metadata) self.metadata["render_modes"].append("human") @property def render_mode(self): """Always returns ``'human'``.""" return "human" def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]: """Perform a step in the base environment and render a frame to the screen.""" result = super().step(action) self._render_frame() return result def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[ObsType, dict[str, Any]]: """Reset the base environment and render a frame to the screen.""" result = super().reset(seed=seed, options=options) self._render_frame() return result def render(self) -> None: """This method doesn't do much, actual rendering is performed in :meth:`step` and :meth:`reset`.""" return None def _render_frame(self): """Fetch the last frame from the base environment and render it to the screen.""" try: import pygame except ImportError: raise DependencyNotInstalled( "pygame is not installed, run `pip install gymnasium[box2d]`" ) if self.env.render_mode == "rgb_array_list": last_rgb_array = self.env.render() assert isinstance(last_rgb_array, list) last_rgb_array = last_rgb_array[-1] elif self.env.render_mode == "rgb_array": last_rgb_array = self.env.render() else: raise Exception( f"Wrapped environment must have mode 'rgb_array' or 'rgb_array_list', actual render mode: {self.env.render_mode}" ) assert isinstance(last_rgb_array, np.ndarray) rgb_array = np.transpose(last_rgb_array, axes=(1, 0, 2)) if self.screen_size is None: self.screen_size = rgb_array.shape[:2] assert ( self.screen_size == rgb_array.shape[:2] ), f"The shape of the rgb array has changed from {self.screen_size} to {rgb_array.shape[:2]}" if self.window is None: pygame.init() pygame.display.init() self.window = pygame.display.set_mode(self.screen_size) if self.clock is None: self.clock = pygame.time.Clock() surf = pygame.surfarray.make_surface(rgb_array) self.window.blit(surf, (0, 0)) pygame.event.pump() self.clock.tick(self.metadata["render_fps"]) pygame.display.flip() def close(self): """Close the rendering window.""" if self.window is not None: import pygame pygame.display.quit() pygame.quit() super().close()