Source code for gymnasium.wrappers.atari_preprocessing

"""Implementation of Atari 2600 Preprocessing following the guidelines of Machado et al., 2018."""
import numpy as np

import gymnasium as gym
from gymnasium.spaces import Box


try:
    import cv2
except ImportError:
    cv2 = None


[docs]class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs): """Atari 2600 preprocessing wrapper. This class follows the guidelines in Machado et al. (2018), "Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents". Specifically, the following preprocess stages applies to the atari environment: - Noop Reset: Obtains the initial state by taking a random number of no-ops on reset, default max 30 no-ops. - Frame skipping: The number of frames skipped between steps, 4 by default - Max-pooling: Pools over the most recent two observations from the frame skips - Termination signal when a life is lost: When the agent losses a life during the environment, then the environment is terminated. Turned off by default. Not recommended by Machado et al. (2018). - Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default - Grayscale observation: If the observation is colour or greyscale, by default, greyscale. - Scale observation: If to scale the observation between [0, 1) or [0, 255), by default, not scaled. """ def __init__( self, env: gym.Env, noop_max: int = 30, frame_skip: int = 4, screen_size: int = 84, terminal_on_life_loss: bool = False, grayscale_obs: bool = True, grayscale_newaxis: bool = False, scale_obs: bool = False, ): """Wrapper for Atari 2600 preprocessing. Args: env (Env): The environment to apply the preprocessing noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0. frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game. screen_size (int): resize Atari frame terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a life is lost. grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation is returned. grayscale_newaxis (bool): `if True and grayscale_obs=True`, then a channel axis is added to grayscale observations to make them 3-dimensional. scale_obs (bool): if True, then observation normalized in range [0,1) is returned. It also limits memory optimization benefits of FrameStack Wrapper. Raises: DependencyNotInstalled: opencv-python package not installed ValueError: Disable frame-skipping in the original env """ gym.utils.RecordConstructorArgs.__init__( self, noop_max=noop_max, frame_skip=frame_skip, screen_size=screen_size, terminal_on_life_loss=terminal_on_life_loss, grayscale_obs=grayscale_obs, grayscale_newaxis=grayscale_newaxis, scale_obs=scale_obs, ) gym.Wrapper.__init__(self, env) if cv2 is None: raise gym.error.DependencyNotInstalled( "opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari" ) assert frame_skip > 0 assert screen_size > 0 assert noop_max >= 0 if frame_skip > 1: if ( env.spec is not None and "NoFrameskip" not in env.spec.id and getattr(env.unwrapped, "_frameskip", None) != 1 ): raise ValueError( "Disable frame-skipping in the original env. Otherwise, more than one " "frame-skip will happen as through this wrapper" ) self.noop_max = noop_max assert env.unwrapped.get_action_meanings()[0] == "NOOP" self.frame_skip = frame_skip self.screen_size = screen_size self.terminal_on_life_loss = terminal_on_life_loss self.grayscale_obs = grayscale_obs self.grayscale_newaxis = grayscale_newaxis self.scale_obs = scale_obs # buffer of most recent two observations for max pooling assert isinstance(env.observation_space, Box) if grayscale_obs: self.obs_buffer = [ np.empty(env.observation_space.shape[:2], dtype=np.uint8), np.empty(env.observation_space.shape[:2], dtype=np.uint8), ] else: self.obs_buffer = [ np.empty(env.observation_space.shape, dtype=np.uint8), np.empty(env.observation_space.shape, dtype=np.uint8), ] self.lives = 0 self.game_over = False _low, _high, _obs_dtype = ( (0, 255, np.uint8) if not scale_obs else (0, 1, np.float32) ) _shape = (screen_size, screen_size, 1 if grayscale_obs else 3) if grayscale_obs and not grayscale_newaxis: _shape = _shape[:-1] # Remove channel axis self.observation_space = Box( low=_low, high=_high, shape=_shape, dtype=_obs_dtype ) @property def ale(self): """Make ale as a class property to avoid serialization error.""" return self.env.unwrapped.ale def step(self, action): """Applies the preprocessing for an :meth:`env.step`.""" total_reward, terminated, truncated, info = 0.0, False, False, {} for t in range(self.frame_skip): _, reward, terminated, truncated, info = self.env.step(action) total_reward += reward self.game_over = terminated if self.terminal_on_life_loss: new_lives = self.ale.lives() terminated = terminated or new_lives < self.lives self.game_over = terminated self.lives = new_lives if terminated or truncated: break if t == self.frame_skip - 2: if self.grayscale_obs: self.ale.getScreenGrayscale(self.obs_buffer[1]) else: self.ale.getScreenRGB(self.obs_buffer[1]) elif t == self.frame_skip - 1: if self.grayscale_obs: self.ale.getScreenGrayscale(self.obs_buffer[0]) else: self.ale.getScreenRGB(self.obs_buffer[0]) return self._get_obs(), total_reward, terminated, truncated, info def reset(self, **kwargs): """Resets the environment using preprocessing.""" # NoopReset _, reset_info = self.env.reset(**kwargs) noops = ( self.env.unwrapped.np_random.integers(1, self.noop_max + 1) if self.noop_max > 0 else 0 ) for _ in range(noops): _, _, terminated, truncated, step_info = self.env.step(0) reset_info.update(step_info) if terminated or truncated: _, reset_info = self.env.reset(**kwargs) self.lives = self.ale.lives() if self.grayscale_obs: self.ale.getScreenGrayscale(self.obs_buffer[0]) else: self.ale.getScreenRGB(self.obs_buffer[0]) self.obs_buffer[1].fill(0) return self._get_obs(), reset_info def _get_obs(self): if self.frame_skip > 1: # more efficient in-place pooling np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0]) assert cv2 is not None obs = cv2.resize( self.obs_buffer[0], (self.screen_size, self.screen_size), interpolation=cv2.INTER_AREA, ) if self.scale_obs: obs = np.asarray(obs, dtype=np.float32) / 255.0 else: obs = np.asarray(obs, dtype=np.uint8) if self.grayscale_obs and self.grayscale_newaxis: obs = np.expand_dims(obs, axis=-1) # Add a channel axis return obs