Source code for gymnasium.wrappers.vector.jax_to_torch

"""Vector wrapper class for converting between PyTorch and Jax."""

from __future__ import annotations

import jax.numpy as jnp
import torch

from gymnasium.vector import VectorEnv
from gymnasium.wrappers.jax_to_torch import Device
from gymnasium.wrappers.vector.array_conversion import ArrayConversion


__all__ = ["JaxToTorch"]


[docs] class JaxToTorch(ArrayConversion): """Wraps a Jax-based vector environment so that it can be interacted with through PyTorch Tensors. Actions must be provided as PyTorch Tensors and observations, rewards, terminations and truncations will be returned as PyTorch Tensors. Example: >>> import gymnasium as gym # doctest: +SKIP >>> envs = gym.make_vec("JaxEnv-vx", 3) # doctest: +SKIP >>> envs = JaxToTorch(envs) # doctest: +SKIP """ def __init__(self, env: VectorEnv, device: Device | None = None): """Vector wrapper to change inputs and outputs to PyTorch tensors. Args: env: The Jax-based vector environment to wrap device: The device the torch Tensors should be moved to """ super().__init__(env, env_xp=jnp, target_xp=torch, target_device=device) self.device: Device | None = device