"""Utility functions for vector environments to share memory between processes."""
from __future__ import annotations
import multiprocessing as mp
from collections.abc import Mapping
from ctypes import c_bool, c_int32, c_int64
from functools import singledispatch
from multiprocessing.sharedctypes import SynchronizedArray
from types import ModuleType
from typing import TYPE_CHECKING, Any, TypeAlias
import numpy as np
from gymnasium.error import CustomSpaceError
from gymnasium.spaces import (
Box,
Dict,
Discrete,
Graph,
MultiBinary,
MultiDiscrete,
OneOf,
Sequence,
Space,
Text,
Tuple,
flatten,
)
if TYPE_CHECKING:
from typing_extensions import Never, Unpack
__all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_memory"]
_SharedMemory: TypeAlias = dict[str, Any] | tuple[Any, ...] | SynchronizedArray
_SharedMemoryOneOf: TypeAlias = tuple[
"SynchronizedArray[c_int64]",
"Unpack[tuple[SynchronizedArray[Any], ...]]",
]
[docs]
@singledispatch
def create_shared_memory(
space: Space[Any], n: int = 1, ctx: ModuleType = mp
) -> _SharedMemory:
"""Create a shared memory object, to be shared across processes.
This eventually contains the observations from the vectorized environment.
Args:
space: Observation space of a single environment in the vectorized environment.
n: Number of environments in the vectorized environment (i.e. the number of processes).
ctx: The multiprocess module
Returns:
shared_memory for the shared object across processes.
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
if isinstance(space, Space):
raise CustomSpaceError(
f"Space of type `{type(space)}` doesn't have an registered `create_shared_memory` function. Register `{type(space)}` for `create_shared_memory` to support it."
)
else:
raise TypeError(
f"The space provided to `create_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@create_shared_memory.register(Box)
@create_shared_memory.register(Discrete)
@create_shared_memory.register(MultiDiscrete)
@create_shared_memory.register(MultiBinary)
def _create_base_shared_memory(
space: Box | Discrete | MultiDiscrete | MultiBinary,
n: int = 1,
ctx: ModuleType = mp,
) -> SynchronizedArray[Any]:
assert space.dtype is not None
assert space.shape is not None
dtype = space.dtype.char
if dtype in "?":
dtype = c_bool
return ctx.Array(dtype, n * int(np.prod(space.shape)))
@create_shared_memory.register(Tuple)
def _create_tuple_shared_memory(
space: Tuple, n: int = 1, ctx: ModuleType = mp
) -> tuple[Any, ...]:
return tuple(
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
)
@create_shared_memory.register(Dict)
def _create_dict_shared_memory(
space: Dict, n: int = 1, ctx: ModuleType = mp
) -> dict[str, Any]:
return {
key: create_shared_memory(subspace, n=n, ctx=ctx)
for (key, subspace) in space.spaces.items()
}
@create_shared_memory.register(Text)
def _create_text_shared_memory(
space: Text, n: int = 1, ctx: ModuleType = mp
) -> SynchronizedArray[c_int32]:
return ctx.Array(np.dtype(np.int32).char, n * space.max_length)
@create_shared_memory.register(OneOf)
def _create_oneof_shared_memory(
space: OneOf, n: int = 1, ctx: ModuleType = mp
) -> _SharedMemoryOneOf:
return (ctx.Array(np.dtype(np.int64).char, n),) + tuple(
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
)
@create_shared_memory.register(Graph)
@create_shared_memory.register(Sequence)
def _create_dynamic_shared_memory(
space: Graph | Sequence, n: int = 1, ctx: ModuleType = mp
) -> Never:
raise TypeError(
f"As {space} has a dynamic shape so its not possible to make a static shared memory. For `AsyncVectorEnv`, disable `shared_memory`."
)
[docs]
@singledispatch
def read_from_shared_memory(
space: Space, shared_memory: _SharedMemory, n: int = 1
) -> dict[str, Any] | tuple[Any, ...] | np.ndarray:
"""Read the batch of observations from shared memory as a numpy array.
..notes::
The numpy array objects returned by `read_from_shared_memory` shares the
memory of `shared_memory`. Any changes to `shared_memory` are forwarded
to `observations`, and vice-versa. To avoid any side-effect, use `np.copy`.
Args:
space: Observation space of a single environment in the vectorized environment.
shared_memory: Shared object across processes. This contains the observations from the vectorized environment.
This object is created with `create_shared_memory`.
n: Number of environments in the vectorized environment (i.e. the number of processes).
Returns:
Batch of observations as a (possibly nested) numpy array.
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
if isinstance(space, Space):
raise CustomSpaceError(
f"Space of type `{type(space)}` doesn't have an registered `read_from_shared_memory` function. Register `{type(space)}` for `read_from_shared_memory` to support it."
)
else:
raise TypeError(
f"The space provided to `read_from_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@read_from_shared_memory.register(Box)
@read_from_shared_memory.register(Discrete)
@read_from_shared_memory.register(MultiDiscrete)
@read_from_shared_memory.register(MultiBinary)
def _read_base_from_shared_memory(
space: Box | Discrete | MultiDiscrete | MultiBinary,
shared_memory: SynchronizedArray[Any],
n: int = 1,
) -> np.ndarray:
assert space.shape is not None
# the `ty:ignore` is needed because of a bug in the typeshed `multiprocessing` stubs
return np.frombuffer( # ty:ignore[no-matching-overload]
shared_memory.get_obj(),
dtype=space.dtype,
).reshape((n,) + space.shape)
@read_from_shared_memory.register(Tuple)
def _read_tuple_from_shared_memory(
space: Tuple, shared_memory: tuple[_SharedMemory, ...], n: int = 1
) -> tuple[Any, ...]:
return tuple(
read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory, space.spaces, strict=True)
)
@read_from_shared_memory.register(Dict)
def _read_dict_from_shared_memory(
space: Dict, shared_memory: dict[str, _SharedMemory], n: int = 1
) -> dict[str, Any]:
return {
key: read_from_shared_memory(subspace, shared_memory[key], n=n)
for (key, subspace) in space.spaces.items()
}
@read_from_shared_memory.register(Text)
def _read_text_from_shared_memory(
space: Text, shared_memory: SynchronizedArray[c_int32], n: int = 1
) -> tuple[str, ...]:
# the `ty:ignore` is needed because of a bug in the typeshed `multiprocessing` stubs
data = np.frombuffer( # ty:ignore[no-matching-overload]
shared_memory.get_obj(),
dtype=np.int32,
).reshape((n, space.max_length))
return tuple(
"".join(
[
space.character_list[val]
for val in values
if val < len(space.character_set)
]
)
for values in data
)
@read_from_shared_memory.register(OneOf)
def _read_one_of_from_shared_memory(
space: OneOf, shared_memory: _SharedMemoryOneOf, n: int = 1
) -> tuple[Any, ...]:
# typeshed bug: `Array[_SimpleCData[c_int64]]` is missing `__buffer__` method stubs
sample_indexes = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int64) # ty:ignore[no-matching-overload]
subspace_samples = tuple(
read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory[1:], space.spaces, strict=True)
)
return tuple(
(sample_index, subspace_samples[sample_index][index])
for index, sample_index in enumerate(sample_indexes)
)
[docs]
@singledispatch
def write_to_shared_memory(
space: Space,
index: int,
value: np.ndarray,
shared_memory: dict[str, Any] | tuple[Any, ...] | SynchronizedArray,
) -> None:
"""Write the observation of a single environment into shared memory.
Args:
space: Observation space of a single environment in the vectorized environment.
index: Index of the environment (must be in `[0, num_envs)`).
value: Observation of the single environment to write to shared memory.
shared_memory: Shared object across processes. This contains the observations from the vectorized environment.
This object is created with `create_shared_memory`.
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
if isinstance(space, Space):
raise CustomSpaceError(
f"Space of type `{type(space)}` doesn't have an registered `write_to_shared_memory` function. Register `{type(space)}` for `write_to_shared_memory` to support it."
)
else:
raise TypeError(
f"The space provided to `write_to_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@write_to_shared_memory.register(Box)
@write_to_shared_memory.register(Discrete)
@write_to_shared_memory.register(MultiDiscrete)
@write_to_shared_memory.register(MultiBinary)
def _write_base_to_shared_memory(
space: Box | Discrete | MultiDiscrete | MultiBinary,
index: int,
value: np.typing.ArrayLike,
shared_memory: SynchronizedArray[Any],
) -> None:
assert space.shape is not None
size = int(np.prod(space.shape))
# the `ty:ignore` is needed because of a bug in the typeshed `multiprocessing` stubs
destination = np.frombuffer(shared_memory.get_obj(), dtype=space.dtype) # ty:ignore[no-matching-overload]
np.copyto(
destination[index * size : (index + 1) * size],
np.asarray(value, dtype=space.dtype).flatten(),
)
@write_to_shared_memory.register(Tuple)
def _write_tuple_to_shared_memory(
space: Tuple,
index: int,
values: tuple[Any, ...],
shared_memory: tuple[_SharedMemory, ...],
) -> None:
for value, memory, subspace in zip(
values, shared_memory, space.spaces, strict=True
):
write_to_shared_memory(subspace, index, value, memory)
@write_to_shared_memory.register(Dict)
def _write_dict_to_shared_memory(
space: Dict, index: int, values: dict[str, Any], shared_memory: Mapping[str, Any]
) -> None:
for key, subspace in space.spaces.items():
write_to_shared_memory(subspace, index, values[key], shared_memory[key])
@write_to_shared_memory.register(Text)
def _write_text_to_shared_memory(
space: Text, index: int, values: str, shared_memory: SynchronizedArray[c_int32]
) -> None:
size = space.max_length
# the `ty:ignore` is needed because of a bug in the typeshed `multiprocessing` stubs
destination = np.frombuffer(shared_memory.get_obj(), dtype=np.int32) # ty:ignore[no-matching-overload]
np.copyto(
destination[index * size : (index + 1) * size],
flatten(space, values), # ty:ignore[invalid-argument-type]
)
@write_to_shared_memory.register(OneOf)
def _write_oneof_to_shared_memory(
space: OneOf, index: int, values: tuple[int, Any], shared_memory: _SharedMemoryOneOf
) -> None:
subspace_idx, space_value = values
# typeshed bug: `Array[_SimpleCData[c_int64]]` is missing `__buffer__` method stubs
destination = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int64) # ty:ignore[no-matching-overload]
np.copyto(destination[index : index + 1], subspace_idx)
# only the subspace's memory is updated with the sample value, ignoring the other memories as data might not match
write_to_shared_memory(
space.spaces[subspace_idx], index, space_value, shared_memory[1 + subspace_idx]
)