Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
cf153f6
Add Hopper and Walker2D models for v5
Kallinteris-Andreas May 2, 2023
bc92449
Merge branch 'Farama-Foundation:main' into main
Kallinteris-Andreas May 9, 2023
0cbdd72
Delete hopper_v5.xml
Kallinteris-Andreas May 9, 2023
db3734e
Delete walker2d_v5.xml
Kallinteris-Andreas May 9, 2023
a2d2e64
General MuJoCo Env Documention Cleanup
Kallinteris-Andreas May 9, 2023
f58bb5e
typofix
Kallinteris-Andreas May 9, 2023
7a4bc32
typo fix
Kallinteris-Andreas May 9, 2023
2418631
update following @pseudo-rnd-thoughts reviews
Kallinteris-Andreas May 9, 2023
3b9080b
Merge branch 'Farama-Foundation:main' into main
Kallinteris-Andreas Jun 5, 2023
77bcb8b
Merge branch 'Farama-Foundation:main' into main
Kallinteris-Andreas Jun 16, 2023
7639d18
refactor `tests/env/test_mojoco.py` ->
Kallinteris-Andreas Jun 16, 2023
8eb1b11
Merge branch 'Farama-Foundation:main' into main
Kallinteris-Andreas Jun 27, 2023
61d0848
Update setup.py
Kallinteris-Andreas Oct 23, 2023
5831a19
do nothing
Kallinteris-Andreas Oct 23, 2023
803dc49
Merge branch 'Farama-Foundation:main' into main
Kallinteris-Andreas Nov 3, 2023
d99cc5d
[MuJoCo] add action space figures
Kallinteris-Andreas Nov 3, 2023
f788bb3
Merge branch 'Farama-Foundation:main' into main
Kallinteris-Andreas Nov 10, 2023
450b471
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Nov 30, 2023
14fb4d8
replace `flat.copy()` with `flatten()`
Kallinteris-Andreas Dec 5, 2023
1583839
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Dec 5, 2023
47a7059
add `MuJoCo.test_model_sensors`
Kallinteris-Andreas Dec 6, 2023
9dc31e2
`test_model_sensors` remove check for standup `v3`
Kallinteris-Andreas Dec 6, 2023
bededa3
factorize `_get_rew()` out of `step`
Kallinteris-Andreas Dec 6, 2023
999d888
some cleanup
Kallinteris-Andreas Dec 6, 2023
0f59baa
support `python==3.8`
Kallinteris-Andreas Dec 6, 2023
76f5e17
fix for real this time
Kallinteris-Andreas Dec 6, 2023
724e47f
`black`
Kallinteris-Andreas Dec 6, 2023
32c1cb8
add prototype
Kallinteris-Andreas Dec 10, 2023
e1772bc
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Dec 10, 2023
30cc231
cleanup
Kallinteris-Andreas Dec 15, 2023
925fcdc
update mjx envs
Kallinteris-Andreas Feb 1, 2024
b7f8806
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Feb 1, 2024
08299e7
huge update
Kallinteris-Andreas Feb 5, 2024
be527c4
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Feb 5, 2024
04ed837
`pre-commit`
Kallinteris-Andreas Feb 5, 2024
72d87ba
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Feb 15, 2024
696b0a0
update
Kallinteris-Andreas Feb 15, 2024
a7c614a
`pre-commit`
Kallinteris-Andreas Feb 15, 2024
3e56f40
fix reacher
Kallinteris-Andreas Feb 22, 2024
9fa9225
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Oct 12, 2024
a5b9bba
`pre-commit`
Kallinteris-Andreas Oct 12, 2024
4e28d8a
update func_env
Kallinteris-Andreas Oct 12, 2024
9b56314
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Dec 16, 2025
6f7603a
Change action space dtype from float32 to float64 in MJXEnv
Kallinteris-Andreas Dec 17, 2025
ea0299d
Refactor MJXEnv step method to streamline state control handling
Kallinteris-Andreas Dec 19, 2025
4d55ed2
Add TypedDict hints for MJX environment parameters
Kallinteris-Andreas Dec 20, 2025
fb96e02
fix ant's is_healthy
Kallinteris-Andreas Dec 20, 2025
73aa059
Update MJX environments to use new cfrc_ext access API to avoid depre…
Kallinteris-Andreas Dec 20, 2025
071f44e
Add render metadata to MJXEnv
Kallinteris-Andreas Dec 20, 2025
62ebe3e
pre-commit
Kallinteris-Andreas Dec 20, 2025
ba3d26a
Add MJX dependency to pyproject.toml
Kallinteris-Andreas Dec 20, 2025
806fced
typo fix
Kallinteris-Andreas Dec 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions gymnasium/envs/mjx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Contains the base class and environments for MJX."""

from gymnasium.envs.mjx.mjx_env import MJXEnv
213 changes: 213 additions & 0 deletions gymnasium/envs/mjx/ant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""Contains the class for the `Ant` environment."""

import gymnasium


try:
import jax
from jax import numpy as jnp
from mujoco import mjx
except ImportError as e:
MJX_IMPORT_ERROR = e
else:
MJX_IMPORT_ERROR = None

from typing import TypedDict

import numpy as np

from gymnasium.envs.mjx.mjx_env import MJXEnv
from gymnasium.envs.mujoco.ant_v5 import DEFAULT_CAMERA_CONFIG


class AntMJXEnvParams(TypedDict):
"""Parameters for the Ant environment."""

xml_file: str
frame_skip: int
default_camera_config: dict[str, float | int | str | None]
forward_reward_weight: float
ctrl_cost_weight: float
contact_cost_weight: float
healthy_reward: float
main_body: int
terminate_when_unhealthy: bool
healthy_z_range: tuple[float, float]
contact_force_range: tuple[float, float]
reset_noise_scale: float
exclude_current_positions_from_observation: bool
include_cfrc_ext_in_observation: bool
camera_id: int | None
camera_name: str | None
max_geom: int
width: int
height: int
render_mode: str | None


class Ant_MJXEnv(MJXEnv):
# NOTE: MJX does not yet support cfrc_ext and therefore this class can not be instantiated
"""Class for Ant."""

def __init__(
self,
params: AntMJXEnvParams,
):
"""Sets the `obveration_space`."""
MJXEnv.__init__(self, params=params)

self.observation_structure = {
"skipped_qpos": 2 * params["exclude_current_positions_from_observation"],
"qpos": self.mjx_model.nq
- 2 * params["exclude_current_positions_from_observation"],
"qvel": self.mjx_model.nv,
"cfrc_ext": (self.mjx_model.nbody - 1)
* 6
* params["include_cfrc_ext_in_observation"],
}

obs_size = self.observation_structure["qpos"]
obs_size += self.observation_structure["qvel"]
obs_size += self.observation_structure["cfrc_ext"]

self.observation_space = gymnasium.spaces.Box(
low=-np.inf, high=np.inf, shape=(obs_size,), dtype=np.float64
)

def _gen_init_physics_state(
self, rng, params: AntMJXEnvParams
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Sets `qpos` (positional elements) from a CUD and `qvel` (velocity elements) from a gaussian."""
noise_low = -params["reset_noise_scale"]
noise_high = params["reset_noise_scale"]

qpos = self.mjx_model.qpos0 + jax.random.uniform(
key=rng, minval=noise_low, maxval=noise_high, shape=(self.mjx_model.nq,)
)
qvel = params["reset_noise_scale"] * jax.random.normal(
key=rng, shape=(self.mjx_model.nv,)
)
act = jnp.empty(self.mjx_model.na)

return qpos, qvel, act

def observation(
self, state: mjx.Data, rng: jax.Array, params: AntMJXEnvParams
) -> jnp.ndarray:
"""Observes the `qpos` (posional elements) and `qvel` (velocity elements) and `cfrc_ext` (external contact forces) of the robot."""
mjx_data = state

position = mjx_data.qpos.flatten()
velocity = mjx_data.qvel.flatten()

if params["exclude_current_positions_from_observation"]:
position = position[2:]

if params["include_cfrc_ext_in_observation"] is True:
external_contact_forces = self._get_contact_forces(mjx_data, params)
else:
external_contact_forces = jnp.array([])

observation = jnp.concatenate((position, velocity, external_contact_forces))

return observation

def _get_contact_forces(self, mjx_data: mjx.Data, params: AntMJXEnvParams):
"""Get External Contact Forces (`cfrc_ext`) clipped by `contact_force_range`."""
raw_contact_forces = mjx_data._impl.cfrc_ext
min_value, max_value = params["contact_force_range"]
contact_forces = jnp.clip(raw_contact_forces, min_value, max_value)
return contact_forces

def _get_reward(
self,
state: mjx.Data,
action: jnp.ndarray,
next_state: mjx.Data,
params: AntMJXEnvParams,
) -> tuple[jnp.ndarray, dict]:
"""Reward = forward_reward + healthy_reward - ctrl_cost - contact cost."""
mjx_data_old = state
mjx_data_new = next_state

xy_position_before = mjx_data_old.xpos[params["main_body"], :2]
xy_position_after = mjx_data_new.xpos[params["main_body"], :2]

xy_velocity = (xy_position_after - xy_position_before) / self.dt(params)
x_velocity, y_velocity = xy_velocity

forward_reward = x_velocity * params["forward_reward_weight"]
healthy_reward = (
self._gen_is_healthy(mjx_data_new, params) * params["healthy_reward"]
)
rewards = forward_reward + healthy_reward

ctrl_cost = params["ctrl_cost_weight"] * jnp.sum(jnp.square(action))
contact_cost = params["contact_cost_weight"] * jnp.sum(
jnp.square(self._get_contact_forces(mjx_data_new, params))
)
costs = ctrl_cost + contact_cost

reward = rewards - costs

reward_info = {
"reward_forward": forward_reward,
"reward_ctrl": -ctrl_cost,
"reward_contact": -contact_cost,
"reward_survive": healthy_reward,
}

return reward, reward_info

def _gen_is_healthy(self, state: mjx.Data, params: AntMJXEnvParams) -> jnp.ndarray:
"""Checks if the robot is in a healthy potision."""
mjx_data = state

z = mjx_data.qpos[2]
min_z, max_z = params["healthy_z_range"]
is_finite = jnp.isfinite(
jnp.concatenate((mjx_data.qpos, mjx_data.qvel, mjx_data.act))
).all()
z_ok = jnp.logical_and(z >= min_z, z <= max_z)
is_healthy = jnp.logical_and(is_finite, z_ok)
return is_healthy

def state_info(self, state: mjx.Data, params: AntMJXEnvParams) -> dict[str, float]:
"""Includes state information exclueded from `observation()`."""
mjx_data = state

info = {
"x_position": mjx_data.qpos[0],
"y_position": mjx_data.qpos[1],
"distance_from_origin": jnp.linalg.norm(mjx_data.qpos[0:2], ord=2),
}
return info

def terminal(
self, state: mjx.Data, rng: jax.Array, params: AntMJXEnvParams
) -> bool:
"""Terminates if unhealthy."""
return jnp.logical_and(
jnp.logical_not(self._gen_is_healthy(state, params)),
params["terminate_when_unhealthy"],
)

def get_default_params(**kwargs) -> AntMJXEnvParams:
"""Get the default parameter for the Ant environment."""
default = {
"xml_file": "ant.xml",
"frame_skip": 5,
"default_camera_config": DEFAULT_CAMERA_CONFIG,
"forward_reward_weight": 1,
"ctrl_cost_weight": 0.5,
"contact_cost_weight": 5e-4,
"healthy_reward": 1.0,
"main_body": 1,
"terminate_when_unhealthy": True,
"healthy_z_range": (0.2, 1.0),
"contact_force_range": (-1.0, 1.0),
"reset_noise_scale": 0.1,
"exclude_current_positions_from_observation": True,
"include_cfrc_ext_in_observation": True,
}
return {**MJXEnv.get_default_params(), **default, **kwargs}
1 change: 1 addition & 0 deletions gymnasium/envs/mjx/assets
Loading
Loading