diff --git a/abcdrl/ddpg.py b/abcdrl/ddpg.py index ee5f36b..e49e9b0 100644 --- a/abcdrl/ddpg.py +++ b/abcdrl/ddpg.py @@ -227,7 +227,10 @@ def sample(self, obs: np.ndarray) -> np.ndarray: ) act_ts += noise act_np = act_ts.cpu().numpy().clip(self.kwargs["act_space"].low, self.kwargs["act_space"].high) + self.sample_step += self.kwargs["num_envs"] + if self.sample_step % self.kwargs["policy_frequency"] == 0: + self.alg.sync_target() return act_np def learn(self, data: ReplayBuffer.Samples[np.ndarray]) -> dict[str, Any]: @@ -242,8 +245,6 @@ def learn(self, data: ReplayBuffer.Samples[np.ndarray]) -> dict[str, Any]: ) log_data = self.alg.learn(data_ts, self.sample_step % self.kwargs["policy_frequency"] == 0) - if self.sample_step % self.kwargs["policy_frequency"] == 0: - self.alg.sync_target() self.learn_step += 1 return log_data diff --git a/abcdrl/ddqn.py b/abcdrl/ddqn.py index 4d877cf..7f3835d 100644 --- a/abcdrl/ddqn.py +++ b/abcdrl/ddqn.py @@ -172,7 +172,10 @@ def sample(self, obs: np.ndarray) -> np.ndarray: with torch.no_grad(): _, act_ts = self.alg.predict(obs_ts).max(dim=1) act_np = act_ts.cpu().numpy() + self.sample_step += self.kwargs["num_envs"] + if self.sample_step % self.kwargs["target_network_frequency"] == 0: + self.alg.sync_target() return act_np def learn(self, data: ReplayBuffer.Samples[np.ndarray]) -> dict[str, Any]: @@ -187,8 +190,6 @@ def learn(self, data: ReplayBuffer.Samples[np.ndarray]) -> dict[str, Any]: ) log_data = self.alg.learn(data_ts) - if self.sample_step % self.kwargs["target_network_frequency"] == 0: - self.alg.sync_target() self.learn_step += 1 log_data["epsilon"] = self._get_epsilon() return log_data diff --git a/abcdrl/dqn.py b/abcdrl/dqn.py index a774d38..a51dff4 100644 --- a/abcdrl/dqn.py +++ b/abcdrl/dqn.py @@ -170,7 +170,10 @@ def sample(self, obs: np.ndarray) -> np.ndarray: with torch.no_grad(): _, act_ts = self.alg.predict(obs_ts).max(dim=1) act_np = act_ts.cpu().numpy() + self.sample_step += self.kwargs["num_envs"] + if self.sample_step % self.kwargs["target_network_frequency"] == 0: + self.alg.sync_target() return act_np def learn(self, data: ReplayBuffer.Samples[np.ndarray]) -> dict[str, Any]: @@ -185,8 +188,6 @@ def learn(self, data: ReplayBuffer.Samples[np.ndarray]) -> dict[str, Any]: ) log_data = self.alg.learn(data_ts) - if self.sample_step % self.kwargs["target_network_frequency"] == 0: - self.alg.sync_target() self.learn_step += 1 log_data["epsilon"] = self._get_epsilon() return log_data diff --git a/abcdrl/pdqn.py b/abcdrl/pdqn.py index 14b344e..48d8f93 100644 --- a/abcdrl/pdqn.py +++ b/abcdrl/pdqn.py @@ -281,7 +281,10 @@ def sample(self, obs: np.ndarray) -> np.ndarray: with torch.no_grad(): _, act_ts = self.alg.predict(obs_ts).max(dim=1) act_np = act_ts.cpu().numpy() + self.sample_step += self.kwargs["num_envs"] + if self.sample_step % self.kwargs["target_network_frequency"] == 0: + self.alg.sync_target() return act_np def learn(self, data: PrioritizedReplayBuffer.Samples[np.ndarray]) -> dict[str, Any]: @@ -296,8 +299,6 @@ def learn(self, data: PrioritizedReplayBuffer.Samples[np.ndarray]) -> dict[str, ) log_data = self.alg.learn(data_ts) - if self.sample_step % self.kwargs["target_network_frequency"] == 0: - self.alg.sync_target() self.learn_step += 1 log_data["epsilon"] = self._get_epsilon() return log_data diff --git a/abcdrl/sac.py b/abcdrl/sac.py index 687c7b7..88de024 100644 --- a/abcdrl/sac.py +++ b/abcdrl/sac.py @@ -274,7 +274,10 @@ def sample(self, obs: np.ndarray) -> np.ndarray: with torch.no_grad(): act_ts = self.alg.predict(obs_ts) act_np = act_ts.cpu().numpy() + self.sample_step += self.kwargs["num_envs"] + if self.sample_step % self.kwargs["target_network_frequency"] == 0: + self.alg.sync_target() return act_np def learn(self, data: ReplayBuffer.Samples[np.ndarray]) -> dict[str, Any]: @@ -289,8 +292,6 @@ def learn(self, data: ReplayBuffer.Samples[np.ndarray]) -> dict[str, Any]: ) log_data = self.alg.learn(data_ts, self.sample_step % self.kwargs["policy_frequency"] == 0) - if self.sample_step % self.kwargs["target_network_frequency"] == 0: - self.alg.sync_target() self.learn_step += 1 return log_data diff --git a/abcdrl/td3.py b/abcdrl/td3.py index 9a96720..bdc11c3 100644 --- a/abcdrl/td3.py +++ b/abcdrl/td3.py @@ -253,7 +253,10 @@ def sample(self, obs: np.ndarray) -> np.ndarray: ) act_ts += noise act_np = act_ts.cpu().numpy().clip(self.kwargs["act_space"].low, self.kwargs["act_space"].high) + self.sample_step += self.kwargs["num_envs"] + if self.sample_step % self.kwargs["policy_frequency"] == 0: + self.alg.sync_target() return act_np def learn(self, data: ReplayBuffer.Samples[np.ndarray]) -> dict[str, Any]: @@ -268,8 +271,6 @@ def learn(self, data: ReplayBuffer.Samples[np.ndarray]) -> dict[str, Any]: ) log_data = self.alg.learn(data_ts, self.sample_step % self.kwargs["policy_frequency"] == 0) - if self.sample_step % self.kwargs["policy_frequency"] == 0: - self.alg.sync_target() self.learn_step += 1 return log_data diff --git a/abcdrl_copy_from/dqn_all_wrappers.py b/abcdrl_copy_from/dqn_all_wrappers.py index 22d7410..b033ec0 100644 --- a/abcdrl_copy_from/dqn_all_wrappers.py +++ b/abcdrl_copy_from/dqn_all_wrappers.py @@ -171,7 +171,10 @@ def sample(self, obs: np.ndarray) -> np.ndarray: with torch.no_grad(): _, act_ts = self.alg.predict(obs_ts).max(dim=1) act_np = act_ts.cpu().numpy() + self.sample_step += self.kwargs["num_envs"] + if self.sample_step % self.kwargs["target_network_frequency"] == 0: + self.alg.sync_target() return act_np def learn(self, data: ReplayBuffer.Samples[np.ndarray]) -> dict[str, Any]: @@ -186,8 +189,6 @@ def learn(self, data: ReplayBuffer.Samples[np.ndarray]) -> dict[str, Any]: ) log_data = self.alg.learn(data_ts) - if self.sample_step % self.kwargs["target_network_frequency"] == 0: - self.alg.sync_target() self.learn_step += 1 log_data["epsilon"] = self._get_epsilon() return log_data