| """ |
| Author: Minh Pham-Dinh |
| Created: Jan 26th, 2024 |
| Last Modified: Feb 5th, 2024 |
| Email: [email protected] |
| |
| Description: |
| File containing the ReplayBuffer that will be used in Dreamer. |
| |
| The implementation is based on: |
| Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination," 2019. |
| [Online]. Available: https://arxiv.org/abs/1912.01603 |
| """ |
|
|
| import numpy as np |
| from gymnasium import Env |
| import torch |
| from addict import Dict |
|
|
| class ReplayBuffer: |
| def __init__(self, capacity, obs_size, action_size): |
| |
| |
| self.obs_size = obs_size |
| self.action_size = action_size |
|
|
| |
| state_type = np.uint8 if len(self.obs_size) < 3 else np.float32 |
| |
| self.observation = np.zeros((capacity, ) + self.obs_size, dtype=state_type) |
| |
| self.actions = np.zeros((capacity, ) + self.action_size, dtype=np.float32) |
| self.rewards = np.zeros((capacity, 1), dtype=np.float32) |
| self.dones = np.zeros((capacity, 1), dtype=np.float32) |
|
|
| self.pointer = 0 |
| self.full = False |
| |
| print(f''' |
| -----------initialized memory---------- |
| |
| obs_buffer_shape: {self.observation.shape} |
| actions_buffer_shape: {self.actions.shape} |
| rewards_buffer_shape: {self.rewards.shape} |
| dones_buffer_shape: {self.dones.shape} |
| |
| ---------------------------------------- |
| ''') |
|
|
| def add(self, obs, action, reward, done): |
| """Add method for buffer |
| |
| Args: |
| obs (np.array): current observation |
| action (np.array): action taken |
| reward (float): reward received after action |
| next_obs (np.array): next observation |
| done (bool): boolean value of termination or truncation |
| """ |
| self.observation[self.pointer] = obs |
| self.actions[self.pointer] = action |
| self.rewards[self.pointer] = reward |
| self.dones[self.pointer] = done |
| self.pointer = (self.pointer + 1) % self.observation.shape[0] |
| if self.pointer == 0: |
| self.full = True |
|
|
| def sample(self, batch_size, seq_len, device): |
| """ |
| Samples batches of experiences of fixed sequence length from the replay buffer, |
| taking into account the circular nature of the buffer to avoid crossing the |
| "end" of the buffer when it is full. |
| |
| This method ensures that sampled sequences are continuous and do not wrap around |
| the end of the buffer, maintaining the temporal integrity of experiences. This is |
| particularly important when the buffer is full, and the pointer marks the boundary |
| between the newest and oldest data in the buffer. |
| |
| Args: |
| batch_size (int): The number of sequences to sample. |
| seq_len (int): The length of each sequence to sample. |
| device (torch.device): The device on which the sampled data will be loaded. |
| |
| Raises: |
| Exception: If there is not enough data in the buffer to sample a full sequence. |
| |
| Returns: |
| Dict: A dictionary containing the sampled sequences of observations, actions, |
| rewards, and dones. Each item in the dictionary is a tensor of shape |
| (batch_size, seq_len, feature_dimension), except for 'dones' which is of shape |
| (batch_size, seq_len, 1). |
| |
| Notes: |
| - The method handles different scenarios based on the buffer's state (full or not) |
| and the pointer's position to ensure valid sequence sampling without wrapping. |
| - When the buffer is not full, sequences can start from index 0 up to the |
| index where `seq_len` sequences can fit without surpassing the current pointer. |
| - When the buffer is full, the method ensures sequences do not start in a way |
| that would cause them to wrap around past the pointer, effectively crossing |
| the boundary between the newest and oldest data. |
| - This approach guarantees the sampled sequences respect the temporal order |
| and continuity necessary for algorithms that rely on sequences of experiences. |
| """ |
| |
| |
| if self.pointer < seq_len and not self.full: |
| raise Exception('not enough data to sample') |
|
|
| |
| if self.full: |
| if self.pointer - seq_len < 0: |
| valid_range = np.arange(self.pointer, self.observation.shape[0] - (self.pointer - seq_len) + 1) |
| else: |
| range_1 = np.arange(0, self.pointer - seq_len + 1) |
| range_2 = np.arange(self.pointer, self.observation.shape[0]) |
| valid_range = np.concatenate((range_1, range_2), -1) |
| else: |
| valid_range = np.arange(0, self.pointer-seq_len+1) |
|
|
| start_index = np.random.choice(valid_range, (batch_size, 1)) |
| |
| seq_len = np.arange(seq_len) |
| sample_idcs = (start_index + seq_len) % self.observation.shape[0] |
| |
| batch = Dict() |
| |
| batch.obs = torch.from_numpy(self.observation[sample_idcs]).to(device) |
| batch.actions = torch.from_numpy(self.actions[sample_idcs]).to(device) |
| batch.rewards = torch.from_numpy(self.rewards[sample_idcs]).to(device) |
| batch.dones = torch.from_numpy(self.dones[sample_idcs]).to(device) |
| |
| return batch |
| |
| def clear(self, ): |
| self.pointer = 0 |
| self.full = False |
|
|
| def __len__(self, ): |
| return self.pointer |