"""Wrapper for recording videos."""
import os
from typing import Callable, Optional
import gymnasium as gym
from gymnasium import logger
from gymnasium.wrappers.monitoring import video_recorder
def capped_cubic_video_schedule(episode_id: int) -> bool:
"""The default episode trigger.
This function will trigger recordings at the episode indices 0, 1, 8, 27, ..., :math:`k^3`, ..., 729, 1000, 2000, 3000, ...
Args:
episode_id: The episode number
Returns:
If to apply a video schedule number
"""
if episode_id < 1000:
return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id
else:
return episode_id % 1000 == 0
[docs]
class RecordVideo(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""This wrapper records videos of rollouts.
Usually, you only want to record episodes intermittently, say every hundredth episode.
To do this, you can specify **either** ``episode_trigger`` **or** ``step_trigger`` (not both).
They should be functions returning a boolean that indicates whether a recording should be started at the
current episode or step, respectively.
If neither :attr:`episode_trigger` nor ``step_trigger`` is passed, a default ``episode_trigger`` will be employed.
By default, the recording will be stopped once a `terminated` or `truncated` signal has been emitted by the environment. However, you can
also create recordings of fixed length (possibly spanning several episodes) by passing a strictly positive value for
``video_length``.
"""
def __init__(
self,
env: gym.Env,
video_folder: str,
episode_trigger: Callable[[int], bool] = None,
step_trigger: Callable[[int], bool] = None,
video_length: int = 0,
name_prefix: str = "rl-video",
disable_logger: bool = False,
):
"""Wrapper records videos of rollouts.
Args:
env: The environment that will be wrapped
video_folder (str): The folder where the recordings will be stored
episode_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this episode
step_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this step
video_length (int): The length of recorded episodes. If 0, entire episodes are recorded.
Otherwise, snippets of the specified length are captured
name_prefix (str): Will be prepended to the filename of the recordings
disable_logger (bool): Whether to disable moviepy logger or not.
"""
gym.utils.RecordConstructorArgs.__init__(
self,
video_folder=video_folder,
episode_trigger=episode_trigger,
step_trigger=step_trigger,
video_length=video_length,
name_prefix=name_prefix,
disable_logger=disable_logger,
)
gym.Wrapper.__init__(self, env)
if env.render_mode in {None, "human", "ansi", "ansi_list"}:
raise ValueError(
f"Render mode is {env.render_mode}, which is incompatible with"
f" RecordVideo. Initialize your environment with a render_mode"
f" that returns an image, such as rgb_array."
)
if episode_trigger is None and step_trigger is None:
episode_trigger = capped_cubic_video_schedule
trigger_count = sum(x is not None for x in [episode_trigger, step_trigger])
assert trigger_count == 1, "Must specify exactly one trigger"
self.episode_trigger = episode_trigger
self.step_trigger = step_trigger
self.video_recorder: Optional[video_recorder.VideoRecorder] = None
self.disable_logger = disable_logger
self.video_folder = os.path.abspath(video_folder)
# Create output folder if needed
if os.path.isdir(self.video_folder):
logger.warn(
f"Overwriting existing videos at {self.video_folder} folder "
f"(try specifying a different `video_folder` for the `RecordVideo` wrapper if this is not desired)"
)
os.makedirs(self.video_folder, exist_ok=True)
self.name_prefix = name_prefix
self.step_id = 0
self.video_length = video_length
self.recording = False
self.terminated = False
self.truncated = False
self.recorded_frames = 0
self.episode_id = 0
try:
self.is_vector_env = self.get_wrapper_attr("is_vector_env")
except AttributeError:
self.is_vector_env = False
def reset(self, **kwargs):
"""Reset the environment using kwargs and then starts recording if video enabled."""
observations = super().reset(**kwargs)
self.terminated = False
self.truncated = False
if self.recording:
assert self.video_recorder is not None
self.video_recorder.recorded_frames = []
self.video_recorder.capture_frame()
self.recorded_frames += 1
if self.video_length > 0:
if self.recorded_frames > self.video_length:
self.close_video_recorder()
elif self._video_enabled():
self.start_video_recorder()
return observations
def start_video_recorder(self):
"""Starts video recorder using :class:`video_recorder.VideoRecorder`."""
self.close_video_recorder()
video_name = f"{self.name_prefix}-step-{self.step_id}"
if self.episode_trigger:
video_name = f"{self.name_prefix}-episode-{self.episode_id}"
base_path = os.path.join(self.video_folder, video_name)
self.video_recorder = video_recorder.VideoRecorder(
env=self.env,
base_path=base_path,
metadata={"step_id": self.step_id, "episode_id": self.episode_id},
disable_logger=self.disable_logger,
)
self.video_recorder.capture_frame()
self.recorded_frames = 1
self.recording = True
def _video_enabled(self):
if self.step_trigger:
return self.step_trigger(self.step_id)
else:
return self.episode_trigger(self.episode_id)
def step(self, action):
"""Steps through the environment using action, recording observations if :attr:`self.recording`."""
(
observations,
rewards,
terminateds,
truncateds,
infos,
) = self.env.step(action)
if not (self.terminated or self.truncated):
# increment steps and episodes
self.step_id += 1
if not self.is_vector_env:
if terminateds or truncateds:
self.episode_id += 1
self.terminated = terminateds
self.truncated = truncateds
elif terminateds[0] or truncateds[0]:
self.episode_id += 1
self.terminated = terminateds[0]
self.truncated = truncateds[0]
if self.recording:
assert self.video_recorder is not None
self.video_recorder.capture_frame()
self.recorded_frames += 1
if self.video_length > 0:
if self.recorded_frames > self.video_length:
self.close_video_recorder()
else:
if not self.is_vector_env:
if terminateds or truncateds:
self.close_video_recorder()
elif terminateds[0] or truncateds[0]:
self.close_video_recorder()
elif self._video_enabled():
self.start_video_recorder()
return observations, rewards, terminateds, truncateds, infos
def close_video_recorder(self):
"""Closes the video recorder if currently recording."""
if self.recording:
assert self.video_recorder is not None
self.video_recorder.close()
self.recording = False
self.recorded_frames = 1
def render(self, *args, **kwargs):
"""Compute the render frames as specified by render_mode attribute during initialization of the environment or as specified in kwargs."""
if self.video_recorder is None or not self.video_recorder.enabled:
return super().render(*args, **kwargs)
if len(self.video_recorder.render_history) > 0:
recorded_frames = [
self.video_recorder.render_history.pop()
for _ in range(len(self.video_recorder.render_history))
]
if self.recording:
return recorded_frames
else:
return recorded_frames + super().render(*args, **kwargs)
else:
return super().render(*args, **kwargs)
def close(self):
"""Closes the wrapper then the video recorder."""
super().close()
self.close_video_recorder()