Source code for gymnasium.wrappers.stateful_action

"""Stateful action wrappers - ``StickyAction`` and ``RepeatAction``."""

from __future__ import annotations

from typing import Any, SupportsFloat

import numpy as np

import gymnasium as gym
from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType
from gymnasium.error import InvalidBound, InvalidProbability

__all__ = ["StickyAction", "RepeatAction"]


[docs] class StickyAction( gym.ActionWrapper[ObsType, ActType, ActType], gym.utils.RecordConstructorArgs ): """Adds a probability that the action is repeated for the same ``step`` function. This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_ in Section 5.2 on page 12, and adds the possibility to repeat the action for more than one step. No vector version of the wrapper exists. Example: >>> import gymnasium as gym >>> env = gym.make("CartPole-v1") >>> env = StickyAction(env, repeat_action_probability=0.9) >>> env.reset(seed=123) (array([ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], dtype=float32), {}) >>> env.step(1) (array([ 0.01734283, 0.15089367, -0.02859527, -0.33293587], dtype=float32), 1.0, False, False, {}) >>> env.step(0) (array([ 0.0203607 , 0.34641072, -0.03525399, -0.6344974 ], dtype=float32), 1.0, False, False, {}) >>> env.step(1) (array([ 0.02728892, 0.5420062 , -0.04794393, -0.9380709 ], dtype=float32), 1.0, False, False, {}) >>> env.step(0) (array([ 0.03812904, 0.34756234, -0.06670535, -0.6608303 ], dtype=float32), 1.0, False, False, {}) Change logs: * v1.0.0 - Initially added * v1.1.0 - Add `repeat_action_duration` argument for dynamic number of sticky actions """ def __init__( self, env: gym.Env[ObsType, ActType], repeat_action_probability: float, repeat_action_duration: int | tuple[int, int] = 1, ): """Initialize StickyAction wrapper. Args: env (Env): the wrapped environment, repeat_action_probability (int | float): a probability of repeating the old action, repeat_action_duration (int | tuple[int, int]): the number of steps the action is repeated. It can be either an int (for deterministic repeats) or a tuple[int, int] for a range of stochastic number of repeats. """ if not 0 <= repeat_action_probability < 1: raise InvalidProbability( f"`repeat_action_probability` should be in the interval [0,1). Received {repeat_action_probability}" ) if isinstance(repeat_action_duration, int): repeat_action_duration = (repeat_action_duration, repeat_action_duration) if not isinstance(repeat_action_duration, tuple): raise ValueError( f"`repeat_action_duration` should be either an integer or a tuple. Received {repeat_action_duration}" ) elif len(repeat_action_duration) != 2: raise ValueError( f"`repeat_action_duration` should be a tuple or a list of two integers. Received {repeat_action_duration}" ) elif repeat_action_duration[0] > repeat_action_duration[1]: raise InvalidBound( f"`repeat_action_duration` is not a valid bound. Received {repeat_action_duration}" ) elif np.any(np.array(repeat_action_duration) < 1): raise ValueError( f"`repeat_action_duration` should be larger or equal than 1. Received {repeat_action_duration}" ) gym.utils.RecordConstructorArgs.__init__( self, repeat_action_probability=repeat_action_probability ) gym.ActionWrapper.__init__(self, env) self.repeat_action_probability = repeat_action_probability self.repeat_action_duration_range = repeat_action_duration self.last_action: ActType | None = None self.is_sticky_actions: bool = False # if sticky actions are taken self.num_repeats: int = 0 # number of sticky action repeats self.repeats_taken: int = 0 # number of sticky actions taken def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[ObsType, dict[str, Any]]: """Reset the environment.""" self.last_action = None self.is_sticky_actions = False self.num_repeats = 0 self.repeats_taken = 0 return super().reset(seed=seed, options=options) def action(self, action: ActType) -> ActType: """Execute the action.""" # either the agent was already "stuck" into repeats, or a new series of repeats is triggered if self.is_sticky_actions or ( self.last_action is not None and self.np_random.uniform() < self.repeat_action_probability ): # if a new series starts, randomly sample its duration if self.num_repeats == 0: self.num_repeats = self.np_random.integers( self.repeat_action_duration_range[0], self.repeat_action_duration_range[1] + 1, ) action = self.last_action self.is_sticky_actions = True self.repeats_taken += 1 # repeats are done, reset "stuck" status if self.is_sticky_actions and self.num_repeats == self.repeats_taken: self.is_sticky_actions = False self.num_repeats = 0 self.repeats_taken = 0 self.last_action = action return action
[docs] class RepeatAction( gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs ): """Repeats the given action for ``num_repeats`` steps, accumulating the reward. This is useful for environments where the agent does not need to make a decision at every time step. By repeating actions, the effective decision frequency is reduced, which can speed up training and make exploration more efficient. Unlike :class:`StickyAction`, which *stochastically* replaces the agent's action with the *previous* action, ``ActionRepeat`` *deterministically* executes the *current* action multiple times and returns the accumulated reward together with the final observation. If the episode terminates or is truncated during repetition, the loop stops early. No vector version of the wrapper exists. Example: >>> import gymnasium as gym >>> env = gym.make("CartPole-v1") >>> env = RepeatAction(env, num_repeats=4) >>> obs, info = env.reset(seed=123) >>> obs, reward, terminated, truncated, info = env.step(1) >>> reward # sum of 4 inner-step rewards 4.0 Change logs: * v1.3.0 - Initially added """ def __init__(self, env: gym.Env[ObsType, ActType], num_repeats: int): """Initialize ActionRepeat wrapper. Args: env (Env): The wrapped environment. num_repeats (int): The number of times to repeat each action. """ if not np.issubdtype(type(num_repeats), np.integer): raise TypeError( f"The num_repeats is expected to be an integer, actual type: {type(num_repeats)}" ) if num_repeats < 1: raise ValueError( f"The num_repeats value needs to be equal or greater than one, actual value: {num_repeats}" ) gym.utils.RecordConstructorArgs.__init__(self, num_repeats=num_repeats) gym.Wrapper.__init__(self, env) self.num_repeats = num_repeats def step( self, action: WrapperActType ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]: """Step the environment, repeating the action for ``num_repeats`` steps. The reward returned is the sum of rewards from all inner steps. If the episode terminates or is truncated during repetition, the loop stops early. Args: action: The action to repeat. Returns: The last observation, accumulated reward, terminated, truncated, and info. """ total_reward = 0.0 terminated = truncated = False info: dict[str, Any] = {} for _ in range(self.num_repeats): obs, reward, terminated, truncated, info = self.env.step(action) total_reward += float(reward) if terminated or truncated: break return obs, total_reward, terminated, truncated, info