Source code for gymnasium.wrappers.vector.stateful_observation
"""A collection of stateful observation wrappers.
* ``NormalizeObservation`` - Normalize the observations
"""
from __future__ import annotations
import numpy as np
import gymnasium as gym
from gymnasium.core import ObsType
from gymnasium.vector.vector_env import VectorEnv, VectorObservationWrapper
from gymnasium.wrappers.utils import RunningMeanStd
__all__ = ["NormalizeObservation"]
[docs]
class NormalizeObservation(VectorObservationWrapper, gym.utils.RecordConstructorArgs):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation
statistics. If `True` (default), the `RunningMeanStd` will get updated every step and reset call.
If `False`, the calculated statistics are used but not updated anymore; this may be used during evaluation.
Note:
The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was
newly instantiated or the policy was changed recently.
Example without the normalize reward wrapper:
>>> import gymnasium as gym
>>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
>>> obs, info = envs.reset(seed=123)
>>> _ = envs.action_space.seed(123)
>>> for _ in range(100):
... obs, *_ = envs.step(envs.action_space.sample())
>>> np.mean(obs)
np.float32(0.024251968)
>>> np.std(obs)
np.float32(0.62259156)
>>> envs.close()
Example with the normalize reward wrapper:
>>> import gymnasium as gym
>>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
>>> envs = NormalizeObservation(envs)
>>> obs, info = envs.reset(seed=123)
>>> _ = envs.action_space.seed(123)
>>> for _ in range(100):
... obs, *_ = envs.step(envs.action_space.sample())
>>> np.mean(obs)
np.float32(-0.2359734)
>>> np.std(obs)
np.float32(1.1938739)
>>> envs.close()
"""
def __init__(self, env: VectorEnv, epsilon: float = 1e-8):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
Args:
env (Env): The environment to apply the wrapper
epsilon: A stability parameter that is used when scaling the observations.
"""
gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
VectorObservationWrapper.__init__(self, env)
self.obs_rms = RunningMeanStd(
shape=self.single_observation_space.shape,
dtype=self.single_observation_space.dtype,
)
self.epsilon = epsilon
self._update_running_mean = True
@property
def update_running_mean(self) -> bool:
"""Property to freeze/continue the running mean calculation of the observation statistics."""
return self._update_running_mean
@update_running_mean.setter
def update_running_mean(self, setting: bool):
"""Sets the property to freeze/continue the running mean calculation of the observation statistics."""
self._update_running_mean = setting
def observations(self, observations: ObsType) -> ObsType:
"""Defines the vector observation normalization function.
Args:
observations: A vector observation from the environment
Returns:
the normalized observation
"""
if self._update_running_mean:
self.obs_rms.update(observations)
return (observations - self.obs_rms.mean) / np.sqrt(
self.obs_rms.var + self.epsilon
)