Skip to content

Commit 16fa142

Browse files
authored
Merge pull request #541 from yandexdataschool/spring_2024_week_04_approx_rl
spring_2024_week_04_approx_rl: a major update
2 parents 7aa7069 + 71a5612 commit 16fa142

16 files changed

+2312
-1723
lines changed

week04_approx_rl/dqn/__init__.py

Whitespace-only changes.

week04_approx_rl/dqn/analysis.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Reversible
4+
import numpy as np
5+
6+
7+
def play_and_log_episode(env, agent, t_max=10000):
8+
"""
9+
Plays an episode using the greedy policy and logs for each timestep:
10+
- state
11+
- qvalues (estimated by the agent)
12+
- actions
13+
- rewards
14+
15+
Also logs:
16+
- the final (usually termo=inal) state.
17+
- whether the episode was terminated
18+
19+
Uses the greedy policy.
20+
"""
21+
assert t_max > 0, t_max
22+
23+
states = []
24+
qvalues_all = []
25+
actions = []
26+
rewards = []
27+
28+
s, _ = env.reset()
29+
for step in range(t_max):
30+
s = np.array(s)
31+
states.append(s)
32+
qvalues = agent.get_qvalues(s[None])[0]
33+
qvalues_all.append(qvalues)
34+
action = np.argmax(qvalues)
35+
actions.append(action)
36+
s, r, terminated, truncated, _ = env.step(action)
37+
rewards.append(r)
38+
if terminated or truncated:
39+
break
40+
states.append(s) # the last state
41+
42+
return_pack = {
43+
"states": np.array(states),
44+
"qvalues": np.array(qvalues_all),
45+
"actions": np.array(actions),
46+
"rewards": np.array(rewards),
47+
"episode_finished": terminated,
48+
}
49+
50+
return return_pack
Lines changed: 3 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,13 @@
11
# taken from stable_baselines3.
22

3-
import numpy as np
4-
from gymnasium import Wrapper, RewardWrapper, ObservationWrapper
5-
from gymnasium.spaces import Box
6-
7-
8-
class MaxAndSkipEnv(Wrapper):
9-
def __init__(self, env, skip=4):
10-
"""Return only every `skip`-th frame"""
11-
super().__init__(env)
12-
# most recent raw observations (for max pooling across time steps)
13-
self._obs_buffer = np.zeros(
14-
(2,) + env.observation_space.shape, dtype=np.uint8)
15-
self._skip = skip
16-
17-
def step(self, action):
18-
"""Repeat action, sum reward, and max over last observations."""
19-
total_reward = 0.0
20-
terminated = truncated = False
21-
for i in range(self._skip):
22-
obs, reward, terminated, truncated, info = self.env.step(action)
23-
if i == self._skip - 2:
24-
self._obs_buffer[0] = obs
25-
if i == self._skip - 1:
26-
self._obs_buffer[1] = obs
27-
total_reward += reward
28-
if terminated or truncated:
29-
break
30-
# Note that the observation on the terminated=True frame
31-
# doesn't matter
32-
max_frame = self._obs_buffer.max(axis=0)
33-
34-
return max_frame, total_reward, terminated, truncated, info
35-
36-
37-
class ClipRewardEnv(RewardWrapper):
38-
def __init__(self, env):
39-
super().__init__(env)
40-
41-
def reward(self, reward):
42-
"""Bin reward to {+1, 0, -1} by its sign."""
43-
return np.sign(reward)
3+
from gymnasium import Wrapper
444

455

466
class FireResetEnv(Wrapper):
477
def __init__(self, env):
488
"""Take action on reset for environments that are fixed until firing."""
499
super().__init__(env)
50-
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
10+
assert env.unwrapped.get_action_meanings()[1] == "FIRE"
5111
assert len(env.unwrapped.get_action_meanings()) >= 3
5212

5313
def reset(self, **kwargs):
@@ -94,27 +54,11 @@ def reset(self, **kwargs):
9454
else:
9555
# no-op step to advance from terminal/lost life state
9656
obs, _, terminated, truncated, info = self.env.step(0)
97-
57+
9858
# The no-op step can lead to a game over, so we need to check it again
9959
# to see if we should reset the environment and avoid the
10060
# monitor.py `RuntimeError: Tried to step environment that needs reset`
10161
if terminated or truncated:
10262
obs, info = self.env.reset(**kwargs)
10363
self.lives = self.env.unwrapped.ale.lives()
10464
return obs, info
105-
106-
107-
# in torch imgs have shape [c, h, w] instead of common [h, w, c]
108-
class AntiTorchWrapper(ObservationWrapper):
109-
def __init__(self, env):
110-
super().__init__(env)
111-
112-
self.img_size = [env.observation_space.shape[i]
113-
for i in [1, 2, 0]
114-
]
115-
self.observation_space = Box(0.0, 1.0, self.img_size)
116-
117-
def observation(self, img):
118-
"""what happens to each observation"""
119-
img = img.transpose(1, 2, 0)
120-
return img

week04_approx_rl/replay_buffer.py renamed to week04_approx_rl/dqn/replay_buffer.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _encode_sample(self, idxes):
4444
np.array(actions),
4545
np.array(rewards),
4646
np.array(obses_tp1),
47-
np.array(dones)
47+
np.array(dones),
4848
)
4949

5050
def sample(self, batch_size):
@@ -67,8 +67,54 @@ def sample(self, batch_size):
6767
done_mask[i] = 1 if executing act_batch[i] resulted in
6868
the end of an episode and 0 otherwise.
6969
"""
70-
idxes = [
71-
random.randint(0, len(self._storage) - 1)
72-
for _ in range(batch_size)
73-
]
70+
idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
7471
return self._encode_sample(idxes)
72+
73+
74+
class LazyFramesVectorReplayBuffer(ReplayBuffer):
75+
"""
76+
ReplayBuffer for vectorized environments, which are wrapped into FrameBuffers.
77+
78+
If an environment is first wrapped into a FrameBuffer and then vectorized,
79+
then the resulting VecEnv will not use LazyFrames, but it will directly
80+
use np.ndarrays, thus greatly increasing RAM consumption by the buffer.
81+
82+
Instead, we first vectorize an environment and only then wrap in into FrameBuffers.
83+
It's not as convenient, but it keeps the advantage in memory from LazyFrames.
84+
85+
So,
86+
observations and next_obervations are stored as LazyFrames
87+
of shape (n_frames, n_envs, ...)
88+
actions, rewards and dones are stored as np.ndarrays of shape (n_envs,).
89+
90+
"""
91+
92+
# (n_frames, n_envs, *)
93+
94+
def _encode_sample(self, idxes):
95+
"""
96+
For each index in idxes samples a (s, a, r, s', done) transition
97+
from a randomly chosen environment of the corresponding VecEnv.
98+
"""
99+
obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], []
100+
for i in idxes:
101+
data = self._storage[i]
102+
obs_t, action, reward, obs_tp1, done = data
103+
n_envs = action.shape[0]
104+
env_idx_chosen_for_sample = random.randint(0, n_envs - 1)
105+
obses_t.append(
106+
np.array(obs_t, copy=False)[:, env_idx_chosen_for_sample],
107+
)
108+
actions.append(np.array(action, copy=False)[env_idx_chosen_for_sample])
109+
rewards.append(reward[env_idx_chosen_for_sample])
110+
obses_tp1.append(
111+
np.array(obs_tp1, copy=False)[:, env_idx_chosen_for_sample],
112+
)
113+
dones.append(done[env_idx_chosen_for_sample])
114+
return (
115+
np.array(obses_t),
116+
np.array(actions),
117+
np.array(rewards),
118+
np.array(obses_tp1),
119+
np.array(dones),
120+
)

week04_approx_rl/dqn/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import psutil # type: ignore
2+
3+
4+
def is_enough_ram(min_available_gb=0.1):
5+
mem = psutil.virtual_memory()
6+
return mem.available >= min_available_gb * (1024**3)
7+
8+
9+
def linear_decay(
10+
init_val: float, final_val: float, cur_step: int, total_steps: int
11+
) -> float:
12+
if cur_step >= total_steps:
13+
return final_val
14+
return (init_val * (total_steps - cur_step) + final_val * cur_step) / total_steps

week04_approx_rl/framebuffer.py

Lines changed: 0 additions & 45 deletions
This file was deleted.

0 commit comments

Comments
 (0)