Source code for gymnasium.wrappers.vector.stateful_observation

"""A collection of stateful observation wrappers.

* ``NormalizeObservation`` - Normalize the observations
"""
from __future__ import annotations

import numpy as np

import gymnasium as gym
from gymnasium.core import ObsType
from gymnasium.vector.vector_env import VectorEnv, VectorObservationWrapper
from gymnasium.wrappers.utils import RunningMeanStd


__all__ = ["NormalizeObservation"]


[docs] class NormalizeObservation(VectorObservationWrapper, gym.utils.RecordConstructorArgs): """This wrapper will normalize observations s.t. each coordinate is centered with unit variance. The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation statistics. If `True` (default), the `RunningMeanStd` will get updated every step and reset call. If `False`, the calculated statistics are used but not updated anymore; this may be used during evaluation. Note: The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was newly instantiated or the policy was changed recently. Example without the normalize reward wrapper: >>> import gymnasium as gym >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync") >>> obs, info = envs.reset(seed=123) >>> _ = envs.action_space.seed(123) >>> for _ in range(100): ... obs, *_ = envs.step(envs.action_space.sample()) >>> np.mean(obs) 0.024251968 >>> np.std(obs) 0.62259156 >>> envs.close() Example with the normalize reward wrapper: >>> import gymnasium as gym >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync") >>> envs = NormalizeObservation(envs) >>> obs, info = envs.reset(seed=123) >>> _ = envs.action_space.seed(123) >>> for _ in range(100): ... obs, *_ = envs.step(envs.action_space.sample()) >>> np.mean(obs) -0.2359734 >>> np.std(obs) 1.1938739 >>> envs.close() """ def __init__(self, env: VectorEnv, epsilon: float = 1e-8): """This wrapper will normalize observations s.t. each coordinate is centered with unit variance. Args: env (Env): The environment to apply the wrapper epsilon: A stability parameter that is used when scaling the observations. """ gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon) VectorObservationWrapper.__init__(self, env) self.obs_rms = RunningMeanStd( shape=self.single_observation_space.shape, dtype=self.single_observation_space.dtype, ) self.epsilon = epsilon self._update_running_mean = True @property def update_running_mean(self) -> bool: """Property to freeze/continue the running mean calculation of the observation statistics.""" return self._update_running_mean @update_running_mean.setter def update_running_mean(self, setting: bool): """Sets the property to freeze/continue the running mean calculation of the observation statistics.""" self._update_running_mean = setting def observations(self, observations: ObsType) -> ObsType: """Defines the vector observation normalization function. Args: observations: A vector observation from the environment Returns: the normalized observation """ if self._update_running_mean: self.obs_rms.update(observations) return (observations - self.obs_rms.mean) / np.sqrt( self.obs_rms.var + self.epsilon )