"""Utilities of visualising an environment."""
from __future__ import annotations
from collections import deque
from typing import TYPE_CHECKING, Callable, List
import numpy as np
import gymnasium as gym
from gymnasium import Env, logger
from gymnasium.core import ActType, ObsType
from gymnasium.error import DependencyNotInstalled
if TYPE_CHECKING:
from matplotlib.axes import Axes
try:
import pygame
from pygame import Surface
from pygame.event import Event
except ImportError as e:
raise gym.error.DependencyNotInstalled(
'pygame is not installed, run `pip install "gymnasium[classic_control]"`'
) from e
try:
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
except ImportError:
logger.warn('matplotlib is not installed, run `pip install "gymnasium[other]"`')
matplotlib, plt = None, None
class MissingKeysToAction(Exception):
"""Raised when the environment does not have a default ``keys_to_action`` mapping."""
[docs]
class PlayableGame:
"""Wraps an environment allowing keyboard inputs to interact with the environment."""
def __init__(
self,
env: Env,
keys_to_action: dict[tuple[int, ...], int] | None = None,
zoom: float | None = None,
):
"""Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment.
Args:
env: The environment to play
keys_to_action: The dictionary of keyboard tuples and action value
zoom: If to zoom in on the environment render
"""
if env.render_mode not in {"rgb_array", "rgb_array_list"}:
raise ValueError(
"PlayableGame wrapper works only with rgb_array and rgb_array_list render modes, "
f"but your environment render_mode = {env.render_mode}."
)
self.env = env
self.relevant_keys = self._get_relevant_keys(keys_to_action)
# self.video_size is the size of the video that is being displayed.
# The window size may be larger, in that case we will add black bars
self.video_size = self._get_video_size(zoom)
self.screen = pygame.display.set_mode(self.video_size, pygame.RESIZABLE)
self.pressed_keys = []
self.running = True
def _get_relevant_keys(
self, keys_to_action: dict[tuple[int], int] | None = None
) -> set:
if keys_to_action is None:
if self.env.has_wrapper_attr("get_keys_to_action"):
keys_to_action = self.env.get_wrapper_attr("get_keys_to_action")()
else:
assert self.env.spec is not None
raise MissingKeysToAction(
f"{self.env.spec.id} does not have explicit key to action mapping, "
"please specify one manually, `play(env, keys_to_action=...)`"
)
assert isinstance(keys_to_action, dict)
relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
return relevant_keys
def _get_video_size(self, zoom: float | None = None) -> tuple[int, int]:
rendered = self.env.render()
if isinstance(rendered, List):
rendered = rendered[-1]
assert rendered is not None and isinstance(rendered, np.ndarray)
video_size = (rendered.shape[1], rendered.shape[0])
if zoom is not None:
video_size = (int(video_size[0] * zoom), int(video_size[1] * zoom))
return video_size
[docs]
def process_event(self, event: Event):
"""Processes a PyGame event.
In particular, this function is used to keep track of which buttons are currently pressed
and to exit the :func:`play` function when the PyGame window is closed.
Args:
event: The event to process
"""
if event.type == pygame.KEYDOWN:
if event.key in self.relevant_keys:
self.pressed_keys.append(event.key)
elif event.key == pygame.K_ESCAPE:
self.running = False
elif event.type == pygame.KEYUP:
if event.key in self.relevant_keys:
self.pressed_keys.remove(event.key)
elif event.type == pygame.QUIT:
self.running = False
elif event.type == pygame.WINDOWRESIZED:
# Compute the maximum video size that fits into the new window
scale_width = event.x / self.video_size[0]
scale_height = event.y / self.video_size[1]
scale = min(scale_height, scale_width)
self.video_size = (scale * self.video_size[0], scale * self.video_size[1])
def display_arr(
screen: Surface, arr: np.ndarray, video_size: tuple[int, int], transpose: bool
):
"""Displays a numpy array on screen.
Args:
screen: The screen to show the array on
arr: The array to show
video_size: The video size of the screen
transpose: If to transpose the array on the screen
"""
assert isinstance(arr, np.ndarray) and arr.dtype == np.uint8
pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if transpose else arr)
pyg_img = pygame.transform.scale(pyg_img, video_size)
# We might have to add black bars if surface_size is larger than video_size
surface_size = screen.get_size()
width_offset = (surface_size[0] - video_size[0]) / 2
height_offset = (surface_size[1] - video_size[1]) / 2
screen.fill((0, 0, 0))
screen.blit(pyg_img, (width_offset, height_offset))
[docs]
def play(
env: Env,
transpose: bool | None = True,
fps: int | None = None,
zoom: float | None = None,
callback: Callable | None = None,
keys_to_action: dict[tuple[str | int, ...] | str | int, ActType] | None = None,
seed: int | None = None,
noop: ActType = 0,
wait_on_player: bool = False,
):
"""Allows the user to play the environment using a keyboard.
If playing in a turn-based environment, set wait_on_player to True.
Args:
env: Environment to use for playing.
transpose: If this is ``True``, the output of observation is transposed. Defaults to ``True``.
fps: Maximum number of steps of the environment executed every second. If ``None`` (the default),
``env.metadata["render_fps""]`` (or 30, if the environment does not specify "render_fps") is used.
zoom: Zoom the observation in, ``zoom`` amount, should be positive float
callback: If a callback is provided, it will be executed after every step. It takes the following input:
* obs_t: observation before performing action
* obs_tp1: observation after performing action
* action: action that was executed
* rew: reward that was received
* terminated: whether the environment is terminated or not
* truncated: whether the environment is truncated or not
* info: debug info
keys_to_action: Mapping from keys pressed to action performed.
Different formats are supported: Key combinations can either be expressed as a tuple of unicode code
points of the keys, as a tuple of characters, or as a string where each character of the string represents
one key.
For example if pressing 'w' and space at the same time is supposed
to trigger action number 2 then ``key_to_action`` dict could look like this:
>>> key_to_action = {
... # ...
... (ord('w'), ord(' ')): 2
... # ...
... }
or like this:
>>> key_to_action = {
... # ...
... ("w", " "): 2
... # ...
... }
or like this:
>>> key_to_action = {
... # ...
... "w ": 2
... # ...
... }
If ``None``, default ``key_to_action`` mapping for that environment is used, if provided.
seed: Random seed used when resetting the environment. If None, no seed is used.
noop: The action used when no key input has been entered, or the entered key combination is unknown.
wait_on_player: Play should wait for a user action
Example:
>>> import gymnasium as gym
>>> import numpy as np
>>> from gymnasium.utils.play import play
>>> play(gym.make("CarRacing-v3", render_mode="rgb_array"), # doctest: +SKIP
... keys_to_action={
... "w": np.array([0, 0.7, 0], dtype=np.float32),
... "a": np.array([-1, 0, 0], dtype=np.float32),
... "s": np.array([0, 0, 1], dtype=np.float32),
... "d": np.array([1, 0, 0], dtype=np.float32),
... "wa": np.array([-1, 0.7, 0], dtype=np.float32),
... "dw": np.array([1, 0.7, 0], dtype=np.float32),
... "ds": np.array([1, 0, 1], dtype=np.float32),
... "as": np.array([-1, 0, 1], dtype=np.float32),
... },
... noop=np.array([0, 0, 0], dtype=np.float32)
... )
Above code works also if the environment is wrapped, so it's particularly useful in
verifying that the frame-level preprocessing does not render the game
unplayable.
If you wish to plot real time statistics as you play, you can use
:class:`PlayPlot`. Here's a sample code for plotting the reward
for last 150 steps.
>>> from gymnasium.utils.play import PlayPlot, play
>>> def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info):
... return [rew,]
>>> plotter = PlayPlot(callback, 150, ["reward"]) # doctest: +SKIP
>>> play(gym.make("CartPole-v1"), callback=plotter.callback) # doctest: +SKIP
"""
env.reset(seed=seed)
if keys_to_action is None:
if env.has_wrapper_attr("get_keys_to_action"):
keys_to_action = env.get_wrapper_attr("get_keys_to_action")()
else:
assert env.spec is not None
raise MissingKeysToAction(
f"{env.spec.id} does not have explicit key to action mapping, "
"please specify one manually"
)
assert keys_to_action is not None
# validate the `keys_to_action` set provided
assert isinstance(keys_to_action, dict)
for key, action in keys_to_action.items():
if isinstance(key, tuple):
assert len(key) > 0
assert all(isinstance(k, (str, int)) for k in key)
else:
assert isinstance(key, (str, int))
assert action in env.action_space
key_code_to_action = {}
for key_combination, action in keys_to_action.items():
key_code = tuple(
sorted(ord(key) if isinstance(key, str) else key for key in key_combination)
)
key_code_to_action[key_code] = action
game = PlayableGame(env, key_code_to_action, zoom)
if fps is None:
fps = env.metadata.get("render_fps", 30)
done, obs = True, None
clock = pygame.time.Clock()
while game.running:
if done:
done = False
obs = env.reset(seed=seed)
elif wait_on_player is False or len(game.pressed_keys) > 0:
action = key_code_to_action.get(tuple(sorted(game.pressed_keys)), noop)
prev_obs = obs
obs, rew, terminated, truncated, info = env.step(action)
done = terminated or truncated
if callback is not None:
callback(prev_obs, obs, action, rew, terminated, truncated, info)
if obs is not None:
rendered = env.render()
if isinstance(rendered, List):
rendered = rendered[-1]
assert rendered is not None and isinstance(rendered, np.ndarray)
display_arr(
game.screen, rendered, transpose=transpose, video_size=game.video_size
)
# process pygame events
for event in pygame.event.get():
game.process_event(event)
pygame.display.flip()
clock.tick(fps)
pygame.quit()
[docs]
class PlayPlot:
"""Provides a callback to create live plots of arbitrary metrics when using :func:`play`.
This class is instantiated with a function that accepts information about a single environment transition:
- obs_t: observation before performing action
- obs_tp1: observation after performing action
- action: action that was executed
- rew: reward that was received
- terminated: whether the environment is terminated or not
- truncated: whether the environment is truncated or not
- info: debug info
It should return a list of metrics that are computed from this data.
For instance, the function may look like this::
>>> def compute_metrics(obs_t, obs_tp, action, reward, terminated, truncated, info):
... return [reward, info["cumulative_reward"], np.linalg.norm(action)]
:class:`PlayPlot` provides the method :meth:`callback` which will pass its arguments along to that function
and uses the returned values to update live plots of the metrics.
Typically, this :meth:`callback` will be used in conjunction with :func:`play` to see how the metrics evolve as you play::
>>> plotter = PlayPlot(compute_metrics, horizon_timesteps=200, # doctest: +SKIP
... plot_names=["Immediate Rew.", "Cumulative Rew.", "Action Magnitude"])
>>> play(your_env, callback=plotter.callback) # doctest: +SKIP
"""
def __init__(
self, callback: Callable, horizon_timesteps: int, plot_names: list[str]
):
"""Constructor of :class:`PlayPlot`.
The function ``callback`` that is passed to this constructor should return
a list of metrics that is of length ``len(plot_names)``.
Args:
callback: Function that computes metrics from environment transitions
horizon_timesteps: The time horizon used for the live plots
plot_names: List of plot titles
Raises:
DependencyNotInstalled: If matplotlib is not installed
"""
self.data_callback = callback
self.horizon_timesteps = horizon_timesteps
self.plot_names = plot_names
if plt is None:
raise DependencyNotInstalled(
'matplotlib is not installed, run `pip install "gymnasium[other]"`'
)
num_plots = len(self.plot_names)
self.fig, self.ax = plt.subplots(num_plots)
if num_plots == 1:
self.ax = [self.ax]
for axis, name in zip(self.ax, plot_names):
axis.set_title(name)
self.t = 0
self.cur_plot: list[Axes | None] = [None for _ in range(num_plots)]
self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)]
[docs]
def callback(
self,
obs_t: ObsType,
obs_tp1: ObsType,
action: ActType,
rew: float,
terminated: bool,
truncated: bool,
info: dict,
):
"""The callback that calls the provided data callback and adds the data to the plots.
Args:
obs_t: The observation at time step t
obs_tp1: The observation at time step t+1
action: The action
rew: The reward
terminated: If the environment is terminated
truncated: If the environment is truncated
info: The information from the environment
"""
points = self.data_callback(
obs_t, obs_tp1, action, rew, terminated, truncated, info
)
for point, data_series in zip(points, self.data):
data_series.append(point)
self.t += 1
xmin, xmax = max(0, self.t - self.horizon_timesteps), self.t
for i, plot in enumerate(self.cur_plot):
if plot is not None:
plot.remove()
self.cur_plot[i] = self.ax[i].scatter(
range(xmin, xmax), list(self.data[i]), c="blue"
)
self.ax[i].set_xlim(xmin, xmax)
if plt is None:
raise DependencyNotInstalled(
'matplotlib is not installed, run `pip install "gymnasium[other]"`'
)
plt.pause(0.000001)