-
Notifications
You must be signed in to change notification settings - Fork 936
Huggingface Integration #292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 26 commits
Commits
Show all changes
68 commits
Select commit
Hold shift + click to select a range
1b585d6
initial commit
vwxyzjn fa82356
pre-commit
vwxyzjn 4074eee
Add hub integration
vwxyzjn 4436ce4
pre-commit
vwxyzjn df41e3d
use CommitOperation
vwxyzjn a98383d
Fix pre-commit
vwxyzjn b430540
refactor
vwxyzjn dd8ee86
Merge branch 'master' into hf-integration
vwxyzjn 8144562
push changes
vwxyzjn 2f20e17
refactor
vwxyzjn fdfc2a5
fix pre-commit
vwxyzjn 56413f8
pre-commit
vwxyzjn b1b1dbd
Merge branch 'master' into hf-integration
vwxyzjn f6865d4
close the env and writer after eval
vwxyzjn fbe986c
support dqn jax
vwxyzjn 83aa010
pre-commit
vwxyzjn ba1bfdb
Update cleanrl_utils/huggingface.py
vwxyzjn aee6809
address comments
vwxyzjn 80a460f
update docs
vwxyzjn 40be7d8
support dqn_atari_jax
vwxyzjn 65ded2a
bug fix and docs
vwxyzjn 133e6bd
Add cleanrl to the hf's `metadata`
vwxyzjn 10d0b79
Merge branch 'master' into hf-integration
vwxyzjn ca60f24
include huggingface integration
vwxyzjn b165e35
test for enjoy.py
vwxyzjn 7163d0d
bump version, pip install extra hack
vwxyzjn 27d9b3d
Update cleanrl_utils/huggingface.py
vwxyzjn 2a2208f
Update cleanrl_utils/huggingface.py
vwxyzjn 4ac5631
Update cleanrl_utils/huggingface.py
vwxyzjn 40358b1
Update cleanrl_utils/huggingface.py
vwxyzjn df68d57
Update cleanrl_utils/huggingface.py
vwxyzjn 7dddfbd
Update cleanrl_utils/huggingface.py
vwxyzjn 954723f
update docs
vwxyzjn fb858ae
update pre-commit
vwxyzjn b508f66
quick fix
vwxyzjn 7d5193b
bug fix
vwxyzjn c390b8d
lazy load modules to avoid dependency issues
vwxyzjn cc456d6
Add huggingface shields
vwxyzjn fd5a737
Add emoji
vwxyzjn 3b0af25
Update docs
vwxyzjn ff0be11
pre-commit
vwxyzjn 9bd034e
Update docs
vwxyzjn 78022d7
Update docs
vwxyzjn aae8d4d
Merge branch 'master' into hf-integration
kinalmehta 1c2cd40
fix: use `algorithm_variant_filename` in model card reproduction script
kinalmehta e172a0c
typo fix
kinalmehta c733514
feat: add hf support for c51
kinalmehta 15be698
formatting fix
kinalmehta 8fac8e3
support pulling variant depdencies directly
vwxyzjn 35d6fc7
support model saving for `ppo_atari_envpool_xla_jax_scan`
vwxyzjn 1ce42c9
Merge branch 'master' into hf-integration
vwxyzjn 8990794
support `ppo_atari_envpool_xla_jax_scan`
vwxyzjn ea4a71d
quick change
vwxyzjn 7493ae4
support 'c51_jax'
kinalmehta fe34419
formatting fix
kinalmehta 4a1f72a
support capture video
vwxyzjn 7f22c25
Add notebook
vwxyzjn 5331287
update docs
vwxyzjn 9aec97e
support `c51_atari` and `c51_atari_jax`
kinalmehta bc8c014
Merge remote-tracking branch 'origin/hf-integration' into hf-integration
kinalmehta b202985
typo fix
kinalmehta 54fd64a
add c51 to zoo docs
kinalmehta 9e5841b
add colab badge
vwxyzjn 9178763
fix broken colab svg
vwxyzjn 07961f4
pypi release
vwxyzjn c09a80d
typo fix
vwxyzjn a18ffdb
update pre-commit
vwxyzjn ba7053a
remove hf-integration reference
vwxyzjn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| import random | ||
| from typing import Callable | ||
|
|
||
| import gym | ||
| import numpy as np | ||
| import torch | ||
|
|
||
|
|
||
| def evaluate( | ||
| model_path: str, | ||
| make_env: Callable, | ||
| env_id: str, | ||
| eval_episodes: int, | ||
| run_name: str, | ||
| Model: torch.nn.Module, | ||
| device: torch.device = torch.device("cpu"), | ||
| epsilon: float = 0.05, | ||
| capture_video: bool = True, | ||
| ): | ||
| envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)]) | ||
| model = Model(envs).to(device) | ||
| model.load_state_dict(torch.load(model_path, map_location=device)) | ||
| model.eval() | ||
|
|
||
| obs = envs.reset() | ||
| episodic_returns = [] | ||
| while len(episodic_returns) < eval_episodes: | ||
| if random.random() < epsilon: | ||
| actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) | ||
| else: | ||
| q_values = model(torch.Tensor(obs).to(device)) | ||
| actions = torch.argmax(q_values, dim=1).cpu().numpy() | ||
| next_obs, _, _, infos = envs.step(actions) | ||
| for info in infos: | ||
| if "episode" in info.keys(): | ||
| print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}") | ||
| episodic_returns += [info["episode"]["r"]] | ||
| obs = next_obs | ||
|
|
||
| return episodic_returns | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| from huggingface_hub import hf_hub_download | ||
|
|
||
| from cleanrl.dqn import QNetwork, make_env | ||
|
|
||
| model_path = hf_hub_download(repo_id="cleanrl/CartPole-v1-dqn-seed1", filename="q_network.pth") | ||
| evaluate( | ||
| model_path, | ||
| make_env, | ||
| "CartPole-v1", | ||
| eval_episodes=10, | ||
| run_name=f"eval", | ||
| Model=QNetwork, | ||
| device="cpu", | ||
| capture_video=False, | ||
| ) | ||
|
|
||
| # from cleanrl.dqn_atari import QNetwork, make_env | ||
|
|
||
| # model_path = hf_hub_download(repo_id="vwxyzjn/BreakoutNoFrameskip-v4-dqn_atari-seed1", filename="q_network.pth") | ||
| # evaluate( | ||
| # model_path, | ||
| # make_env, | ||
| # "BreakoutNoFrameskip-v4", | ||
| # eval_episodes=10, | ||
| # run_name=f"eval", | ||
| # Model=QNetwork, | ||
| # device="cpu", | ||
| # epsilon=0.05, | ||
| # capture_video=False, | ||
| # ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| import random | ||
| from typing import Callable | ||
|
|
||
| import flax | ||
| import flax.linen as nn | ||
| import gym | ||
| import jax | ||
| import numpy as np | ||
|
|
||
|
|
||
| def evaluate( | ||
| model_path: str, | ||
| make_env: Callable, | ||
| env_id: str, | ||
| eval_episodes: int, | ||
| run_name: str, | ||
| Model: nn.Module, | ||
| epsilon: float = 0.05, | ||
| capture_video: bool = True, | ||
| seed=1, | ||
| ): | ||
| envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)]) | ||
| obs = envs.reset() | ||
| model = Model(action_dim=envs.single_action_space.n) | ||
| q_key = jax.random.PRNGKey(seed) | ||
| params = model.init(q_key, obs) | ||
| with open(model_path, "rb") as f: | ||
| params = flax.serialization.from_bytes(params, f.read()) | ||
| model.apply = jax.jit(model.apply) | ||
|
|
||
| episodic_returns = [] | ||
| while len(episodic_returns) < eval_episodes: | ||
| if random.random() < epsilon: | ||
| actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) | ||
| else: | ||
| q_values = model.apply(params, obs) | ||
| actions = q_values.argmax(axis=-1) | ||
| actions = jax.device_get(actions) | ||
| next_obs, _, _, infos = envs.step(actions) | ||
| for info in infos: | ||
| if "episode" in info.keys(): | ||
| print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}") | ||
| episodic_returns += [info["episode"]["r"]] | ||
| obs = next_obs | ||
|
|
||
| return episodic_returns | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| from huggingface_hub import hf_hub_download | ||
|
|
||
| from cleanrl.dqn_jax import QNetwork, make_env | ||
|
|
||
| model_path = hf_hub_download(repo_id="vwxyzjn/CartPole-v1-dqn_jax-seed1", filename="dqn_jax.cleanrl_model") | ||
| evaluate( | ||
| model_path, | ||
| make_env, | ||
| "CartPole-v1", | ||
| eval_episodes=10, | ||
| run_name=f"eval", | ||
| Model=QNetwork, | ||
| capture_video=False, | ||
| ) | ||
|
|
||
| # from cleanrl.dqn_atari import QNetwork, make_env | ||
|
|
||
| # model_path = hf_hub_download(repo_id="vwxyzjn/BreakoutNoFrameskip-v4-dqn_atari-seed1", filename="q_network.pth") | ||
| # evaluate( | ||
| # model_path, | ||
| # make_env, | ||
| # "BreakoutNoFrameskip-v4", | ||
| # eval_episodes=10, | ||
| # run_name=f"eval", | ||
| # Model=QNetwork, | ||
| # device="cpu", | ||
| # epsilon=0.05, | ||
| # capture_video=False, | ||
| # ) | ||
vwxyzjn marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.