Source code for gymnasium.spaces.graph

"""Implementation of a space that represents graph information where nodes and edges can be represented with euclidean space."""

from __future__ import annotations

from typing import Any, NamedTuple, Sequence

import numpy as np
from numpy.typing import NDArray

import gymnasium as gym
from import Box
from gymnasium.spaces.discrete import Discrete
from gymnasium.spaces.multi_discrete import MultiDiscrete
from import Space

class GraphInstance(NamedTuple):
    """A Graph space instance.

    * nodes (np.ndarray): an (n x ...) sized array representing the features for n nodes, (...) must adhere to the shape of the node space.
    * edges (Optional[np.ndarray]): an (m x ...) sized array representing the features for m edges, (...) must adhere to the shape of the edge space.
    * edge_links (Optional[np.ndarray]): an (m x 2) sized array of ints representing the indices of the two nodes that each edge connects.

    nodes: NDArray[Any]
    edges: NDArray[Any] | None
    edge_links: NDArray[Any] | None

[docs] class Graph(Space[GraphInstance]): r"""A space representing graph information as a series of ``nodes`` connected with ``edges`` according to an adjacency matrix represented as a series of ``edge_links``. Example: >>> from gymnasium.spaces import Graph, Box, Discrete >>> observation_space = Graph(node_space=Box(low=-100, high=100, shape=(3,)), edge_space=Discrete(3), seed=123) >>> observation_space.sample(num_nodes=4, num_edges=8) GraphInstance(nodes=array([[ 36.47037 , -89.235794, -55.928024], [-63.125637, -64.81882 , 62.4189 ], [ 84.669 , -44.68512 , 63.950912], [ 77.97854 , 2.594091, -51.00708 ]], dtype=float32), edges=array([2, 0, 2, 1, 2, 0, 2, 1]), edge_links=array([[3, 0], [0, 0], [0, 1], [0, 2], [1, 0], [1, 0], [0, 1], [0, 2]], dtype=int32)) """ def __init__( self, node_space: Box | Discrete, edge_space: None | Box | Discrete, seed: int | np.random.Generator | None = None, ): r"""Constructor of :class:`Graph`. The argument ``node_space`` specifies the base space that each node feature will use. This argument must be either a Box or Discrete instance. The argument ``edge_space`` specifies the base space that each edge feature will use. This argument must be either a None, Box or Discrete instance. Args: node_space (Union[Box, Discrete]): space of the node features. edge_space (Union[None, Box, Discrete]): space of the edge features. seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space. """ assert isinstance( node_space, (Box, Discrete) ), f"Values of the node_space should be instances of Box or Discrete, got {type(node_space)}" if edge_space is not None: assert isinstance( edge_space, (Box, Discrete) ), f"Values of the edge_space should be instances of None Box or Discrete, got {type(edge_space)}" self.node_space = node_space self.edge_space = edge_space super().__init__(None, None, seed) @property def is_np_flattenable(self): """Checks whether this space can be flattened to a :class:`spaces.Box`.""" return False def _generate_sample_space( self, base_space: None | Box | Discrete, num: int ) -> Box | MultiDiscrete | None: if num == 0 or base_space is None: return None if isinstance(base_space, Box): return Box( low=np.array(max(1, num) * [base_space.low]), high=np.array(max(1, num) * [base_space.high]), shape=(num,) + base_space.shape, dtype=base_space.dtype, seed=self.np_random, ) elif isinstance(base_space, Discrete): return MultiDiscrete(nvec=[base_space.n] * num, seed=self.np_random) else: raise TypeError( f"Expects base space to be Box and Discrete, actual space: {type(base_space)}." )
[docs] def seed( self, seed: int | tuple[int, int] | tuple[int, int, int] | None = None ) -> tuple[int, int] | tuple[int, int, int]: """Seeds the PRNG of this space and node / edge subspace. Depending on the type of seed, the subspaces will be seeded differently * ``None`` - The root, node and edge spaces PRNG are randomly initialized * ``Int`` - The integer is used to seed the :class:`Graph` space that is used to generate seed values for the node and edge subspaces. * ``Tuple[int, int]`` - Seeds the :class:`Graph` and node subspace with a particular value. Only if edge subspace isn't specified * ``Tuple[int, int, int]`` - Seeds the :class:`Graph`, node and edge subspaces with a particular value. Args: seed: An optional int or tuple of ints for this space and the node / edge subspaces. See above for more details. Returns: A tuple of two or three ints depending on if the edge subspace is specified. """ if seed is None: if self.edge_space is None: return super().seed(None), self.node_space.seed(None) else: return ( super().seed(None), self.node_space.seed(None), self.edge_space.seed(None), ) elif isinstance(seed, int): if self.edge_space is None: super_seed = super().seed(seed) node_seed = int(self.np_random.integers(np.iinfo(np.int32).max)) # this is necessary such that after int or list/tuple seeding, the Graph PRNG are equivalent super().seed(seed) return super_seed, self.node_space.seed(node_seed) else: super_seed = super().seed(seed) node_seed, edge_seed = self.np_random.integers( np.iinfo(np.int32).max, size=(2,) ) # this is necessary such that after int or list/tuple seeding, the Graph PRNG are equivalent super().seed(seed) return ( super_seed, self.node_space.seed(int(node_seed)), self.edge_space.seed(int(edge_seed)), ) elif isinstance(seed, (list, tuple)): if self.edge_space is None: if len(seed) != 2: raise ValueError( f"Expects a tuple of two values for Graph and node space, actual length: {len(seed)}" ) return super().seed(seed[0]), self.node_space.seed(seed[1]) else: if len(seed) != 3: raise ValueError( f"Expects a tuple of three values for Graph, node and edge space, actual length: {len(seed)}" ) return ( super().seed(seed[0]), self.node_space.seed(seed[1]), self.edge_space.seed(seed[2]), ) else: raise TypeError( f"Expects `None`, int or tuple of ints, actual type: {type(seed)}" )
[docs] def sample( self, mask: None | ( tuple[ NDArray[Any] | tuple[Any, ...] | None, NDArray[Any] | tuple[Any, ...] | None, ] ) = None, num_nodes: int = 10, num_edges: int | None = None, ) -> GraphInstance: """Generates a single sample graph with num_nodes between ``1`` and ``10`` sampled from the Graph. Args: mask: An optional tuple of optional node and edge mask that is only possible with Discrete spaces (Box spaces don't support sample masks). If no ``num_edges`` is provided then the ``edge_mask`` is multiplied by the number of edges num_nodes: The number of nodes that will be sampled, the default is `10` nodes num_edges: An optional number of edges, otherwise, a random number between `0` and :math:`num_nodes^2` Returns: A :class:`GraphInstance` with attributes `.nodes`, `.edges`, and `.edge_links`. """ assert ( num_nodes > 0 ), f"The number of nodes is expected to be greater than 0, actual value: {num_nodes}" if mask is not None: node_space_mask, edge_space_mask = mask else: node_space_mask, edge_space_mask = None, None # we only have edges when we have at least 2 nodes if num_edges is None: if num_nodes > 1: # maximal number of edges is `n*(n-1)` allowing self connections and two-way is allowed num_edges = self.np_random.integers(num_nodes * (num_nodes - 1)) else: num_edges = 0 if edge_space_mask is not None: edge_space_mask = tuple(edge_space_mask for _ in range(num_edges)) else: if self.edge_space is None: gym.logger.warn( f"The number of edges is set ({num_edges}) but the edge space is None." ) assert ( num_edges >= 0 ), f"Expects the number of edges to be greater than 0, actual value: {num_edges}" assert num_edges is not None sampled_node_space = self._generate_sample_space(self.node_space, num_nodes) sampled_edge_space = self._generate_sample_space(self.edge_space, num_edges) assert sampled_node_space is not None sampled_nodes = sampled_node_space.sample(node_space_mask) sampled_edges = ( sampled_edge_space.sample(edge_space_mask) if sampled_edge_space is not None else None ) sampled_edge_links = None if sampled_edges is not None and num_edges > 0: sampled_edge_links = self.np_random.integers( low=0, high=num_nodes, size=(num_edges, 2), dtype=np.int32 ) return GraphInstance(sampled_nodes, sampled_edges, sampled_edge_links)
def contains(self, x: GraphInstance) -> bool: """Return boolean specifying if x is a valid member of this space.""" if isinstance(x, GraphInstance): # Checks the nodes if isinstance(x.nodes, np.ndarray): if all(node in self.node_space for node in x.nodes): # Check the edges and edge links which are optional if isinstance(x.edges, np.ndarray) and isinstance( x.edge_links, np.ndarray ): assert x.edges is not None assert x.edge_links is not None if self.edge_space is not None: if all(edge in self.edge_space for edge in x.edges): if np.issubdtype(x.edge_links.dtype, np.integer): if x.edge_links.shape == (len(x.edges), 2): if np.all( np.logical_and( x.edge_links >= 0, x.edge_links < len(x.nodes), ) ): return True else: return x.edges is None and x.edge_links is None return False def __repr__(self) -> str: """A string representation of this space. The representation will include ``node_space`` and ``edge_space`` Returns: A representation of the space """ return f"Graph({self.node_space}, {self.edge_space})" def __eq__(self, other: Any) -> bool: """Check whether `other` is equivalent to this instance.""" return ( isinstance(other, Graph) and (self.node_space == other.node_space) and (self.edge_space == other.edge_space) ) def to_jsonable( self, sample_n: Sequence[GraphInstance] ) -> list[dict[str, list[int | float]]]: """Convert a batch of samples from this space to a JSONable data type.""" ret_n = [] for sample in sample_n: ret = {"nodes": sample.nodes.tolist()} if sample.edges is not None and sample.edge_links is not None: ret["edges"] = sample.edges.tolist() ret["edge_links"] = sample.edge_links.tolist() ret_n.append(ret) return ret_n def from_jsonable( self, sample_n: Sequence[dict[str, list[list[int] | list[float]]]] ) -> list[GraphInstance]: """Convert a JSONable data type to a batch of samples from this space.""" ret: list[GraphInstance] = [] for sample in sample_n: if "edges" in sample: assert self.edge_space is not None ret_n = GraphInstance( np.asarray(sample["nodes"], dtype=self.node_space.dtype), np.asarray(sample["edges"], dtype=self.edge_space.dtype), np.asarray(sample["edge_links"], dtype=np.int32), ) else: ret_n = GraphInstance( np.asarray(sample["nodes"], dtype=self.node_space.dtype), None, None, ) ret.append(ret_n) return ret