Source code for gymnasium.wrappers.vector.jax_to_numpy
"""Vector wrapper for converting between NumPy and Jax."""
from __future__ import annotations
import jax.numpy as jnp
import numpy as np
from gymnasium.error import DependencyNotInstalled
from gymnasium.vector import VectorEnv
from gymnasium.wrappers.vector.array_conversion import ArrayConversion
__all__ = ["JaxToNumpy"]
[docs]
class JaxToNumpy(ArrayConversion):
"""Wraps a jax vector environment so that it can be interacted with through numpy arrays.
Notes:
A vectorized version of :class:`gymnasium.wrappers.JaxToNumpy`
Actions must be provided as numpy arrays and observations, rewards, terminations and truncations will be returned as numpy arrays.
Example:
>>> import gymnasium as gym # doctest: +SKIP
>>> envs = gym.make_vec("JaxEnv-vx", 3) # doctest: +SKIP
>>> envs = JaxToNumpy(envs) # doctest: +SKIP
"""
def __init__(self, env: VectorEnv):
"""Wraps an environment such that the input and outputs are numpy arrays.
Args:
env: the vector jax environment to wrap
"""
if jnp is None:
raise DependencyNotInstalled(
'Jax is not installed, run `pip install "gymnasium[jax]"`'
)
super().__init__(env, env_xp=jnp, target_xp=np)