Source code for gymnasium.wrappers.numpy_to_torch
"""Helper functions and wrapper class for converting between PyTorch and NumPy."""
from __future__ import annotations
import functools
from typing import Union
import numpy as np
import gymnasium as gym
from gymnasium.error import DependencyNotInstalled
from gymnasium.wrappers.array_conversion import (
ArrayConversion,
array_conversion,
module_namespace,
)
try:
import torch
Device = Union[str, torch.device]
except ImportError:
raise DependencyNotInstalled(
'Torch is not installed therefore cannot call `torch_to_numpy`, run `pip install "gymnasium[torch]"`'
)
__all__ = ["NumpyToTorch", "torch_to_numpy", "numpy_to_torch", "Device"]
torch_to_numpy = functools.partial(array_conversion, xp=module_namespace(np))
numpy_to_torch = functools.partial(array_conversion, xp=module_namespace(torch))
[docs]
class NumpyToTorch(ArrayConversion):
"""Wraps a NumPy-based environment such that it can be interacted with PyTorch Tensors.
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.NumpyToTorch`.
Note:
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
Example:
>>> import torch
>>> import gymnasium as gym
>>> env = gym.make("CartPole-v1")
>>> env = NumpyToTorch(env)
>>> obs, _ = env.reset(seed=123)
>>> type(obs)
<class 'torch.Tensor'>
>>> action = torch.tensor(env.action_space.sample())
>>> obs, reward, terminated, truncated, info = env.step(action)
>>> type(obs)
<class 'torch.Tensor'>
>>> type(reward)
<class 'float'>
>>> type(terminated)
<class 'bool'>
>>> type(truncated)
<class 'bool'>
Change logs:
* v1.0.0 - Initially added
"""
def __init__(self, env: gym.Env, device: Device | None = None):
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
Args:
env: The NumPy-based environment to wrap
device: The device the torch Tensors should be moved to
"""
super().__init__(env=env, env_xp=np, target_xp=torch, target_device=device)
self.device: Device | None = device