Source code for gymnasium.wrappers.vector.jax_to_torch

"""Vector wrapper class for converting between PyTorch and Jax."""

from __future__ import annotations

from typing import Any

from gymnasium.core import ActType, ObsType
from gymnasium.vector import VectorEnv, VectorWrapper
from gymnasium.vector.vector_env import ArrayType
from gymnasium.wrappers.jax_to_torch import Device, jax_to_torch, torch_to_jax


__all__ = ["JaxToTorch"]


[docs] class JaxToTorch(VectorWrapper): """Wraps a Jax-based vector environment so that it can be interacted with through PyTorch Tensors. Actions must be provided as PyTorch Tensors and observations, rewards, terminations and truncations will be returned as PyTorch Tensors. Example: >>> import gymnasium as gym # doctest: +SKIP >>> envs = gym.make_vec("JaxEnv-vx", 3) # doctest: +SKIP >>> envs = JaxToTorch(envs) # doctest: +SKIP """ def __init__(self, env: VectorEnv, device: Device | None = None): """Vector wrapper to change inputs and outputs to PyTorch tensors. Args: env: The Jax-based vector environment to wrap device: The device the torch Tensors should be moved to """ super().__init__(env) self.device: Device | None = device def step( self, actions: ActType ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]: """Performs the given action within the environment. Args: actions: The action to perform as a PyTorch Tensor Returns: Torch-based Tensors of the next observation, reward, termination, truncation, and extra info """ jax_action = torch_to_jax(actions) obs, reward, terminated, truncated, info = self.env.step(jax_action) return ( jax_to_torch(obs, self.device), jax_to_torch(reward, self.device), jax_to_torch(terminated, self.device), jax_to_torch(truncated, self.device), jax_to_torch(info, self.device), ) def reset( self, *, seed: int | list[int] | None = None, options: dict[str, Any] | None = None, ) -> tuple[ObsType, dict[str, Any]]: """Resets the environment returning PyTorch-based observation and info. Args: seed: The seed for resetting the environment options: The options for resetting the environment, these are converted to jax arrays. Returns: PyTorch-based observations and info """ if options: options = torch_to_jax(options) return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)