Source code for gymnasium.wrappers.gray_scale_observation

"""Wrapper that converts a color observation to grayscale."""
import numpy as np

import gymnasium as gym
from gymnasium.spaces import Box


[docs]class GrayScaleObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): """Convert the image observation from RGB to gray scale. Example: >>> import gymnasium as gym >>> from gymnasium.wrappers import GrayScaleObservation >>> env = gym.make("CarRacing-v2") >>> env.observation_space Box(0, 255, (96, 96, 3), uint8) >>> env = GrayScaleObservation(gym.make("CarRacing-v2")) >>> env.observation_space Box(0, 255, (96, 96), uint8) >>> env = GrayScaleObservation(gym.make("CarRacing-v2"), keep_dim=True) >>> env.observation_space Box(0, 255, (96, 96, 1), uint8) """ def __init__(self, env: gym.Env, keep_dim: bool = False): """Convert the image observation from RGB to gray scale. Args: env (Env): The environment to apply the wrapper keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1. Otherwise, they are of shape AxB. """ gym.utils.RecordConstructorArgs.__init__(self, keep_dim=keep_dim) gym.ObservationWrapper.__init__(self, env) self.keep_dim = keep_dim assert ( isinstance(self.observation_space, Box) and len(self.observation_space.shape) == 3 and self.observation_space.shape[-1] == 3 ) obs_shape = self.observation_space.shape[:2] if self.keep_dim: self.observation_space = Box( low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8 ) else: self.observation_space = Box( low=0, high=255, shape=obs_shape, dtype=np.uint8 ) def observation(self, observation): """Converts the colour observation to greyscale. Args: observation: Color observations Returns: Grayscale observations """ import cv2 observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY) if self.keep_dim: observation = np.expand_dims(observation, -1) return observation