Source code for gymnasium.wrappers.flatten_observation

"""Wrapper for flattening observations of an environment."""
import gymnasium as gym
from gymnasium import spaces


[docs]class FlattenObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): """Observation wrapper that flattens the observation. Example: >>> import gymnasium as gym >>> from gymnasium.wrappers import FlattenObservation >>> env = gym.make("CarRacing-v2") >>> env.observation_space.shape (96, 96, 3) >>> env = FlattenObservation(env) >>> env.observation_space.shape (27648,) >>> obs, _ = env.reset() >>> obs.shape (27648,) """ def __init__(self, env: gym.Env): """Flattens the observations of an environment. Args: env: The environment to apply the wrapper """ gym.utils.RecordConstructorArgs.__init__(self) gym.ObservationWrapper.__init__(self, env) self.observation_space = spaces.flatten_space(env.observation_space) def observation(self, observation): """Flattens an observation. Args: observation: The observation to flatten Returns: The flattened observation """ return spaces.flatten(self.env.observation_space, observation)