From c505d42e030a42e21629ac51cb53e2c55a24f21f Mon Sep 17 00:00:00 2001 From: AkashKarnatak Date: Sat, 10 Jun 2023 14:57:56 +0530 Subject: [PATCH 1/2] After [this](https://github.com/DLR-RM/stable-baselines3/commit/9c338f917a822c8be13b1aa9f7b2319770481b62) by `stable-baselines3`, DummyVecEnv.reset() method passes a `seed` argument to its list of envs. Since StockTradingEnv.reset() does not expect any argument it results in an error. This commit fixes #1022. --- finrl/meta/env_stock_trading/env_stocktrading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finrl/meta/env_stock_trading/env_stocktrading.py b/finrl/meta/env_stock_trading/env_stocktrading.py index a1698d5292..57a18cc2ba 100644 --- a/finrl/meta/env_stock_trading/env_stocktrading.py +++ b/finrl/meta/env_stock_trading/env_stocktrading.py @@ -355,7 +355,7 @@ def step(self, actions): return self.state, self.reward, self.terminal, False, {} - def reset(self): + def reset(self, seed=None): # initiate state self.state = self._initiate_state() From 311b1f423c88dd6d2f30ecb772ecca2674e3c3b8 Mon Sep 17 00:00:00 2001 From: AkashKarnatak Date: Sat, 10 Jun 2023 15:15:31 +0530 Subject: [PATCH 2/2] DummyVecEnv.render() no longer returns the expected state after [this](https://github.com/DLR-RM/stable-baselines3/pull/1327) pull request got merged in `stable-baselines3`. This commit fixes #1010 --- finrl/agents/stablebaselines3/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finrl/agents/stablebaselines3/models.py b/finrl/agents/stablebaselines3/models.py index 3d627b11b5..32e317a573 100644 --- a/finrl/agents/stablebaselines3/models.py +++ b/finrl/agents/stablebaselines3/models.py @@ -316,7 +316,7 @@ def DRL_prediction( trade_obs, rewards, dones, info = trade_env.step(action) if i == (len(trade_data.index.unique()) - 2): # print(env_test.render()) - last_state = trade_env.render() + last_state = trade_env.envs[0].render() df_last_state = pd.DataFrame({"last_state": last_state}) df_last_state.to_csv(f"results/last_state_{name}_{i}.csv", index=False)