"""A collection of wrappers for modifying the reward.
* ``TransformReward`` - Transforms the reward by a function
* ``ClipReward`` - Clips the reward between a minimum and maximum value
"""
from __future__ import annotations
from typing import Callable, SupportsFloat
import numpy as np
import gymnasium as gym
from gymnasium.core import ActType, ObsType
from gymnasium.error import InvalidBound
__all__ = ["TransformReward", "ClipReward"]
[docs]
class ClipReward(TransformReward[ObsType, ActType], gym.utils.RecordConstructorArgs):
"""Clips the rewards for an environment between an upper and lower bound.
A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.ClipReward`.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import ClipReward
>>> env = gym.make("CartPole-v1")
>>> env = ClipReward(env, 0, 0.5)
>>> _ = env.reset()
>>> _, rew, _, _, _ = env.step(1)
>>> rew
0.5
Change logs:
* v1.0.0 - Initially added
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
min_reward: float | np.ndarray | None = None,
max_reward: float | np.ndarray | None = None,
):
"""Initialize ClipRewards wrapper.
Args:
env (Env): The environment to wrap
min_reward (Union[float, np.ndarray]): lower bound to apply
max_reward (Union[float, np.ndarray]): higher bound to apply
"""
if min_reward is None and max_reward is None:
raise InvalidBound("Both `min_reward` and `max_reward` cannot be None")
elif max_reward is not None and min_reward is not None:
if np.any(max_reward - min_reward < 0):
raise InvalidBound(
f"Min reward ({min_reward}) must be smaller than max reward ({max_reward})"
)
gym.utils.RecordConstructorArgs.__init__(
self, min_reward=min_reward, max_reward=max_reward
)
TransformReward.__init__(
self, env=env, func=lambda x: np.clip(x, a_min=min_reward, a_max=max_reward)
)