Source code for gymnasium.spaces.text

"""Implementation of a space that represents textual strings."""
from __future__ import annotations

from typing import Any

import numpy as np
from numpy.typing import NDArray

from gymnasium.spaces.space import Space


alphanumeric: frozenset[str] = frozenset(
    "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
)


[docs] class Text(Space[str]): r"""A space representing a string comprised of characters from a given charset. Example: >>> from gymnasium.spaces import Text >>> # {"", "B5", "hello", ...} >>> Text(5) Text(1, 5, charset=0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz) >>> # {"0", "42", "0123456789", ...} >>> import string >>> Text(min_length = 1, ... max_length = 10, ... charset = string.digits) Text(1, 10, charset=0123456789) """ def __init__( self, max_length: int, *, min_length: int = 1, charset: frozenset[str] | str = alphanumeric, seed: int | np.random.Generator | None = None, ): r"""Constructor of :class:`Text` space. Both bounds for text length are inclusive. Args: min_length (int): Minimum text length (in characters). Defaults to 1 to prevent empty strings. max_length (int): Maximum text length (in characters). charset (Union[set], str): Character set, defaults to the lower and upper english alphabet plus latin digits. seed: The seed for sampling from the space. """ assert np.issubdtype( type(min_length), np.integer ), f"Expects the min_length to be an integer, actual type: {type(min_length)}" assert np.issubdtype( type(max_length), np.integer ), f"Expects the max_length to be an integer, actual type: {type(max_length)}" assert ( 0 <= min_length ), f"Minimum text length must be non-negative, actual value: {min_length}" assert ( min_length <= max_length ), f"The min_length must be less than or equal to the max_length, min_length: {min_length}, max_length: {max_length}" self.min_length: int = int(min_length) self.max_length: int = int(max_length) self._char_set: frozenset[str] = frozenset(charset) self._char_list: tuple[str, ...] = tuple(charset) self._char_index: dict[str, np.int32] = { val: np.int32(i) for i, val in enumerate(tuple(charset)) } self._char_str: str = "".join(sorted(tuple(charset))) # As the shape is dynamic (between min_length and max_length) then None super().__init__(dtype=str, seed=seed)
[docs] def sample( self, mask: None | (tuple[int | None, NDArray[np.int8] | None]) = None, ) -> str: """Generates a single random sample from this space with by default a random length between ``min_length`` and ``max_length`` and sampled from the ``charset``. Args: mask: An optional tuples of length and mask for the text. The length is expected to be between the ``min_length`` and ``max_length`` otherwise a random integer between ``min_length`` and ``max_length`` is selected. For the mask, we expect a numpy array of length of the charset passed with ``dtype == np.int8``. If the charlist mask is all zero then an empty string is returned no matter the ``min_length`` Returns: A sampled string from the space """ if mask is not None: assert isinstance( mask, tuple ), f"Expects the mask type to be a tuple, actual type: {type(mask)}" assert ( len(mask) == 2 ), f"Expects the mask length to be two, actual length: {len(mask)}" length, charlist_mask = mask if length is not None: assert np.issubdtype( type(length), np.integer ), f"Expects the Text sample length to be an integer, actual type: {type(length)}" assert ( self.min_length <= length <= self.max_length ), f"Expects the Text sample length be between {self.min_length} and {self.max_length}, actual length: {length}" if charlist_mask is not None: assert isinstance( charlist_mask, np.ndarray ), f"Expects the Text sample mask to be an np.ndarray, actual type: {type(charlist_mask)}" assert ( charlist_mask.dtype == np.int8 ), f"Expects the Text sample mask to be an np.ndarray, actual dtype: {charlist_mask.dtype}" assert charlist_mask.shape == ( len(self.character_set), ), f"expects the Text sample mask to be {(len(self.character_set),)}, actual shape: {charlist_mask.shape}" assert np.all( np.logical_or(charlist_mask == 0, charlist_mask == 1) ), f"Expects all masks values to 0 or 1, actual values: {charlist_mask}" else: length, charlist_mask = None, None if length is None: length = self.np_random.integers(self.min_length, self.max_length + 1) if charlist_mask is None: string = self.np_random.choice(self.character_list, size=length) else: valid_mask = charlist_mask == 1 valid_indexes = np.where(valid_mask)[0] if len(valid_indexes) == 0: if self.min_length == 0: string = "" else: # Otherwise the string will not be contained in the space raise ValueError( f"Trying to sample with a minimum length > 0 ({self.min_length}) but the character mask is all zero meaning that no character could be sampled." ) else: string = "".join( self.character_list[index] for index in self.np_random.choice(valid_indexes, size=length) ) return "".join(string)
def contains(self, x: Any) -> bool: """Return boolean specifying if x is a valid member of this space.""" if isinstance(x, str): if self.min_length <= len(x) <= self.max_length: return all(c in self.character_set for c in x) return False def __repr__(self) -> str: """Gives a string representation of this space.""" return f"Text({self.min_length}, {self.max_length}, charset={self.characters})" def __eq__(self, other: Any) -> bool: """Check whether ``other`` is equivalent to this instance.""" return ( isinstance(other, Text) and self.min_length == other.min_length and self.max_length == other.max_length and self.character_set == other.character_set ) @property def character_set(self) -> frozenset[str]: """Returns the character set for the space.""" return self._char_set @property def character_list(self) -> tuple[str, ...]: """Returns a tuple of characters in the space.""" return self._char_list def character_index(self, char: str) -> np.int32: """Returns a unique index for each character in the space's character set.""" return self._char_index[char] @property def characters(self) -> str: """Returns a string with all Text characters.""" return self._char_str @property def is_np_flattenable(self) -> bool: """The flattened version is an integer array for each character, padded to the max character length.""" return True