diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3006d0c74..6297793ea 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,7 +18,10 @@ jobs: env: CACHE_BRANCH: ${{ github.head_ref || github.ref_name }} DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} - BAZEL_CI_RESOURCE_OPTS: --jobs=HOST_CPUS-1 --local_cpu_resources=HOST_CPUS-1 --local_ram_resources=HOST_RAM*.5 + # Keep hosted Linux runners responsive during the full MuJoCo/MyoSuite coverage + # build. The default HOST_CPUS-1 parallelism can starve the runner hard + # enough for GitHub Actions to lose communication. + BAZEL_CI_RESOURCE_OPTS: --jobs=2 --local_cpu_resources=2 --local_ram_resources=HOST_RAM*.35 --experimental_disk_cache_gc_max_size=2G --experimental_disk_cache_gc_idle_delay=0s MUJOCO_GL: egl EGL_PLATFORM: surfaceless steps: @@ -128,7 +131,10 @@ jobs: env: CACHE_BRANCH: ${{ github.head_ref || github.ref_name }} DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} - BAZEL_CI_RESOURCE_OPTS: --jobs=HOST_CPUS-1 --local_cpu_resources=HOST_CPUS-1 --local_ram_resources=HOST_RAM*.5 + # Keep hosted Linux runners responsive during the full MuJoCo/MyoSuite test + # build. The default HOST_CPUS-1 parallelism can starve the runner hard + # enough for GitHub Actions to lose communication. + BAZEL_CI_RESOURCE_OPTS: --jobs=2 --local_cpu_resources=2 --local_ram_resources=HOST_RAM*.35 --experimental_disk_cache_gc_max_size=2G --experimental_disk_cache_gc_idle_delay=0s MUJOCO_GL: egl EGL_PLATFORM: surfaceless steps: @@ -280,12 +286,13 @@ jobs: test-windows: name: Test (windows) runs-on: windows-2022 - timeout-minutes: 180 + timeout-minutes: 300 env: BAZEL_SH: C:/Program Files/Git/usr/bin/bash.exe CACHE_BRANCH: ${{ github.head_ref || github.ref_name }} DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} - BAZEL_CI_RESOURCE_OPTS: --jobs=HOST_CPUS-1 --local_cpu_resources=HOST_CPUS-1 --local_ram_resources=HOST_RAM*.5 + BAZEL_CI_RESOURCE_OPTS: --jobs=HOST_CPUS-1 --local_cpu_resources=HOST_CPUS-1 --local_ram_resources=HOST_RAM*.5 --experimental_disk_cache_gc_max_size=2G --experimental_disk_cache_gc_idle_delay=0s + BAZEL_CACHE_PREFIX: bazel-test-v5-windows-capped MESA_GL_VERSION_OVERRIDE: 4.5COMPAT MSYS2_ARG_CONV_EXCL: "*" MSYS_NO_PATHCONV: "1" @@ -345,14 +352,16 @@ jobs: - name: Restore Bazel caches uses: actions/cache/restore@v4 with: + # Keep the Windows disk cache capped. The old uncapped archive grew + # past 7 GB and exhausted the hosted runner disk during extraction. path: | ~/.cache/bazelisk ~/.cache/envpool-bazel-repo ~/.cache/envpool-bazel-disk - key: bazel-test-v3-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('.bazelversion') }}-${{ env.CACHE_BRANCH }} + key: ${{ env.BAZEL_CACHE_PREFIX }}-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('.bazelversion') }}-${{ env.CACHE_BRANCH }} restore-keys: | - bazel-test-v3-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('.bazelversion') }}-${{ env.DEFAULT_BRANCH }}- - bazel-test-v3-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('.bazelversion') }}- + ${{ env.BAZEL_CACHE_PREFIX }}-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('.bazelversion') }}-${{ env.DEFAULT_BRANCH }}- + ${{ env.BAZEL_CACHE_PREFIX }}-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('.bazelversion') }}- - name: Install bazelisk shell: pwsh run: | @@ -384,7 +393,7 @@ jobs: ~/.cache/bazelisk ~/.cache/envpool-bazel-repo ~/.cache/envpool-bazel-disk - key: bazel-test-v3-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('.bazelversion') }}-${{ env.CACHE_BRANCH }}-${{ github.run_id }} + key: ${{ env.BAZEL_CACHE_PREFIX }}-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('.bazelversion') }}-${{ env.CACHE_BRANCH }}-${{ github.run_id }} - name: Prune older Bazel caches if: ${{ always() && !cancelled() && (steps.save-bazel-caches.outcome == 'success' || steps.save-bazel-caches.outcome == 'skipped') }} continue-on-error: true @@ -393,7 +402,7 @@ jobs: run: >- python scripts/prune_actions_caches.py --repo "${{ github.repository }}" - --prefix "bazel-test-v3-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('.bazelversion') }}-" + --prefix "${{ env.BAZEL_CACHE_PREFIX }}-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('.bazelversion') }}-" --group-by-branch --delete-missing-branches --keep 1 diff --git a/BUILD b/BUILD index 149635e72..1a9bd5c38 100644 --- a/BUILD +++ b/BUILD @@ -10,6 +10,22 @@ config_setting( constraint_values = ["@platforms//os:linux"], ) +config_setting( + name = "linux_x86_64", + constraint_values = [ + "@platforms//cpu:x86_64", + "@platforms//os:linux", + ], +) + +config_setting( + name = "linux_arm64", + constraint_values = [ + "@platforms//cpu:arm64", + "@platforms//os:linux", + ], +) + config_setting( name = "windows", constraint_values = ["@platforms//os:windows"], diff --git a/Makefile b/Makefile index 1c736caab..7326564ec 100644 --- a/Makefile +++ b/Makefile @@ -154,10 +154,10 @@ buildifier: buildifier-install # bazel build/test bazel-pip-requirement-dev: - cd third_party/pip_requirements && (cmp -s requirements.txt requirements-dev-lock.txt || cp -f requirements-dev-lock.txt requirements.txt) + cd third_party/pip_requirements && (cmp -s requirements.txt requirements-dev-lock.txt || (rm -f requirements.txt && cp -f requirements-dev-lock.txt requirements.txt)) bazel-pip-requirement-release: - cd third_party/pip_requirements && (cmp -s requirements.txt requirements-release-lock.txt || cp -f requirements-release-lock.txt requirements.txt) + cd third_party/pip_requirements && (cmp -s requirements.txt requirements-release-lock.txt || (rm -f requirements.txt && cp -f requirements-release-lock.txt requirements.txt)) clang-tidy: clang-tidy-install bazel-pip-requirement-dev targets="$${CLANG_TIDY_TARGETS:-$$($(CLANG_TIDY_TARGET_RESOLVER) | tr '\n' ' ')}"; \ @@ -258,6 +258,7 @@ pypi-wheel: $(PYPI_WHEEL_PREREQS) bazel-release release-test1: tmpdir=$$(python3 -c 'import tempfile; print(tempfile.mkdtemp(prefix="envpool-release-test-"))'); \ + cd "$$tmpdir" && PYTHONPATH= python3 "$(CURDIR)/scripts/release_installed_wheel_smoke.py" --source-root "$(CURDIR)" && \ cd "$$tmpdir" && PYTHONPATH= python3 "$(CURDIR)/envpool/make_test.py" release-test2: diff --git a/README.md b/README.md index 913d50262..ed0eddf52 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,6 @@ - [x] [Atari games](https://envpool.readthedocs.io/en/latest/env/atari.html) - [x] [MuJoCo (Gymnasium)](https://envpool.readthedocs.io/en/latest/env/mujoco_gym.html) -- [x] [MetaWorld](https://envpool.readthedocs.io/en/latest/env/metaworld.html) - [x] [Classic control RL envs](https://envpool.readthedocs.io/en/latest/env/classic_control.html): CartPole, MountainCar, Pendulum, Acrobot - [x] [Toy text RL envs](https://envpool.readthedocs.io/en/latest/env/toy_text.html): Catch, FrozenLake, Taxi, NChain, CliffWalking, Blackjack - [x] [ViZDoom single player](https://envpool.readthedocs.io/en/latest/env/vizdoom.html) @@ -20,6 +19,8 @@ - [x] [Procgen](https://envpool.readthedocs.io/en/latest/env/procgen.html) - [x] [Minigrid](https://envpool.readthedocs.io/en/latest/env/minigrid.html) - [x] [Highway](https://envpool.readthedocs.io/en/latest/env/highway.html) +- [x] [MetaWorld](https://envpool.readthedocs.io/en/latest/env/metaworld.html) +- [x] [MyoSuite](https://envpool.readthedocs.io/en/latest/env/myosuite.html) Here are EnvPool's several highlights: diff --git a/WORKSPACE b/WORKSPACE index 438001516..b79362295 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -45,3 +45,7 @@ pip_workspace() load("@pip_requirements//:requirements.bzl", "install_deps") install_deps() + +load("@myosuite_oracle_requirements//:requirements.bzl", myosuite_oracle_install_deps = "install_deps") + +myosuite_oracle_install_deps() diff --git a/docs/_static/render_samples/myosuite_myobase_official_compare.png b/docs/_static/render_samples/myosuite_myobase_official_compare.png new file mode 100644 index 000000000..285ee66c5 Binary files /dev/null and b/docs/_static/render_samples/myosuite_myobase_official_compare.png differ diff --git a/docs/_static/render_samples/myosuite_myochallenge_official_compare.png b/docs/_static/render_samples/myosuite_myochallenge_official_compare.png new file mode 100644 index 000000000..0ed592bc7 Binary files /dev/null and b/docs/_static/render_samples/myosuite_myochallenge_official_compare.png differ diff --git a/docs/_static/render_samples/myosuite_myodm_official_compare.png b/docs/_static/render_samples/myosuite_myodm_official_compare.png new file mode 100644 index 000000000..36c6bf0b5 Binary files /dev/null and b/docs/_static/render_samples/myosuite_myodm_official_compare.png differ diff --git a/docs/content/cpp_interface.rst b/docs/content/cpp_interface.rst index 05b7a957f..d19c45f54 100644 --- a/docs/content/cpp_interface.rst +++ b/docs/content/cpp_interface.rst @@ -28,17 +28,13 @@ The compile-time dictionary helpers used by these types still live in Environment Authoring --------------------- -.. doxygenvariable:: common_config - :project: envpool_cpp_api - -.. doxygenvariable:: common_action_spec - :project: envpool_cpp_api - -.. doxygenvariable:: common_state_spec - :project: envpool_cpp_api +The shared dictionaries in ``envpool/core/env_spec.h`` define the base +configuration, action, and state entries that every environment family extends: +``common_config``, ``common_action_spec``, and ``common_state_spec``. .. doxygenclass:: EnvSpec :project: envpool_cpp_api + :members: config, state_spec, action_spec, EnvSpec .. doxygenclass:: Env :project: envpool_cpp_api diff --git a/docs/content/python_interface.rst b/docs/content/python_interface.rst index 071b06ded..bda9cd395 100644 --- a/docs/content/python_interface.rst +++ b/docs/content/python_interface.rst @@ -55,6 +55,11 @@ The observation space and action space of resulted environment describe a single environment's space, but each time the observation/action's first dimension is always equal to ``num_envs`` (sync mode) or equal to ``batch_size`` (async mode). +For Gymnasium compatibility, the Gymnasium wrapper also exposes +``num_envs``, ``is_vector_env``, ``single_observation_space``, and +``single_action_space`` for vector-aware wrappers such as Gymnasium's vector +``NormalizeObservation``. EnvPool keeps ``observation_space`` and +``action_space`` as the single-environment spaces for backward compatibility. ``envpool.make_gym``, ``envpool.make_dm``, and ``envpool.make_gymnasium`` are shortcuts for ``envpool.make(..., env_type="gym" | "dm" | "gymnasium")``, @@ -149,11 +154,11 @@ Example: Representative first-frame compares for EnvPool families that support rendering. In each panel, EnvPool is on the left and the reference output is on the right. For -Box2D, Classic Control, MiniGrid, MuJoCo, and Gymnasium-Robotics, the -reference is the upstream Python renderer. For Atari, Procgen, and VizDoom, -the reference is the exact in-tree render oracle used by the test suite. -Google Research Football is intentionally excluded here because its render API -is unsupported. +Box2D, Classic Control, MiniGrid, MuJoCo, Gymnasium-Robotics, and MyoSuite, +the reference is the upstream Python renderer. For Atari, Procgen, and +VizDoom, the reference is the exact in-tree render oracle used by the test +suite. Google Research Football is intentionally excluded here because its +render API is unsupported. .. image:: ../_static/render_samples/atari_oracle_compare.png :width: 900px @@ -187,6 +192,18 @@ is unsupported. :width: 900px :align: center +.. image:: ../_static/render_samples/myosuite_myobase_official_compare.png + :width: 900px + :align: center + +.. image:: ../_static/render_samples/myosuite_myochallenge_official_compare.png + :width: 900px + :align: center + +.. image:: ../_static/render_samples/myosuite_myodm_official_compare.png + :width: 900px + :align: center + .. image:: ../_static/render_samples/vizdoom_oracle_compare.png :width: 900px :align: center diff --git a/docs/env/myosuite.rst b/docs/env/myosuite.rst new file mode 100644 index 000000000..6a1bed23e --- /dev/null +++ b/docs/env/myosuite.rst @@ -0,0 +1,469 @@ +MyoSuite +======== + +EnvPool's MyoSuite integration uses ``myosuite==2.11.6`` pinned at commit +``05cb84678373f91271004f99602ebbf01e57d1a1`` with ``mujoco==3.6.0``. +The runtime implementation is native C++; the official Python package is used +only by oracle tests and doc-generation tooling. + +The generated upstream registry and task metadata live under +``third_party/myosuite/``. Runtime C++ consumes those generated assets instead +of keeping a handwritten task list in ``envpool/mujoco/myosuite/``. + + +Env IDs +------- + +EnvPool registers all 398 official MyoSuite task IDs from the pinned oracle. +Every official ID also has an EnvPool alias of the form +``MyoSuite/``, for example: + +:: + + envpool.make_gymnasium("myoFingerReachFixed-v0") + envpool.make_gymnasium("MyoSuite/myoFingerReachFixed-v0") + +The full registered official IDs and EnvPool aliases are: + +:: + + Official ID EnvPool alias + ----------- ------------- + MyoHandAirplaneFixed-v0 MyoSuite/MyoHandAirplaneFixed-v0 + MyoHandAirplaneFly-v0 MyoSuite/MyoHandAirplaneFly-v0 + MyoHandAirplaneLift-v0 MyoSuite/MyoHandAirplaneLift-v0 + MyoHandAirplanePass-v0 MyoSuite/MyoHandAirplanePass-v0 + MyoHandAirplaneRandom-v0 MyoSuite/MyoHandAirplaneRandom-v0 + MyoHandAlarmclockFixed-v0 MyoSuite/MyoHandAlarmclockFixed-v0 + MyoHandAlarmclockLift-v0 MyoSuite/MyoHandAlarmclockLift-v0 + MyoHandAlarmclockPass-v0 MyoSuite/MyoHandAlarmclockPass-v0 + MyoHandAlarmclockRandom-v0 MyoSuite/MyoHandAlarmclockRandom-v0 + MyoHandAlarmclockSee-v0 MyoSuite/MyoHandAlarmclockSee-v0 + MyoHandAppleFixed-v0 MyoSuite/MyoHandAppleFixed-v0 + MyoHandAppleLift-v0 MyoSuite/MyoHandAppleLift-v0 + MyoHandApplePass-v0 MyoSuite/MyoHandApplePass-v0 + MyoHandAppleRandom-v0 MyoSuite/MyoHandAppleRandom-v0 + MyoHandBananaFixed-v0 MyoSuite/MyoHandBananaFixed-v0 + MyoHandBananaPass-v0 MyoSuite/MyoHandBananaPass-v0 + MyoHandBananaRandom-v0 MyoSuite/MyoHandBananaRandom-v0 + MyoHandBinocularsFixed-v0 MyoSuite/MyoHandBinocularsFixed-v0 + MyoHandBinocularsPass-v0 MyoSuite/MyoHandBinocularsPass-v0 + MyoHandBinocularsRandom-v0 MyoSuite/MyoHandBinocularsRandom-v0 + MyoHandBowlDrink2-v0 MyoSuite/MyoHandBowlDrink2-v0 + MyoHandBowlFixed-v0 MyoSuite/MyoHandBowlFixed-v0 + MyoHandBowlPass-v0 MyoSuite/MyoHandBowlPass-v0 + MyoHandBowlRandom-v0 MyoSuite/MyoHandBowlRandom-v0 + MyoHandCameraFixed-v0 MyoSuite/MyoHandCameraFixed-v0 + MyoHandCameraPass-v0 MyoSuite/MyoHandCameraPass-v0 + MyoHandCameraRandom-v0 MyoSuite/MyoHandCameraRandom-v0 + MyoHandCoffeemugFixed-v0 MyoSuite/MyoHandCoffeemugFixed-v0 + MyoHandCoffeemugRandom-v0 MyoSuite/MyoHandCoffeemugRandom-v0 + MyoHandCubelargeFixed-v0 MyoSuite/MyoHandCubelargeFixed-v0 + MyoHandCubelargePass-v0 MyoSuite/MyoHandCubelargePass-v0 + MyoHandCubelargeRandom-v0 MyoSuite/MyoHandCubelargeRandom-v0 + MyoHandCubemediumFixed-v0 MyoSuite/MyoHandCubemediumFixed-v0 + MyoHandCubemediumLInspect-v0 MyoSuite/MyoHandCubemediumLInspect-v0 + MyoHandCubemediumRandom-v0 MyoSuite/MyoHandCubemediumRandom-v0 + MyoHandCubesmallFixed-v0 MyoSuite/MyoHandCubesmallFixed-v0 + MyoHandCubesmallLift-v0 MyoSuite/MyoHandCubesmallLift-v0 + MyoHandCubesmallPass-v0 MyoSuite/MyoHandCubesmallPass-v0 + MyoHandCubesmallRandom-v0 MyoSuite/MyoHandCubesmallRandom-v0 + MyoHandCupDrink-v0 MyoSuite/MyoHandCupDrink-v0 + MyoHandCupFixed-v0 MyoSuite/MyoHandCupFixed-v0 + MyoHandCupPass-v0 MyoSuite/MyoHandCupPass-v0 + MyoHandCupPour-v0 MyoSuite/MyoHandCupPour-v0 + MyoHandCupRandom-v0 MyoSuite/MyoHandCupRandom-v0 + MyoHandCylinderlargeFixed-v0 MyoSuite/MyoHandCylinderlargeFixed-v0 + MyoHandCylinderlargeInspect-v0 MyoSuite/MyoHandCylinderlargeInspect-v0 + MyoHandCylinderlargeRandom-v0 MyoSuite/MyoHandCylinderlargeRandom-v0 + MyoHandCylindermediumFixed-v0 MyoSuite/MyoHandCylindermediumFixed-v0 + MyoHandCylindermediumLift-v0 MyoSuite/MyoHandCylindermediumLift-v0 + MyoHandCylindermediumPass-v0 MyoSuite/MyoHandCylindermediumPass-v0 + MyoHandCylindermediumRandom-v0 MyoSuite/MyoHandCylindermediumRandom-v0 + MyoHandCylindersmallFixed-v0 MyoSuite/MyoHandCylindersmallFixed-v0 + MyoHandCylindersmallInspect-v0 MyoSuite/MyoHandCylindersmallInspect-v0 + MyoHandCylindersmallPass-v0 MyoSuite/MyoHandCylindersmallPass-v0 + MyoHandCylindersmallRandom-v0 MyoSuite/MyoHandCylindersmallRandom-v0 + MyoHandDuckFixed-v0 MyoSuite/MyoHandDuckFixed-v0 + MyoHandDuckInspect-v0 MyoSuite/MyoHandDuckInspect-v0 + MyoHandDuckLift-v0 MyoSuite/MyoHandDuckLift-v0 + MyoHandDuckPass-v0 MyoSuite/MyoHandDuckPass-v0 + MyoHandDuckRandom-v0 MyoSuite/MyoHandDuckRandom-v0 + MyoHandElephantFixed-v0 MyoSuite/MyoHandElephantFixed-v0 + MyoHandElephantLift-v0 MyoSuite/MyoHandElephantLift-v0 + MyoHandElephantPass-v0 MyoSuite/MyoHandElephantPass-v0 + MyoHandElephantRandom-v0 MyoSuite/MyoHandElephantRandom-v0 + MyoHandEyeglassesFixed-v0 MyoSuite/MyoHandEyeglassesFixed-v0 + MyoHandEyeglassesPass-v0 MyoSuite/MyoHandEyeglassesPass-v0 + MyoHandEyeglassesRandom-v0 MyoSuite/MyoHandEyeglassesRandom-v0 + MyoHandFlashlight1On-v0 MyoSuite/MyoHandFlashlight1On-v0 + MyoHandFlashlight2On-v0 MyoSuite/MyoHandFlashlight2On-v0 + MyoHandFlashlightFixed-v0 MyoSuite/MyoHandFlashlightFixed-v0 + MyoHandFlashlightLift-v0 MyoSuite/MyoHandFlashlightLift-v0 + MyoHandFlashlightPass-v0 MyoSuite/MyoHandFlashlightPass-v0 + MyoHandFlashlightRandom-v0 MyoSuite/MyoHandFlashlightRandom-v0 + MyoHandFluteFixed-v0 MyoSuite/MyoHandFluteFixed-v0 + MyoHandFlutePass-v0 MyoSuite/MyoHandFlutePass-v0 + MyoHandFluteRandom-v0 MyoSuite/MyoHandFluteRandom-v0 + MyoHandGamecontrollerFixed-v0 MyoSuite/MyoHandGamecontrollerFixed-v0 + MyoHandGamecontrollerPass-v0 MyoSuite/MyoHandGamecontrollerPass-v0 + MyoHandGamecontrollerRandom-v0 MyoSuite/MyoHandGamecontrollerRandom-v0 + MyoHandHammerFixed-v0 MyoSuite/MyoHandHammerFixed-v0 + MyoHandHammerPass-v0 MyoSuite/MyoHandHammerPass-v0 + MyoHandHammerRandom-v0 MyoSuite/MyoHandHammerRandom-v0 + MyoHandHammerUse-v0 MyoSuite/MyoHandHammerUse-v0 + MyoHandHandFixed-v0 MyoSuite/MyoHandHandFixed-v0 + MyoHandHandInspect-v0 MyoSuite/MyoHandHandInspect-v0 + MyoHandHandRandom-v0 MyoSuite/MyoHandHandRandom-v0 + MyoHandHeadphonesFixed-v0 MyoSuite/MyoHandHeadphonesFixed-v0 + MyoHandHeadphonesPass-v0 MyoSuite/MyoHandHeadphonesPass-v0 + MyoHandHeadphonesRandom-v0 MyoSuite/MyoHandHeadphonesRandom-v0 + MyoHandKnifeChop-v0 MyoSuite/MyoHandKnifeChop-v0 + MyoHandKnifeFixed-v0 MyoSuite/MyoHandKnifeFixed-v0 + MyoHandKnifeRandom-v0 MyoSuite/MyoHandKnifeRandom-v0 + MyoHandLightbulbFixed-v0 MyoSuite/MyoHandLightbulbFixed-v0 + MyoHandLightbulbPass-v0 MyoSuite/MyoHandLightbulbPass-v0 + MyoHandLightbulbRandom-v0 MyoSuite/MyoHandLightbulbRandom-v0 + MyoHandMouseFixed-v0 MyoSuite/MyoHandMouseFixed-v0 + MyoHandMouseLift-v0 MyoSuite/MyoHandMouseLift-v0 + MyoHandMousePass-v0 MyoSuite/MyoHandMousePass-v0 + MyoHandMouseRandom-v0 MyoSuite/MyoHandMouseRandom-v0 + MyoHandMouseUse-v0 MyoSuite/MyoHandMouseUse-v0 + MyoHandMugDrink3-v0 MyoSuite/MyoHandMugDrink3-v0 + MyoHandMugFixed-v0 MyoSuite/MyoHandMugFixed-v0 + MyoHandMugLift-v0 MyoSuite/MyoHandMugLift-v0 + MyoHandMugPass-v0 MyoSuite/MyoHandMugPass-v0 + MyoHandMugRandom-v0 MyoSuite/MyoHandMugRandom-v0 + MyoHandPhoneFixed-v0 MyoSuite/MyoHandPhoneFixed-v0 + MyoHandPhoneLift-v0 MyoSuite/MyoHandPhoneLift-v0 + MyoHandPhoneRandom-v0 MyoSuite/MyoHandPhoneRandom-v0 + MyoHandPiggybankFixed-v0 MyoSuite/MyoHandPiggybankFixed-v0 + MyoHandPiggybankPass-v0 MyoSuite/MyoHandPiggybankPass-v0 + MyoHandPiggybankRandom-v0 MyoSuite/MyoHandPiggybankRandom-v0 + MyoHandPiggybankUse-v0 MyoSuite/MyoHandPiggybankUse-v0 + MyoHandPyramidlargeFixed-v0 MyoSuite/MyoHandPyramidlargeFixed-v0 + MyoHandPyramidlargePass-v0 MyoSuite/MyoHandPyramidlargePass-v0 + MyoHandPyramidlargeRandom-v0 MyoSuite/MyoHandPyramidlargeRandom-v0 + MyoHandPyramidmediumFixed-v0 MyoSuite/MyoHandPyramidmediumFixed-v0 + MyoHandPyramidmediumPass-v0 MyoSuite/MyoHandPyramidmediumPass-v0 + MyoHandPyramidmediumRandom-v0 MyoSuite/MyoHandPyramidmediumRandom-v0 + MyoHandPyramidsmallFixed-v0 MyoSuite/MyoHandPyramidsmallFixed-v0 + MyoHandPyramidsmallInspect-v0 MyoSuite/MyoHandPyramidsmallInspect-v0 + MyoHandPyramidsmallRandom-v0 MyoSuite/MyoHandPyramidsmallRandom-v0 + MyoHandScissorsFixed-v0 MyoSuite/MyoHandScissorsFixed-v0 + MyoHandScissorsRandom-v0 MyoSuite/MyoHandScissorsRandom-v0 + MyoHandScissorsUse-v0 MyoSuite/MyoHandScissorsUse-v0 + MyoHandSpherelargeFixed-v0 MyoSuite/MyoHandSpherelargeFixed-v0 + MyoHandSpherelargePass-v0 MyoSuite/MyoHandSpherelargePass-v0 + MyoHandSpherelargeRandom-v0 MyoSuite/MyoHandSpherelargeRandom-v0 + MyoHandSpheremediumFixed-v0 MyoSuite/MyoHandSpheremediumFixed-v0 + MyoHandSpheremediumInspect-v0 MyoSuite/MyoHandSpheremediumInspect-v0 + MyoHandSpheremediumLift-v0 MyoSuite/MyoHandSpheremediumLift-v0 + MyoHandSpheremediumRandom-v0 MyoSuite/MyoHandSpheremediumRandom-v0 + MyoHandSpheresmallFixed-v0 MyoSuite/MyoHandSpheresmallFixed-v0 + MyoHandSpheresmallInspect-v0 MyoSuite/MyoHandSpheresmallInspect-v0 + MyoHandSpheresmallLift-v0 MyoSuite/MyoHandSpheresmallLift-v0 + MyoHandSpheresmallPass-v0 MyoSuite/MyoHandSpheresmallPass-v0 + MyoHandSpheresmallRandom-v0 MyoSuite/MyoHandSpheresmallRandom-v0 + MyoHandStampFixed-v0 MyoSuite/MyoHandStampFixed-v0 + MyoHandStampLift-v0 MyoSuite/MyoHandStampLift-v0 + MyoHandStampRandom-v0 MyoSuite/MyoHandStampRandom-v0 + MyoHandStampStamp-v0 MyoSuite/MyoHandStampStamp-v0 + MyoHandStanfordbunnyFixed-v0 MyoSuite/MyoHandStanfordbunnyFixed-v0 + MyoHandStanfordbunnyInspect-v0 MyoSuite/MyoHandStanfordbunnyInspect-v0 + MyoHandStanfordbunnyPass-v0 MyoSuite/MyoHandStanfordbunnyPass-v0 + MyoHandStanfordbunnyRandom-v0 MyoSuite/MyoHandStanfordbunnyRandom-v0 + MyoHandStaplerFixed-v0 MyoSuite/MyoHandStaplerFixed-v0 + MyoHandStaplerLift-v0 MyoSuite/MyoHandStaplerLift-v0 + MyoHandStaplerRandom-v0 MyoSuite/MyoHandStaplerRandom-v0 + MyoHandStaplerStaple1-v0 MyoSuite/MyoHandStaplerStaple1-v0 + MyoHandStaplerStaple2-v0 MyoSuite/MyoHandStaplerStaple2-v0 + MyoHandTeapotFixed-v0 MyoSuite/MyoHandTeapotFixed-v0 + MyoHandTeapotPour2-v0 MyoSuite/MyoHandTeapotPour2-v0 + MyoHandTeapotRandom-v0 MyoSuite/MyoHandTeapotRandom-v0 + MyoHandToothbrushBrush1-v0 MyoSuite/MyoHandToothbrushBrush1-v0 + MyoHandToothbrushFixed-v0 MyoSuite/MyoHandToothbrushFixed-v0 + MyoHandToothbrushLift-v0 MyoSuite/MyoHandToothbrushLift-v0 + MyoHandToothbrushRandom-v0 MyoSuite/MyoHandToothbrushRandom-v0 + MyoHandToothpasteFixed-v0 MyoSuite/MyoHandToothpasteFixed-v0 + MyoHandToothpasteLift-v0 MyoSuite/MyoHandToothpasteLift-v0 + MyoHandToothpasteRandom-v0 MyoSuite/MyoHandToothpasteRandom-v0 + MyoHandToothpasteSqueeze1-v0 MyoSuite/MyoHandToothpasteSqueeze1-v0 + MyoHandToruslargeFixed-v0 MyoSuite/MyoHandToruslargeFixed-v0 + MyoHandToruslargeInspect-v0 MyoSuite/MyoHandToruslargeInspect-v0 + MyoHandToruslargeLift-v0 MyoSuite/MyoHandToruslargeLift-v0 + MyoHandToruslargeRandom-v0 MyoSuite/MyoHandToruslargeRandom-v0 + MyoHandTorusmediumFixed-v0 MyoSuite/MyoHandTorusmediumFixed-v0 + MyoHandTorusmediumLift-v0 MyoSuite/MyoHandTorusmediumLift-v0 + MyoHandTorusmediumPass-v0 MyoSuite/MyoHandTorusmediumPass-v0 + MyoHandTorusmediumRandom-v0 MyoSuite/MyoHandTorusmediumRandom-v0 + MyoHandTorussmallFixed-v0 MyoSuite/MyoHandTorussmallFixed-v0 + MyoHandTorussmallLift-v0 MyoSuite/MyoHandTorussmallLift-v0 + MyoHandTorussmallPass-v0 MyoSuite/MyoHandTorussmallPass-v0 + MyoHandTorussmallRandom-v0 MyoSuite/MyoHandTorussmallRandom-v0 + MyoHandTrainFixed-v0 MyoSuite/MyoHandTrainFixed-v0 + MyoHandTrainPlay-v0 MyoSuite/MyoHandTrainPlay-v0 + MyoHandTrainRandom-v0 MyoSuite/MyoHandTrainRandom-v0 + MyoHandWatchFixed-v0 MyoSuite/MyoHandWatchFixed-v0 + MyoHandWatchLift-v0 MyoSuite/MyoHandWatchLift-v0 + MyoHandWatchPass-v0 MyoSuite/MyoHandWatchPass-v0 + MyoHandWatchRandom-v0 MyoSuite/MyoHandWatchRandom-v0 + MyoHandWaterbottleFixed-v0 MyoSuite/MyoHandWaterbottleFixed-v0 + MyoHandWaterbottleLift-v0 MyoSuite/MyoHandWaterbottleLift-v0 + MyoHandWaterbottlePass-v0 MyoSuite/MyoHandWaterbottlePass-v0 + MyoHandWaterbottleRandom-v0 MyoSuite/MyoHandWaterbottleRandom-v0 + MyoHandWaterbottleShake-v0 MyoSuite/MyoHandWaterbottleShake-v0 + MyoHandWineglassDrink2-v0 MyoSuite/MyoHandWineglassDrink2-v0 + MyoHandWineglassFixed-v0 MyoSuite/MyoHandWineglassFixed-v0 + MyoHandWineglassLift-v0 MyoSuite/MyoHandWineglassLift-v0 + MyoHandWineglassPass-v0 MyoSuite/MyoHandWineglassPass-v0 + MyoHandWineglassRandom-v0 MyoSuite/MyoHandWineglassRandom-v0 + MyoHandWineglassToast1-v0 MyoSuite/MyoHandWineglassToast1-v0 + motorFingerPoseFixed-v0 MyoSuite/motorFingerPoseFixed-v0 + motorFingerPoseRandom-v0 MyoSuite/motorFingerPoseRandom-v0 + motorFingerReachFixed-v0 MyoSuite/motorFingerReachFixed-v0 + motorFingerReachRandom-v0 MyoSuite/motorFingerReachRandom-v0 + myoArmReachFixed-v0 MyoSuite/myoArmReachFixed-v0 + myoArmReachRandom-v0 MyoSuite/myoArmReachRandom-v0 + myoChallengeBaodingP1-v1 MyoSuite/myoChallengeBaodingP1-v1 + myoChallengeBaodingP2-v1 MyoSuite/myoChallengeBaodingP2-v1 + myoChallengeBimanual-v0 MyoSuite/myoChallengeBimanual-v0 + myoChallengeChaseTagP1-v0 MyoSuite/myoChallengeChaseTagP1-v0 + myoChallengeChaseTagP2-v0 MyoSuite/myoChallengeChaseTagP2-v0 + myoChallengeChaseTagP2eval-v0 MyoSuite/myoChallengeChaseTagP2eval-v0 + myoChallengeDieReorientDemo-v0 MyoSuite/myoChallengeDieReorientDemo-v0 + myoChallengeDieReorientP1-v0 MyoSuite/myoChallengeDieReorientP1-v0 + myoChallengeDieReorientP2-v0 MyoSuite/myoChallengeDieReorientP2-v0 + myoChallengeOslRunFixed-v0 MyoSuite/myoChallengeOslRunFixed-v0 + myoChallengeOslRunRandom-v0 MyoSuite/myoChallengeOslRunRandom-v0 + myoChallengeRelocateP1-v0 MyoSuite/myoChallengeRelocateP1-v0 + myoChallengeRelocateP2-v0 MyoSuite/myoChallengeRelocateP2-v0 + myoChallengeRelocateP2eval-v0 MyoSuite/myoChallengeRelocateP2eval-v0 + myoChallengeSoccerP1-v0 MyoSuite/myoChallengeSoccerP1-v0 + myoChallengeSoccerP2-v0 MyoSuite/myoChallengeSoccerP2-v0 + myoChallengeTableTennisP0-v0 MyoSuite/myoChallengeTableTennisP0-v0 + myoChallengeTableTennisP1-v0 MyoSuite/myoChallengeTableTennisP1-v0 + myoChallengeTableTennisP2-v0 MyoSuite/myoChallengeTableTennisP2-v0 + myoElbowPose1D6MExoFixed-v0 MyoSuite/myoElbowPose1D6MExoFixed-v0 + myoElbowPose1D6MExoRandom-v0 MyoSuite/myoElbowPose1D6MExoRandom-v0 + myoElbowPose1D6MFixed-v0 MyoSuite/myoElbowPose1D6MFixed-v0 + myoElbowPose1D6MRandom-v0 MyoSuite/myoElbowPose1D6MRandom-v0 + myoFatiArmReachFixed-v0 MyoSuite/myoFatiArmReachFixed-v0 + myoFatiArmReachRandom-v0 MyoSuite/myoFatiArmReachRandom-v0 + myoFatiChallengeBaodingP1-v1 MyoSuite/myoFatiChallengeBaodingP1-v1 + myoFatiChallengeBaodingP2-v1 MyoSuite/myoFatiChallengeBaodingP2-v1 + myoFatiChallengeBimanual-v0 MyoSuite/myoFatiChallengeBimanual-v0 + myoFatiChallengeChaseTagP1-v0 MyoSuite/myoFatiChallengeChaseTagP1-v0 + myoFatiChallengeChaseTagP2-v0 MyoSuite/myoFatiChallengeChaseTagP2-v0 + myoFatiChallengeChaseTagP2eval-v0 MyoSuite/myoFatiChallengeChaseTagP2eval-v0 + myoFatiChallengeDieReorientDemo-v0 MyoSuite/myoFatiChallengeDieReorientDemo-v0 + myoFatiChallengeDieReorientP1-v0 MyoSuite/myoFatiChallengeDieReorientP1-v0 + myoFatiChallengeDieReorientP2-v0 MyoSuite/myoFatiChallengeDieReorientP2-v0 + myoFatiChallengeOslRunFixed-v0 MyoSuite/myoFatiChallengeOslRunFixed-v0 + myoFatiChallengeOslRunRandom-v0 MyoSuite/myoFatiChallengeOslRunRandom-v0 + myoFatiChallengeRelocateP1-v0 MyoSuite/myoFatiChallengeRelocateP1-v0 + myoFatiChallengeRelocateP2-v0 MyoSuite/myoFatiChallengeRelocateP2-v0 + myoFatiChallengeRelocateP2eval-v0 MyoSuite/myoFatiChallengeRelocateP2eval-v0 + myoFatiChallengeSoccerP1-v0 MyoSuite/myoFatiChallengeSoccerP1-v0 + myoFatiChallengeSoccerP2-v0 MyoSuite/myoFatiChallengeSoccerP2-v0 + myoFatiChallengeTableTennisP0-v0 MyoSuite/myoFatiChallengeTableTennisP0-v0 + myoFatiChallengeTableTennisP1-v0 MyoSuite/myoFatiChallengeTableTennisP1-v0 + myoFatiChallengeTableTennisP2-v0 MyoSuite/myoFatiChallengeTableTennisP2-v0 + myoFatiElbowPose1D6MExoFixed-v0 MyoSuite/myoFatiElbowPose1D6MExoFixed-v0 + myoFatiElbowPose1D6MExoRandom-v0 MyoSuite/myoFatiElbowPose1D6MExoRandom-v0 + myoFatiElbowPose1D6MFixed-v0 MyoSuite/myoFatiElbowPose1D6MFixed-v0 + myoFatiElbowPose1D6MRandom-v0 MyoSuite/myoFatiElbowPose1D6MRandom-v0 + myoFatiFingerPoseFixed-v0 MyoSuite/myoFatiFingerPoseFixed-v0 + myoFatiFingerPoseRandom-v0 MyoSuite/myoFatiFingerPoseRandom-v0 + myoFatiFingerReachFixed-v0 MyoSuite/myoFatiFingerReachFixed-v0 + myoFatiFingerReachRandom-v0 MyoSuite/myoFatiFingerReachRandom-v0 + myoFatiHandKeyTurnFixed-v0 MyoSuite/myoFatiHandKeyTurnFixed-v0 + myoFatiHandKeyTurnRandom-v0 MyoSuite/myoFatiHandKeyTurnRandom-v0 + myoFatiHandObjHoldFixed-v0 MyoSuite/myoFatiHandObjHoldFixed-v0 + myoFatiHandObjHoldRandom-v0 MyoSuite/myoFatiHandObjHoldRandom-v0 + myoFatiHandPenTwirlFixed-v0 MyoSuite/myoFatiHandPenTwirlFixed-v0 + myoFatiHandPenTwirlRandom-v0 MyoSuite/myoFatiHandPenTwirlRandom-v0 + myoFatiHandPose0Fixed-v0 MyoSuite/myoFatiHandPose0Fixed-v0 + myoFatiHandPose1Fixed-v0 MyoSuite/myoFatiHandPose1Fixed-v0 + myoFatiHandPose2Fixed-v0 MyoSuite/myoFatiHandPose2Fixed-v0 + myoFatiHandPose3Fixed-v0 MyoSuite/myoFatiHandPose3Fixed-v0 + myoFatiHandPose4Fixed-v0 MyoSuite/myoFatiHandPose4Fixed-v0 + myoFatiHandPose5Fixed-v0 MyoSuite/myoFatiHandPose5Fixed-v0 + myoFatiHandPose6Fixed-v0 MyoSuite/myoFatiHandPose6Fixed-v0 + myoFatiHandPose7Fixed-v0 MyoSuite/myoFatiHandPose7Fixed-v0 + myoFatiHandPose8Fixed-v0 MyoSuite/myoFatiHandPose8Fixed-v0 + myoFatiHandPose9Fixed-v0 MyoSuite/myoFatiHandPose9Fixed-v0 + myoFatiHandPoseFixed-v0 MyoSuite/myoFatiHandPoseFixed-v0 + myoFatiHandPoseRandom-v0 MyoSuite/myoFatiHandPoseRandom-v0 + myoFatiHandReachFixed-v0 MyoSuite/myoFatiHandReachFixed-v0 + myoFatiHandReachRandom-v0 MyoSuite/myoFatiHandReachRandom-v0 + myoFatiHandReorient100-v0 MyoSuite/myoFatiHandReorient100-v0 + myoFatiHandReorient8-v0 MyoSuite/myoFatiHandReorient8-v0 + myoFatiHandReorientID-v0 MyoSuite/myoFatiHandReorientID-v0 + myoFatiHandReorientOOD-v0 MyoSuite/myoFatiHandReorientOOD-v0 + myoFatiLegHillyTerrainWalk-v0 MyoSuite/myoFatiLegHillyTerrainWalk-v0 + myoFatiLegRoughTerrainWalk-v0 MyoSuite/myoFatiLegRoughTerrainWalk-v0 + myoFatiLegStairTerrainWalk-v0 MyoSuite/myoFatiLegStairTerrainWalk-v0 + myoFatiLegStandRandom-v0 MyoSuite/myoFatiLegStandRandom-v0 + myoFatiLegWalk-v0 MyoSuite/myoFatiLegWalk-v0 + myoFatiTorsoExoPoseFixed-v0 MyoSuite/myoFatiTorsoExoPoseFixed-v0 + myoFatiTorsoPoseFixed-v0 MyoSuite/myoFatiTorsoPoseFixed-v0 + myoFingerPoseFixed-v0 MyoSuite/myoFingerPoseFixed-v0 + myoFingerPoseRandom-v0 MyoSuite/myoFingerPoseRandom-v0 + myoFingerReachFixed-v0 MyoSuite/myoFingerReachFixed-v0 + myoFingerReachRandom-v0 MyoSuite/myoFingerReachRandom-v0 + myoHandKeyTurnFixed-v0 MyoSuite/myoHandKeyTurnFixed-v0 + myoHandKeyTurnRandom-v0 MyoSuite/myoHandKeyTurnRandom-v0 + myoHandObjHoldFixed-v0 MyoSuite/myoHandObjHoldFixed-v0 + myoHandObjHoldRandom-v0 MyoSuite/myoHandObjHoldRandom-v0 + myoHandPenTwirlFixed-v0 MyoSuite/myoHandPenTwirlFixed-v0 + myoHandPenTwirlRandom-v0 MyoSuite/myoHandPenTwirlRandom-v0 + myoHandPose0Fixed-v0 MyoSuite/myoHandPose0Fixed-v0 + myoHandPose1Fixed-v0 MyoSuite/myoHandPose1Fixed-v0 + myoHandPose2Fixed-v0 MyoSuite/myoHandPose2Fixed-v0 + myoHandPose3Fixed-v0 MyoSuite/myoHandPose3Fixed-v0 + myoHandPose4Fixed-v0 MyoSuite/myoHandPose4Fixed-v0 + myoHandPose5Fixed-v0 MyoSuite/myoHandPose5Fixed-v0 + myoHandPose6Fixed-v0 MyoSuite/myoHandPose6Fixed-v0 + myoHandPose7Fixed-v0 MyoSuite/myoHandPose7Fixed-v0 + myoHandPose8Fixed-v0 MyoSuite/myoHandPose8Fixed-v0 + myoHandPose9Fixed-v0 MyoSuite/myoHandPose9Fixed-v0 + myoHandPoseFixed-v0 MyoSuite/myoHandPoseFixed-v0 + myoHandPoseRandom-v0 MyoSuite/myoHandPoseRandom-v0 + myoHandReachFixed-v0 MyoSuite/myoHandReachFixed-v0 + myoHandReachRandom-v0 MyoSuite/myoHandReachRandom-v0 + myoHandReorient100-v0 MyoSuite/myoHandReorient100-v0 + myoHandReorient8-v0 MyoSuite/myoHandReorient8-v0 + myoHandReorientID-v0 MyoSuite/myoHandReorientID-v0 + myoHandReorientOOD-v0 MyoSuite/myoHandReorientOOD-v0 + myoLegHillyTerrainWalk-v0 MyoSuite/myoLegHillyTerrainWalk-v0 + myoLegRoughTerrainWalk-v0 MyoSuite/myoLegRoughTerrainWalk-v0 + myoLegStairTerrainWalk-v0 MyoSuite/myoLegStairTerrainWalk-v0 + myoLegStandRandom-v0 MyoSuite/myoLegStandRandom-v0 + myoLegWalk-v0 MyoSuite/myoLegWalk-v0 + myoReafHandKeyTurnFixed-v0 MyoSuite/myoReafHandKeyTurnFixed-v0 + myoReafHandKeyTurnRandom-v0 MyoSuite/myoReafHandKeyTurnRandom-v0 + myoReafHandObjHoldFixed-v0 MyoSuite/myoReafHandObjHoldFixed-v0 + myoReafHandObjHoldRandom-v0 MyoSuite/myoReafHandObjHoldRandom-v0 + myoReafHandPenTwirlFixed-v0 MyoSuite/myoReafHandPenTwirlFixed-v0 + myoReafHandPenTwirlRandom-v0 MyoSuite/myoReafHandPenTwirlRandom-v0 + myoReafHandPose0Fixed-v0 MyoSuite/myoReafHandPose0Fixed-v0 + myoReafHandPose1Fixed-v0 MyoSuite/myoReafHandPose1Fixed-v0 + myoReafHandPose2Fixed-v0 MyoSuite/myoReafHandPose2Fixed-v0 + myoReafHandPose3Fixed-v0 MyoSuite/myoReafHandPose3Fixed-v0 + myoReafHandPose4Fixed-v0 MyoSuite/myoReafHandPose4Fixed-v0 + myoReafHandPose5Fixed-v0 MyoSuite/myoReafHandPose5Fixed-v0 + myoReafHandPose6Fixed-v0 MyoSuite/myoReafHandPose6Fixed-v0 + myoReafHandPose7Fixed-v0 MyoSuite/myoReafHandPose7Fixed-v0 + myoReafHandPose8Fixed-v0 MyoSuite/myoReafHandPose8Fixed-v0 + myoReafHandPose9Fixed-v0 MyoSuite/myoReafHandPose9Fixed-v0 + myoReafHandPoseFixed-v0 MyoSuite/myoReafHandPoseFixed-v0 + myoReafHandPoseRandom-v0 MyoSuite/myoReafHandPoseRandom-v0 + myoReafHandReachFixed-v0 MyoSuite/myoReafHandReachFixed-v0 + myoReafHandReachRandom-v0 MyoSuite/myoReafHandReachRandom-v0 + myoReafHandReorient100-v0 MyoSuite/myoReafHandReorient100-v0 + myoReafHandReorient8-v0 MyoSuite/myoReafHandReorient8-v0 + myoReafHandReorientID-v0 MyoSuite/myoReafHandReorientID-v0 + myoReafHandReorientOOD-v0 MyoSuite/myoReafHandReorientOOD-v0 + myoSarcArmReachFixed-v0 MyoSuite/myoSarcArmReachFixed-v0 + myoSarcArmReachRandom-v0 MyoSuite/myoSarcArmReachRandom-v0 + myoSarcChallengeBaodingP1-v1 MyoSuite/myoSarcChallengeBaodingP1-v1 + myoSarcChallengeBaodingP2-v1 MyoSuite/myoSarcChallengeBaodingP2-v1 + myoSarcChallengeBimanual-v0 MyoSuite/myoSarcChallengeBimanual-v0 + myoSarcChallengeChaseTagP1-v0 MyoSuite/myoSarcChallengeChaseTagP1-v0 + myoSarcChallengeChaseTagP2-v0 MyoSuite/myoSarcChallengeChaseTagP2-v0 + myoSarcChallengeChaseTagP2eval-v0 MyoSuite/myoSarcChallengeChaseTagP2eval-v0 + myoSarcChallengeDieReorientDemo-v0 MyoSuite/myoSarcChallengeDieReorientDemo-v0 + myoSarcChallengeDieReorientP1-v0 MyoSuite/myoSarcChallengeDieReorientP1-v0 + myoSarcChallengeDieReorientP2-v0 MyoSuite/myoSarcChallengeDieReorientP2-v0 + myoSarcChallengeOslRunFixed-v0 MyoSuite/myoSarcChallengeOslRunFixed-v0 + myoSarcChallengeOslRunRandom-v0 MyoSuite/myoSarcChallengeOslRunRandom-v0 + myoSarcChallengeRelocateP1-v0 MyoSuite/myoSarcChallengeRelocateP1-v0 + myoSarcChallengeRelocateP2-v0 MyoSuite/myoSarcChallengeRelocateP2-v0 + myoSarcChallengeRelocateP2eval-v0 MyoSuite/myoSarcChallengeRelocateP2eval-v0 + myoSarcChallengeSoccerP1-v0 MyoSuite/myoSarcChallengeSoccerP1-v0 + myoSarcChallengeSoccerP2-v0 MyoSuite/myoSarcChallengeSoccerP2-v0 + myoSarcChallengeTableTennisP0-v0 MyoSuite/myoSarcChallengeTableTennisP0-v0 + myoSarcChallengeTableTennisP1-v0 MyoSuite/myoSarcChallengeTableTennisP1-v0 + myoSarcChallengeTableTennisP2-v0 MyoSuite/myoSarcChallengeTableTennisP2-v0 + myoSarcElbowPose1D6MExoFixed-v0 MyoSuite/myoSarcElbowPose1D6MExoFixed-v0 + myoSarcElbowPose1D6MExoRandom-v0 MyoSuite/myoSarcElbowPose1D6MExoRandom-v0 + myoSarcElbowPose1D6MFixed-v0 MyoSuite/myoSarcElbowPose1D6MFixed-v0 + myoSarcElbowPose1D6MRandom-v0 MyoSuite/myoSarcElbowPose1D6MRandom-v0 + myoSarcFingerPoseFixed-v0 MyoSuite/myoSarcFingerPoseFixed-v0 + myoSarcFingerPoseRandom-v0 MyoSuite/myoSarcFingerPoseRandom-v0 + myoSarcFingerReachFixed-v0 MyoSuite/myoSarcFingerReachFixed-v0 + myoSarcFingerReachRandom-v0 MyoSuite/myoSarcFingerReachRandom-v0 + myoSarcHandKeyTurnFixed-v0 MyoSuite/myoSarcHandKeyTurnFixed-v0 + myoSarcHandKeyTurnRandom-v0 MyoSuite/myoSarcHandKeyTurnRandom-v0 + myoSarcHandObjHoldFixed-v0 MyoSuite/myoSarcHandObjHoldFixed-v0 + myoSarcHandObjHoldRandom-v0 MyoSuite/myoSarcHandObjHoldRandom-v0 + myoSarcHandPenTwirlFixed-v0 MyoSuite/myoSarcHandPenTwirlFixed-v0 + myoSarcHandPenTwirlRandom-v0 MyoSuite/myoSarcHandPenTwirlRandom-v0 + myoSarcHandPose0Fixed-v0 MyoSuite/myoSarcHandPose0Fixed-v0 + myoSarcHandPose1Fixed-v0 MyoSuite/myoSarcHandPose1Fixed-v0 + myoSarcHandPose2Fixed-v0 MyoSuite/myoSarcHandPose2Fixed-v0 + myoSarcHandPose3Fixed-v0 MyoSuite/myoSarcHandPose3Fixed-v0 + myoSarcHandPose4Fixed-v0 MyoSuite/myoSarcHandPose4Fixed-v0 + myoSarcHandPose5Fixed-v0 MyoSuite/myoSarcHandPose5Fixed-v0 + myoSarcHandPose6Fixed-v0 MyoSuite/myoSarcHandPose6Fixed-v0 + myoSarcHandPose7Fixed-v0 MyoSuite/myoSarcHandPose7Fixed-v0 + myoSarcHandPose8Fixed-v0 MyoSuite/myoSarcHandPose8Fixed-v0 + myoSarcHandPose9Fixed-v0 MyoSuite/myoSarcHandPose9Fixed-v0 + myoSarcHandPoseFixed-v0 MyoSuite/myoSarcHandPoseFixed-v0 + myoSarcHandPoseRandom-v0 MyoSuite/myoSarcHandPoseRandom-v0 + myoSarcHandReachFixed-v0 MyoSuite/myoSarcHandReachFixed-v0 + myoSarcHandReachRandom-v0 MyoSuite/myoSarcHandReachRandom-v0 + myoSarcHandReorient100-v0 MyoSuite/myoSarcHandReorient100-v0 + myoSarcHandReorient8-v0 MyoSuite/myoSarcHandReorient8-v0 + myoSarcHandReorientID-v0 MyoSuite/myoSarcHandReorientID-v0 + myoSarcHandReorientOOD-v0 MyoSuite/myoSarcHandReorientOOD-v0 + myoSarcLegHillyTerrainWalk-v0 MyoSuite/myoSarcLegHillyTerrainWalk-v0 + myoSarcLegRoughTerrainWalk-v0 MyoSuite/myoSarcLegRoughTerrainWalk-v0 + myoSarcLegStairTerrainWalk-v0 MyoSuite/myoSarcLegStairTerrainWalk-v0 + myoSarcLegStandRandom-v0 MyoSuite/myoSarcLegStandRandom-v0 + myoSarcLegWalk-v0 MyoSuite/myoSarcLegWalk-v0 + myoSarcTorsoExoPoseFixed-v0 MyoSuite/myoSarcTorsoExoPoseFixed-v0 + myoSarcTorsoPoseFixed-v0 MyoSuite/myoSarcTorsoPoseFixed-v0 + myoTorsoExoPoseFixed-v0 MyoSuite/myoTorsoExoPoseFixed-v0 + myoTorsoPoseFixed-v0 MyoSuite/myoTorsoPoseFixed-v0 + +The covered surface includes MyoBase reach, pose, key-turn, object-hold, +pen-twirl, reorient, walk, and terrain tasks; MyoChallenge tasks; MyoDM track +tasks; and the corresponding normal, sarcopenia, fatigue, and +reafferentation variants exposed by upstream. + + +Render Compare +-------------- + +Reset and first-three-step render comparisons for every pinned official task: +398 tasks total, split into 151 MyoBase/Reorient/Walk/Terrain tasks, 57 +MyoChallenge tasks, and 190 MyoDM TrackEnv tasks. For each step pair, EnvPool +is on the left and the pinned MyoSuite renderer is on the right. The images are +generated by +``third_party/myosuite/generate_render_sample.py`` from the pinned official +oracle and the same action sequence used by the render test. The render test +keeps a visual-alignment gate based on frame-level and low-frequency image +delta, so wrong cameras, models, and scenes fail without requiring bitwise +renderer identity. If EnvPool's public API auto-resets a task within those +three calls, the official oracle is reset at that same reset boundary and +synchronized only to the corresponding reset state. + +MyoBase/Reorient/Walk/Terrain: 151 tasks. + +.. image:: ../_static/render_samples/myosuite_myobase_official_compare.png + :width: 900px + :align: center + +MyoChallenge: 57 tasks. + +.. image:: ../_static/render_samples/myosuite_myochallenge_official_compare.png + :width: 900px + :align: center + +MyoDM TrackEnv: 190 tasks. + +.. image:: ../_static/render_samples/myosuite_myodm_official_compare.png + :width: 900px + :align: center diff --git a/docs/index.rst b/docs/index.rst index 0a4df9f7d..4f8ee8b8a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -97,6 +97,7 @@ stable version through `envpool.readthedocs.io/en/stable/ env/gymnasium_robotics env/metaworld env/mujoco_gym + env/myosuite env/procgen env/toy_text env/vizdoom diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 3a7973b12..b86abd400 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -52,6 +52,12 @@ timestep RGB py Mujoco +MyoBase +MyoChallenge +MyoDM +MyoSuite +reafferentation +sarcopenia tran Reacher golang @@ -83,6 +89,7 @@ backends pygame rollout bitwise +Bimanual dtype Subclassed deleter diff --git a/envpool/BUILD b/envpool/BUILD index 2b0d6fa25..ccc4d15d7 100644 --- a/envpool/BUILD +++ b/envpool/BUILD @@ -42,6 +42,7 @@ py_library( "//envpool/mujoco:metaworld_registration", "//envpool/mujoco:mujoco_dmc_registration", "//envpool/mujoco:mujoco_gym_registration", + "//envpool/mujoco:myosuite_registration", "//envpool/mujoco:robotics_registration", "//envpool/procgen:procgen_registration", "//envpool/toy_text:toy_text_registration", diff --git a/envpool/__init__.py b/envpool/__init__.py index 20fcc71c3..d1e321b76 100644 --- a/envpool/__init__.py +++ b/envpool/__init__.py @@ -37,7 +37,7 @@ register, ) -__version__ = "1.2.0" +__version__ = "1.2.2" __all__ = [ "register", "make", diff --git a/envpool/classic_control/classic_control_test.py b/envpool/classic_control/classic_control_test.py index 15073693e..2fdc9c99c 100644 --- a/envpool/classic_control/classic_control_test.py +++ b/envpool/classic_control/classic_control_test.py @@ -13,7 +13,7 @@ # limitations under the License. """Unit tests for classic control environments.""" -from typing import Any, no_type_check +from typing import Any, cast, no_type_check import gymnasium as gym import numpy as np @@ -82,6 +82,44 @@ def test_cartpole(self) -> None: self.run_space_check(env0, env1) self.run_deterministic_check("CartPole-v1") + def test_cartpole_gymnasium_vector_wrapper(self) -> None: + num_envs = 4 + env = make_gym("CartPole-v1", num_envs=num_envs) + self.assertEqual(env.num_envs, num_envs) + self.assertTrue(env.is_vector_env) + self.assertIs(env.single_observation_space, env.observation_space) + self.assertIs(env.single_action_space, env.action_space) + if hasattr(gym, "vector") and hasattr(gym.vector, "VectorEnv"): + self.assertIsInstance(env, gym.vector.VectorEnv) + + vector_wrappers = getattr(gym.wrappers, "vector", None) + normalize_observation = getattr( + vector_wrappers, "NormalizeObservation", None + ) + if normalize_observation is None: + normalize_observation = gym.wrappers.NormalizeObservation + wrapped = cast(Any, normalize_observation)(env) + try: + obs, _ = wrapped.reset() + if hasattr(wrapped, "num_envs"): + self.assertEqual(wrapped.num_envs, num_envs) + if hasattr(wrapped, "is_vector_env"): + self.assertTrue(wrapped.is_vector_env) + self.assertEqual( + obs.shape, (num_envs, *env.single_observation_space.shape) + ) + obs, rew, term, trunc, _ = wrapped.step( + np.zeros(num_envs, dtype=np.int32) + ) + self.assertEqual( + obs.shape, (num_envs, *env.single_observation_space.shape) + ) + self.assertEqual(rew.shape, (num_envs,)) + self.assertEqual(term.shape, (num_envs,)) + self.assertEqual(trunc.shape, (num_envs,)) + finally: + wrapped.close() + def test_pendulum(self) -> None: env0 = gym.make("Pendulum-v1") env1 = make_gym("Pendulum-v1") diff --git a/envpool/entry.py b/envpool/entry.py index 5ee50d517..04a6fa806 100644 --- a/envpool/entry.py +++ b/envpool/entry.py @@ -22,6 +22,7 @@ import envpool.mujoco.dmc.registration # noqa: F401 import envpool.mujoco.gym.registration # noqa: F401 import envpool.mujoco.metaworld.registration # noqa: F401 +import envpool.mujoco.myosuite.registration # noqa: F401 import envpool.mujoco.robotics.registration # noqa: F401 import envpool.procgen.registration # noqa: F401 import envpool.toy_text.registration # noqa: F401 diff --git a/envpool/mujoco/BUILD b/envpool/mujoco/BUILD index bd2706853..30f4bae0d 100644 --- a/envpool/mujoco/BUILD +++ b/envpool/mujoco/BUILD @@ -13,11 +13,14 @@ # limitations under the License. load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") +load("@python_versions//3.12:defs.bzl", py_binary_312 = "py_binary") load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") load("@rules_python//python:defs.bzl", "py_library", "py_test") load("//envpool:requirements.bzl", "requirement") load("//third_party:common.bzl", "copy_to_directory") load("//third_party/metaworld_assets:defs.bzl", "metaworld_runtime_assets") +load("//third_party/myosuite:defs.bzl", "myosuite_runtime_assets") +load("//third_party/myosuite:oracle_requirements.bzl", "oracle_requirement") package(default_visibility = ["//visibility:public"]) @@ -55,6 +58,7 @@ py_library( ":metaworld", ":mujoco_dmc", ":mujoco_gym", + ":myosuite", ":robotics", ], ) @@ -94,6 +98,20 @@ metaworld_runtime_assets( strip_prefix = "metaworld_assets/metaworld/assets/", ) +myosuite_runtime_assets( + name = "gen_myosuite_assets", + srcs = [ + "@myosuite_furniture_sim//:source", + "@myosuite_mpl_sim//:source", + "@myosuite_myo_sim//:source", + "@myosuite_object_sim//:source", + "@myosuite_source//:runtime_assets", + "@myosuite_ycb_sim//:source", + ], + out = "myosuite/assets", + metadata_srcs = ["//third_party/myosuite:myosuite_generated_json"], +) + cc_library( name = "mujoco_gym_env", hdrs = [ @@ -278,6 +296,145 @@ py_library( deps = ["//envpool/python:api"], ) +cc_library( + name = "myosuite_env", + hdrs = [ + "frame_stack.h", + "myosuite/myosuite_env.h", + "offscreen_renderer.h", + ], + data = [":gen_myosuite_assets"], + deps = [ + ":mujoco_render", + ":robotics_env", + "//envpool/core:async_envpool", + "//third_party/myosuite:myosuite_tasks", + "@mujoco//:mujoco_lib", + "@mujoco//:mujoco_obj_decoder_plugin_lib", + "@mujoco//:mujoco_stl_decoder_plugin_lib", + ], +) + +cc_library( + name = "myosuite_clang_tidy", + srcs = ["myosuite/myosuite_clang_tidy.cc"], + deps = [":myosuite_env"], +) + +cc_library( + name = "myosuite_envpool_module", + srcs = [ + "myosuite/mujoco_plugin_init.cc", + "myosuite/myosuite_envpool.cc", + ], + # The native MyoSuite runtime is header-only under myosuite_env.h and is + # compiled into this module object. Coverage instrumentation changes Linux + # floating-point codegen enough to make long oracle rollouts diverge; the + # pybind_extension wrapper still links under coverage so instrumented + # transitive deps resolve their coverage runtime symbols. + copts = select({ + "//:windows": [ + # CI fastbuild otherwise compiles the header-only MyoSuite runtime + # at /Od while the pinned official oracle wheel is release-built. + "/O2", + ], + "//:linux_x86_64": [ + # CI fastbuild otherwise compiles the header-only MyoSuite runtime + # at -O0 while the pinned official oracle wheel is release-built. + # This is x86_64-only because the official Linux x86_64 MuJoCo + # wheel enables platform SIMD; linux_arm64 only needs coverage + # disabled for stable long oracle rollouts. + "-O3", + "-fno-profile-arcs", + "-fno-test-coverage", + ], + "//:linux_arm64": [ + "-fno-profile-arcs", + "-fno-test-coverage", + ], + "@platforms//os:osx": [ + "-fno-coverage-mapping", + "-fno-profile-instr-generate", + ], + "//conditions:default": [], + }), + features = ["-coverage"], + deps = [ + ":myosuite_env", + "//envpool/core:py_envpool", + ], + alwayslink = True, +) + +pybind_extension( + name = "myosuite_envpool", + deps = [ + ":myosuite_envpool_module", + ], +) + +py_binary_312( + name = "generate_myosuite_render_sample", + srcs = [ + "myosuite/myosuite_oracle_probe.py", + "//third_party/myosuite:generate_render_sample.py", + ], + data = select({ + # Linux oracle rendering uses the pinned pip MuJoCo wheel directly. + "//:linux": [], + "//conditions:default": ["@mujoco//:mujoco_shared_lib"], + }) + [ + ":myosuite_oracle_probe", + "@myosuite_furniture_sim//:source", + "@myosuite_mpl_sim//:source", + "@myosuite_myo_sim//:source", + "@myosuite_object_sim//:source", + "@myosuite_source//:source", + "@myosuite_ycb_sim//:source", + ], + imports = ["../.."], + main = "generate_render_sample.py", + deps = [ + ":myosuite", + ":myosuite_registration", + oracle_requirement("click"), + oracle_requirement("dm-control"), + oracle_requirement("flatten-dict"), + oracle_requirement("gitpython"), + oracle_requirement("gymnasium"), + oracle_requirement("h5py"), + oracle_requirement("mujoco"), + oracle_requirement("numpy"), + oracle_requirement("packaging"), + oracle_requirement("pillow"), + oracle_requirement("pink-noise-rl"), + oracle_requirement("sk-video"), + oracle_requirement("termcolor"), + ], +) + +py_library( + name = "myosuite_task_metadata", + srcs = ["myosuite/tasks.py"], + data = [":gen_myosuite_assets"], + imports = ["../.."], +) + +py_library( + name = "myosuite", + srcs = ["myosuite/__init__.py"], + data = [ + ":gen_myosuite_assets", + ":myosuite_envpool", + ], + imports = ["../.."], + deps = [ + ":myosuite_task_metadata", + "//envpool/python:api", + "//envpool/python:glfw_context", + ], +) + cc_test( name = "mujoco_envpool_test", size = "enormous", @@ -323,6 +480,17 @@ py_library( ], ) +py_library( + name = "myosuite_registration", + srcs = ["myosuite/registration.py"], + imports = ["../.."], + deps = [ + ":myosuite", + ":myosuite_task_metadata", + "//envpool:registration", + ], +) + py_test( name = "metaworld_test", size = "enormous", @@ -336,6 +504,94 @@ py_test( ], ) +py_test( + name = "myosuite_test", + size = "enormous", + srcs = ["myosuite/myosuite_test.py"], + imports = ["../.."], + shard_count = 8, + deps = [ + ":myosuite", + ":myosuite_registration", + ":myosuite_task_metadata", + "//envpool/python:glfw_context", + requirement("absl-py"), + requirement("numpy"), + ], +) + +py_test( + name = "myosuite_render_test", + size = "enormous", + srcs = ["myosuite/myosuite_render_test.py"], + data = [":myosuite_oracle_probe"], + imports = ["../.."], + shard_count = 8, + deps = [ + ":myosuite", + ":myosuite_registration", + ":myosuite_task_metadata", + "//envpool/python:glfw_context", + requirement("absl-py"), + requirement("numpy"), + ], +) + +py_test( + name = "myosuite_oracle_align_test", + size = "enormous", + srcs = ["myosuite/myosuite_oracle_align_test.py"], + data = [":myosuite_oracle_probe"], + # This test is an oracle correctness gate, not a Python coverage target. + # rules_python coverage tries to import native object dirs from the + # transitive MyoSuite deps as Python modules and can fail after the unittest + # body already passed with module-not-imported warnings. + features = ["-coverage"], + imports = ["../.."], + shard_count = 8, + deps = [ + ":myosuite", + ":myosuite_registration", + ":myosuite_task_metadata", + requirement("absl-py"), + requirement("numpy"), + ], +) + +py_binary_312( + name = "myosuite_oracle_probe", + srcs = ["myosuite/myosuite_oracle_probe.py"], + data = select({ + # Linux oracle rendering uses the pinned pip MuJoCo wheel directly. + "//:linux": [], + "//conditions:default": ["@mujoco//:mujoco_shared_lib"], + }) + [ + "@myosuite_furniture_sim//:source", + "@myosuite_mpl_sim//:source", + "@myosuite_myo_sim//:source", + "@myosuite_object_sim//:source", + "@myosuite_source//:source", + "@myosuite_ycb_sim//:source", + ], + imports = ["../.."], + deps = [ + "//envpool/python:glfw_context", + oracle_requirement("click"), + oracle_requirement("dm-control"), + oracle_requirement("flatten-dict"), + oracle_requirement("gitpython"), + oracle_requirement("gymnasium"), + oracle_requirement("h5py"), + oracle_requirement("mujoco"), + oracle_requirement("numpy"), + oracle_requirement("packaging"), + oracle_requirement("pillow"), + oracle_requirement("pink-noise-rl"), + oracle_requirement("sk-video"), + oracle_requirement("termcolor"), + ], +) + py_test( name = "metaworld_align_test", size = "enormous", @@ -407,6 +663,27 @@ py_test( ], ) +py_test( + name = "mujoco_egl_teardown_test", + size = "enormous", + srcs = ["mujoco_egl_teardown_test.py"], + imports = ["../.."], + deps = [ + ":metaworld", + ":metaworld_registration", + ":mujoco_dmc", + ":mujoco_dmc_registration", + ":mujoco_gym", + ":mujoco_gym_registration", + ":mujoco_pixel_observation_test_utils", + ":myosuite", + ":myosuite_registration", + ":robotics", + ":robotics_registration", + requirement("absl-py"), + ], +) + py_test( name = "mujoco_dmc_pixel_observation_test", size = "enormous", @@ -473,6 +750,7 @@ py_test( test_suite( name = "mujoco_pixel_observation_test", tests = [ + ":mujoco_egl_teardown_test", ":mujoco_dmc_pixel_observation_test", ":mujoco_gym_pixel_observation_test", ":robotics_pixel_observation_test", diff --git a/envpool/mujoco/gym/ant.h b/envpool/mujoco/gym/ant.h index f7edd3f25..3ddc0d347 100644 --- a/envpool/mujoco/gym/ant.h +++ b/envpool/mujoco/gym/ant.h @@ -39,6 +39,7 @@ class AntEnvFns { "terminate_when_unhealthy"_.Bind(true), "exclude_current_positions_from_observation"_.Bind(true), "xml_file"_.Bind(std::string("ant.xml")), + "gymnasium_v5_render_camera"_.Bind(false), "forward_reward_weight"_.Bind(1.0), "ctrl_cost_weight"_.Bind(0.5), "contact_cost_weight"_.Bind(5e-4), "healthy_reward"_.Bind(1.0), "healthy_z_min"_.Bind(0.2), "healthy_z_max"_.Bind(1.0), @@ -89,6 +90,7 @@ class AntEnvBase : public Env, public MujocoEnv { int id_torso_; bool terminate_when_unhealthy_, no_pos_, use_contact_force_; bool legacy_healthy_reward_, exclude_worldbody_contact_forces_; + bool gymnasium_v5_render_camera_; mjtNum ctrl_cost_weight_, contact_cost_weight_; mjtNum forward_reward_weight_, healthy_reward_; mjtNum healthy_z_min_, healthy_z_max_; @@ -117,6 +119,7 @@ class AntEnvBase : public Env, public MujocoEnv { legacy_healthy_reward_(spec.config["legacy_healthy_reward"_]), exclude_worldbody_contact_forces_( spec.config["exclude_worldbody_contact_forces"_]), + gymnasium_v5_render_camera_(spec.config["gymnasium_v5_render_camera"_]), ctrl_cost_weight_(spec.config["ctrl_cost_weight"_]), contact_cost_weight_(spec.config["contact_cost_weight"_]), forward_reward_weight_(spec.config["forward_reward_weight"_]), @@ -142,6 +145,15 @@ class AntEnvBase : public Env, public MujocoEnv { #endif } + bool RenderCamera(mjvCamera* camera) override { + if (!gymnasium_v5_render_camera_) { + return false; + } + camera->distance = 4.0; + ApplyGymnasiumDefaultCameraId(camera); + return true; + } + bool IsDone() override { return done_; } void Reset() override { diff --git a/envpool/mujoco/gym/half_cheetah.h b/envpool/mujoco/gym/half_cheetah.h index 91acc2202..6435b4b6e 100644 --- a/envpool/mujoco/gym/half_cheetah.h +++ b/envpool/mujoco/gym/half_cheetah.h @@ -35,6 +35,7 @@ class HalfCheetahEnvFns { "frame_stack"_.Bind(1), "post_constraint"_.Bind(true), "exclude_current_positions_from_observation"_.Bind(true), "xml_file"_.Bind(std::string("half_cheetah.xml")), + "gymnasium_v5_render_camera"_.Bind(false), "ctrl_cost_weight"_.Bind(0.1), "forward_reward_weight"_.Bind(1.0), "reset_noise_scale"_.Bind(0.1)); @@ -74,6 +75,7 @@ class HalfCheetahEnvBase : public Env, public MujocoEnv { using Base::spec_; bool no_pos_; + bool gymnasium_v5_render_camera_; mjtNum ctrl_cost_weight_, forward_reward_weight_; std::uniform_real_distribution<> dist_qpos_; std::normal_distribution<> dist_qvel_; @@ -93,6 +95,7 @@ class HalfCheetahEnvBase : public Env, public MujocoEnv { RenderHeightOrDefault(spec.config), RenderCameraIdOrDefault(spec.config)), no_pos_(spec.config["exclude_current_positions_from_observation"_]), + gymnasium_v5_render_camera_(spec.config["gymnasium_v5_render_camera"_]), ctrl_cost_weight_(spec.config["ctrl_cost_weight"_]), forward_reward_weight_(spec.config["forward_reward_weight"_]), dist_qpos_(-spec.config["reset_noise_scale"_], @@ -112,6 +115,15 @@ class HalfCheetahEnvBase : public Env, public MujocoEnv { #endif } + bool RenderCamera(mjvCamera* camera) override { + if (!gymnasium_v5_render_camera_) { + return false; + } + camera->distance = 4.0; + ApplyGymnasiumDefaultCameraId(camera); + return true; + } + bool IsDone() override { return done_; } void Reset() override { diff --git a/envpool/mujoco/gym/hopper.h b/envpool/mujoco/gym/hopper.h index 764785666..a98a00fb2 100644 --- a/envpool/mujoco/gym/hopper.h +++ b/envpool/mujoco/gym/hopper.h @@ -38,6 +38,7 @@ class HopperEnvFns { "legacy_healthy_reward"_.Bind(true), "exclude_current_positions_from_observation"_.Bind(true), "xml_file"_.Bind(std::string("hopper.xml")), + "gymnasium_v5_render_camera"_.Bind(false), "ctrl_cost_weight"_.Bind(1e-3), "forward_reward_weight"_.Bind(1.0), "healthy_reward"_.Bind(1.0), "velocity_min"_.Bind(-10.0), "velocity_max"_.Bind(10.0), "healthy_state_min"_.Bind(-100.0), @@ -79,6 +80,7 @@ class HopperEnvBase : public Env, public MujocoEnv { bool terminate_when_unhealthy_, no_pos_; bool legacy_healthy_reward_; + bool gymnasium_v5_render_camera_; mjtNum ctrl_cost_weight_, forward_reward_weight_; mjtNum healthy_reward_, healthy_z_min_; mjtNum velocity_min_, velocity_max_; @@ -103,6 +105,7 @@ class HopperEnvBase : public Env, public MujocoEnv { terminate_when_unhealthy_(spec.config["terminate_when_unhealthy"_]), no_pos_(spec.config["exclude_current_positions_from_observation"_]), legacy_healthy_reward_(spec.config["legacy_healthy_reward"_]), + gymnasium_v5_render_camera_(spec.config["gymnasium_v5_render_camera"_]), ctrl_cost_weight_(spec.config["ctrl_cost_weight"_]), forward_reward_weight_(spec.config["forward_reward_weight"_]), healthy_reward_(spec.config["healthy_reward"_]), @@ -129,6 +132,20 @@ class HopperEnvBase : public Env, public MujocoEnv { #endif } + bool RenderCamera(mjvCamera* camera) override { + if (!gymnasium_v5_render_camera_) { + return false; + } + camera->trackbodyid = 2; + camera->distance = 3.0; + camera->lookat[0] = 0.0; + camera->lookat[1] = 0.0; + camera->lookat[2] = 1.15; + camera->elevation = -20.0; + ApplyGymnasiumDefaultCameraId(camera); + return true; + } + bool IsDone() override { return done_; } void Reset() override { diff --git a/envpool/mujoco/gym/humanoid.h b/envpool/mujoco/gym/humanoid.h index 6fdca03b9..c0690a177 100644 --- a/envpool/mujoco/gym/humanoid.h +++ b/envpool/mujoco/gym/humanoid.h @@ -40,6 +40,7 @@ class HumanoidEnvFns { "terminate_when_unhealthy"_.Bind(true), "exclude_current_positions_from_observation"_.Bind(true), "xml_file"_.Bind(std::string("humanoid.xml")), + "gymnasium_v5_render_camera"_.Bind(false), "ctrl_cost_weight"_.Bind(0.1), "healthy_reward"_.Bind(5.0), "healthy_z_min"_.Bind(1.0), "healthy_z_max"_.Bind(2.0), "contact_cost_weight"_.Bind(5e-7), "contact_cost_max"_.Bind(10.0), @@ -93,6 +94,7 @@ class HumanoidEnvBase : public Env, public MujocoEnv { bool terminate_when_unhealthy_, no_pos_, use_contact_force_; bool legacy_healthy_reward_, exclude_worldbody_observations_; bool exclude_root_actuator_forces_; + bool gymnasium_v5_render_camera_; mjtNum ctrl_cost_weight_, forward_reward_weight_, healthy_reward_; mjtNum healthy_z_min_, healthy_z_max_; mjtNum contact_cost_weight_, contact_cost_max_; @@ -120,6 +122,7 @@ class HumanoidEnvBase : public Env, public MujocoEnv { spec.config["exclude_worldbody_observations"_]), exclude_root_actuator_forces_( spec.config["exclude_root_actuator_forces"_]), + gymnasium_v5_render_camera_(spec.config["gymnasium_v5_render_camera"_]), ctrl_cost_weight_(spec.config["ctrl_cost_weight"_]), forward_reward_weight_(spec.config["forward_reward_weight"_]), healthy_reward_(spec.config["healthy_reward"_]), @@ -143,6 +146,20 @@ class HumanoidEnvBase : public Env, public MujocoEnv { #endif } + bool RenderCamera(mjvCamera* camera) override { + if (!gymnasium_v5_render_camera_) { + return false; + } + camera->trackbodyid = 1; + camera->distance = 4.0; + camera->lookat[0] = 0.0; + camera->lookat[1] = 0.0; + camera->lookat[2] = 2.0; + camera->elevation = -20.0; + ApplyGymnasiumDefaultCameraId(camera); + return true; + } + bool IsDone() override { return done_; } void Reset() override { diff --git a/envpool/mujoco/gym/humanoid_standup.h b/envpool/mujoco/gym/humanoid_standup.h index ec5dba974..ec6a707d0 100644 --- a/envpool/mujoco/gym/humanoid_standup.h +++ b/envpool/mujoco/gym/humanoid_standup.h @@ -38,6 +38,7 @@ class HumanoidStandupEnvFns { "exclude_worldbody_observations"_.Bind(false), "exclude_root_actuator_forces"_.Bind(false), "xml_file"_.Bind(std::string("humanoidstandup.xml")), + "gymnasium_v5_render_camera"_.Bind(false), "ctrl_cost_weight"_.Bind(0.1), "contact_cost_weight"_.Bind(5e-7), "contact_cost_max"_.Bind(10.0), "healthy_reward"_.Bind(1.0), "reset_noise_scale"_.Bind(1e-2)); @@ -86,6 +87,7 @@ class HumanoidStandupEnvBase : public Env, public MujocoEnv { bool no_pos_; bool exclude_worldbody_observations_, exclude_root_actuator_forces_; + bool gymnasium_v5_render_camera_; mjtNum ctrl_cost_weight_, contact_cost_weight_, contact_cost_max_; mjtNum forward_reward_weight_, healthy_reward_; std::uniform_real_distribution<> dist_; @@ -109,6 +111,7 @@ class HumanoidStandupEnvBase : public Env, public MujocoEnv { spec.config["exclude_worldbody_observations"_]), exclude_root_actuator_forces_( spec.config["exclude_root_actuator_forces"_]), + gymnasium_v5_render_camera_(spec.config["gymnasium_v5_render_camera"_]), ctrl_cost_weight_(spec.config["ctrl_cost_weight"_]), contact_cost_weight_(spec.config["contact_cost_weight"_]), contact_cost_max_(spec.config["contact_cost_max"_]), @@ -130,6 +133,20 @@ class HumanoidStandupEnvBase : public Env, public MujocoEnv { #endif } + bool RenderCamera(mjvCamera* camera) override { + if (!gymnasium_v5_render_camera_) { + return false; + } + camera->trackbodyid = 1; + camera->distance = 4.0; + camera->lookat[0] = 0.0; + camera->lookat[1] = 0.0; + camera->lookat[2] = 0.8925; + camera->elevation = -20.0; + ApplyGymnasiumDefaultCameraId(camera); + return true; + } + bool IsDone() override { return done_; } void Reset() override { diff --git a/envpool/mujoco/gym/inverted_double_pendulum.h b/envpool/mujoco/gym/inverted_double_pendulum.h index 8e569474b..304063994 100644 --- a/envpool/mujoco/gym/inverted_double_pendulum.h +++ b/envpool/mujoco/gym/inverted_double_pendulum.h @@ -37,8 +37,9 @@ class InvertedDoublePendulumEnvFns { "healthy_reward"_.Bind(10.0), "reward_if_not_terminated"_.Bind(false), "constraint_obs_dim"_.Bind(3), "xml_file"_.Bind(std::string("inverted_double_pendulum.xml")), - "healthy_z_max"_.Bind(1.0), "observation_min"_.Bind(-10.0), - "observation_max"_.Bind(10.0), "reset_noise_scale"_.Bind(0.1)); + "gymnasium_v5_render_camera"_.Bind(false), "healthy_z_max"_.Bind(1.0), + "observation_min"_.Bind(-10.0), "observation_max"_.Bind(10.0), + "reset_noise_scale"_.Bind(0.1)); } template static decltype(auto) StateSpec(const Config& conf) { diff --git a/envpool/mujoco/gym/inverted_pendulum.h b/envpool/mujoco/gym/inverted_pendulum.h index 61c02c13b..9a8987596 100644 --- a/envpool/mujoco/gym/inverted_pendulum.h +++ b/envpool/mujoco/gym/inverted_pendulum.h @@ -31,13 +31,13 @@ namespace mujoco_gym { class InvertedPendulumEnvFns { public: static decltype(auto) DefaultConfig() { - return MakeDict("reward_threshold"_.Bind(950.0), "frame_skip"_.Bind(2), - "frame_stack"_.Bind(1), "post_constraint"_.Bind(true), - "healthy_reward"_.Bind(1.0), - "reward_if_not_terminated"_.Bind(false), - "xml_file"_.Bind(std::string("inverted_pendulum.xml")), - "healthy_z_min"_.Bind(-0.2), "healthy_z_max"_.Bind(0.2), - "reset_noise_scale"_.Bind(0.01)); + return MakeDict( + "reward_threshold"_.Bind(950.0), "frame_skip"_.Bind(2), + "frame_stack"_.Bind(1), "post_constraint"_.Bind(true), + "healthy_reward"_.Bind(1.0), "reward_if_not_terminated"_.Bind(false), + "xml_file"_.Bind(std::string("inverted_pendulum.xml")), + "gymnasium_v5_render_camera"_.Bind(false), "healthy_z_min"_.Bind(-0.2), + "healthy_z_max"_.Bind(0.2), "reset_noise_scale"_.Bind(0.01)); } template static decltype(auto) StateSpec(const Config& conf) { @@ -72,6 +72,7 @@ class InvertedPendulumEnvBase : public Env, public MujocoEnv { using Base::spec_; bool reward_if_not_terminated_; + bool gymnasium_v5_render_camera_; mjtNum healthy_reward_, healthy_z_min_, healthy_z_max_; std::uniform_real_distribution<> dist_; @@ -90,6 +91,7 @@ class InvertedPendulumEnvBase : public Env, public MujocoEnv { RenderHeightOrDefault(spec.config), RenderCameraIdOrDefault(spec.config)), reward_if_not_terminated_(spec.config["reward_if_not_terminated"_]), + gymnasium_v5_render_camera_(spec.config["gymnasium_v5_render_camera"_]), healthy_reward_(spec.config["healthy_reward"_]), healthy_z_min_(spec.config["healthy_z_min"_]), healthy_z_max_(spec.config["healthy_z_max"_]), @@ -109,6 +111,16 @@ class InvertedPendulumEnvBase : public Env, public MujocoEnv { #endif } + bool RenderCamera(mjvCamera* camera) override { + if (!gymnasium_v5_render_camera_) { + return false; + } + camera->trackbodyid = 0; + camera->distance = 2.04; + ApplyGymnasiumDefaultCameraId(camera); + return true; + } + bool IsDone() override { return done_; } void Reset() override { diff --git a/envpool/mujoco/gym/mujoco_env.h b/envpool/mujoco/gym/mujoco_env.h index bb762e820..3b78a5a21 100644 --- a/envpool/mujoco/gym/mujoco_env.h +++ b/envpool/mujoco/gym/mujoco_env.h @@ -162,7 +162,9 @@ class MujocoEnv : public RenderableEnv { #else if (renderer_ == nullptr) { renderer_ = std::make_unique( - envpool::mujoco::CameraPolicy::kGymLike); + envpool::mujoco::CameraPolicy::kGymLike, + /*disable_auxiliary_visuals=*/false, /*share_cgl_context=*/false, + /*prefer_offline_cgl_context=*/false, /*resize_offscreen=*/true); } renderer_->Render(model_, data_, width, height, camera_id, rgb, camera); #endif @@ -224,6 +226,7 @@ class MujocoEnv : public RenderableEnv { mjv_defaultCamera(camera); camera->type = mjCAMERA_FREE; camera->fixedcamid = -1; + camera->trackbodyid = -1; camera->distance = model_->stat.extent; if (model_->ngeom == 0) { return; @@ -233,6 +236,17 @@ class MujocoEnv : public RenderableEnv { } } + void ApplyGymnasiumDefaultCameraId(mjvCamera* camera) const { + int track_camera_id = mj_name2id(model_, mjOBJ_CAMERA, "track"); + if (track_camera_id >= 0) { + camera->type = mjCAMERA_FIXED; + camera->fixedcamid = track_camera_id; + return; + } + camera->type = mjCAMERA_FREE; + camera->fixedcamid = -1; + } + mjtNum* PrepareObservation(Array* target) { return frame_stack_buffer_.Prepare("obs", target); } diff --git a/envpool/mujoco/gym/mujoco_render_test.py b/envpool/mujoco/gym/mujoco_render_test.py index 58e91409a..723cd15b1 100644 --- a/envpool/mujoco/gym/mujoco_render_test.py +++ b/envpool/mujoco/gym/mujoco_render_test.py @@ -94,9 +94,12 @@ def __init__(self, width: int, height: int): del width, height from mujoco.cgl import cgl + self._pixel_format: Any = None + self._context: Any = None + self._locked = False attrib = cgl.CGLPixelFormatAttribute profile = cgl.CGLOpenGLProfile - attrib_values = ( + preferred_attribs = ( attrib.CGLPFAOpenGLProfile, profile.CGLOGLPVersion_Legacy, attrib.CGLPFAColorSize, @@ -107,19 +110,32 @@ def __init__(self, width: int, height: int): 24, attrib.CGLPFAStencilSize, 8, - attrib.CGLPFAAllowOfflineRenderers, - 0, + attrib.CGLPFAMultisample, + attrib.CGLPFASampleBuffers, + 1, + attrib.CGLPFASample, + 4, + attrib.CGLPFAAccelerated, 0, # terminator ) - attribs = (ctypes.c_int * len(attrib_values))(*attrib_values) - self._pixel_format = cgl.CGLPixelFormatObj() - num_pixel_formats = cgl.GLint() - cgl.CGLChoosePixelFormat( - attribs, - ctypes.byref(self._pixel_format), - ctypes.byref(num_pixel_formats), + offline_attribs = ( + attrib.CGLPFAOpenGLProfile, + profile.CGLOGLPVersion_Legacy, + attrib.CGLPFAColorSize, + 24, + attrib.CGLPFAAlphaSize, + 8, + attrib.CGLPFADepthSize, + 24, + attrib.CGLPFAStencilSize, + 8, + attrib.CGLPFAAllowOfflineRenderers, + 0, # terminator ) - if not self._pixel_format or num_pixel_formats.value == 0: + + if not self._choose_pixel_format( + cgl, preferred_attribs + ) and not self._choose_pixel_format(cgl, offline_attribs): raise RuntimeError("failed to create CGL pixel format") self._context = cgl.CGLContextObj() @@ -132,14 +148,32 @@ def __init__(self, width: int, height: int): cgl.CGLReleasePixelFormat(self._pixel_format) self._pixel_format = None raise RuntimeError("failed to create CGL context") - self._locked = False + + def _choose_pixel_format( + self, cgl: Any, attrib_values: tuple[int, ...] + ) -> bool: + attribs = (ctypes.c_int * len(attrib_values))(*attrib_values) + pixel_format = cgl.CGLPixelFormatObj() + num_pixel_formats = cgl.GLint() + try: + cgl.CGLChoosePixelFormat( + attribs, + ctypes.byref(pixel_format), + ctypes.byref(num_pixel_formats), + ) + except cgl.CGLError: + return False + if not pixel_format or num_pixel_formats.value == 0: + return False + self._pixel_format = pixel_format + return True def make_current(self) -> None: from mujoco.cgl import cgl cgl.CGLSetCurrentContext(self._context) - # Mirror mujoco.cgl.GLContext so the official renderer uses the - # same CGL lifecycle as EnvPool's native renderer. + # Mirror mujoco.cgl.GLContext's pixel format while keeping the + # lock lifecycle idempotent for repeated render() calls. if not self._locked: cgl.CGLLockContext(self._context) self._locked = True diff --git a/envpool/mujoco/gym/pusher.h b/envpool/mujoco/gym/pusher.h index e5936add5..ba9e98241 100644 --- a/envpool/mujoco/gym/pusher.h +++ b/envpool/mujoco/gym/pusher.h @@ -37,6 +37,7 @@ class PusherEnvFns { "ctrl_cost_weight"_.Bind(0.1), "dist_cost_weight"_.Bind(1.0), "near_cost_weight"_.Bind(0.5), "xml_file"_.Bind(std::string("pusher.xml")), + "gymnasium_v5_render_camera"_.Bind(false), "reward_after_step"_.Bind(false), "weighted_reward_info"_.Bind(false), "reset_qvel_scale"_.Bind(0.005), "cylinder_x_min"_.Bind(-0.3), "cylinder_x_max"_.Bind(0.0), "cylinder_y_min"_.Bind(-0.2), @@ -102,8 +103,7 @@ class PusherEnvBase : public Env, public MujocoEnv { cylinder_dist_min_(spec.config["cylinder_dist_min"_]), reward_after_step_(spec.config["reward_after_step"_]), weighted_reward_info_(spec.config["weighted_reward_info"_]), - gymnasium_v5_render_camera_(spec.config["xml_file"_] == - std::string("pusher_v5.xml")), + gymnasium_v5_render_camera_(spec.config["gymnasium_v5_render_camera"_]), dist_qpos_x_(spec.config["cylinder_x_min"_], spec.config["cylinder_x_max"_]), dist_qpos_y_(spec.config["cylinder_y_min"_], @@ -142,6 +142,7 @@ class PusherEnvBase : public Env, public MujocoEnv { } camera->trackbodyid = -1; camera->distance = 4.0; + ApplyGymnasiumDefaultCameraId(camera); return true; } diff --git a/envpool/mujoco/gym/reacher.h b/envpool/mujoco/gym/reacher.h index 753a950de..dc8d94f88 100644 --- a/envpool/mujoco/gym/reacher.h +++ b/envpool/mujoco/gym/reacher.h @@ -37,6 +37,7 @@ class ReacherEnvFns { "ctrl_cost_weight"_.Bind(1.0), "reward_after_step"_.Bind(false), "obs_include_z_distance"_.Bind(true), "dist_cost_weight"_.Bind(1.0), "xml_file"_.Bind(std::string("reacher.xml")), + "gymnasium_v5_render_camera"_.Bind(false), "reset_qpos_scale"_.Bind(0.1), "reset_qvel_scale"_.Bind(0.005), "reset_goal_scale"_.Bind(0.2)); } @@ -75,6 +76,7 @@ class ReacherEnvBase : public Env, public MujocoEnv { int id_fingertip_, id_target_; bool reward_after_step_, obs_include_z_distance_; + bool gymnasium_v5_render_camera_; mjtNum ctrl_cost_weight_, dist_cost_weight_, reset_goal_scale_; std::uniform_real_distribution<> dist_qpos_, dist_qvel_, dist_goal_; @@ -96,6 +98,7 @@ class ReacherEnvBase : public Env, public MujocoEnv { id_target_(mj_name2id(model_, mjOBJ_XBODY, "target")), reward_after_step_(spec.config["reward_after_step"_]), obs_include_z_distance_(spec.config["obs_include_z_distance"_]), + gymnasium_v5_render_camera_(spec.config["gymnasium_v5_render_camera"_]), ctrl_cost_weight_(spec.config["ctrl_cost_weight"_]), dist_cost_weight_(spec.config["dist_cost_weight"_]), reset_goal_scale_(spec.config["reset_goal_scale"_]), @@ -129,6 +132,15 @@ class ReacherEnvBase : public Env, public MujocoEnv { #endif } + bool RenderCamera(mjvCamera* camera) override { + if (!gymnasium_v5_render_camera_) { + return false; + } + camera->trackbodyid = 0; + ApplyGymnasiumDefaultCameraId(camera); + return true; + } + bool IsDone() override { return done_; } void Reset() override { diff --git a/envpool/mujoco/gym/registration.py b/envpool/mujoco/gym/registration.py index a10fc2fa9..d59ab0053 100644 --- a/envpool/mujoco/gym/registration.py +++ b/envpool/mujoco/gym/registration.py @@ -34,6 +34,8 @@ for task, versions, max_episode_steps in gym_mujoco_envs: for version in versions: extra_args: dict[str, Any] = {} + if version == "v5": + extra_args["gymnasium_v5_render_camera"] = True if task in ["Ant", "Humanoid"] and version == "v3": extra_args["use_contact_force"] = True if task == "Ant" and version == "v5": diff --git a/envpool/mujoco/gym/swimmer.h b/envpool/mujoco/gym/swimmer.h index ae897eb1b..8b3d914f2 100644 --- a/envpool/mujoco/gym/swimmer.h +++ b/envpool/mujoco/gym/swimmer.h @@ -35,6 +35,7 @@ class SwimmerEnvFns { "frame_stack"_.Bind(1), "post_constraint"_.Bind(true), "exclude_current_positions_from_observation"_.Bind(true), "xml_file"_.Bind(std::string("swimmer.xml")), + "gymnasium_v5_render_camera"_.Bind(false), "forward_reward_weight"_.Bind(1.0), "ctrl_cost_weight"_.Bind(1e-4), "reset_noise_scale"_.Bind(0.1)); diff --git a/envpool/mujoco/gym/walker2d.h b/envpool/mujoco/gym/walker2d.h index 5cab0800d..9d1b1d5a8 100644 --- a/envpool/mujoco/gym/walker2d.h +++ b/envpool/mujoco/gym/walker2d.h @@ -38,6 +38,7 @@ class Walker2dEnvFns { "exclude_current_positions_from_observation"_.Bind(true), "legacy_healthy_reward"_.Bind(true), "xml_file"_.Bind(std::string("walker2d.xml")), + "gymnasium_v5_render_camera"_.Bind(false), "forward_reward_weight"_.Bind(1.0), "healthy_reward"_.Bind(1.0), "healthy_z_min"_.Bind(0.8), "healthy_z_max"_.Bind(2.0), "healthy_angle_min"_.Bind(-1.0), "healthy_angle_max"_.Bind(1.0), @@ -78,6 +79,7 @@ class Walker2dEnvBase : public Env, public MujocoEnv { bool terminate_when_unhealthy_, no_pos_; bool legacy_healthy_reward_; + bool gymnasium_v5_render_camera_; mjtNum ctrl_cost_weight_, forward_reward_weight_; mjtNum healthy_reward_, healthy_z_min_, healthy_z_max_; mjtNum healthy_angle_min_, healthy_angle_max_; @@ -101,6 +103,7 @@ class Walker2dEnvBase : public Env, public MujocoEnv { terminate_when_unhealthy_(spec.config["terminate_when_unhealthy"_]), no_pos_(spec.config["exclude_current_positions_from_observation"_]), legacy_healthy_reward_(spec.config["legacy_healthy_reward"_]), + gymnasium_v5_render_camera_(spec.config["gymnasium_v5_render_camera"_]), ctrl_cost_weight_(spec.config["ctrl_cost_weight"_]), forward_reward_weight_(spec.config["forward_reward_weight"_]), healthy_reward_(spec.config["healthy_reward"_]), @@ -126,6 +129,20 @@ class Walker2dEnvBase : public Env, public MujocoEnv { #endif } + bool RenderCamera(mjvCamera* camera) override { + if (!gymnasium_v5_render_camera_) { + return false; + } + camera->trackbodyid = 2; + camera->distance = 4.0; + camera->lookat[0] = 0.0; + camera->lookat[1] = 0.0; + camera->lookat[2] = 1.15; + camera->elevation = -20.0; + ApplyGymnasiumDefaultCameraId(camera); + return true; + } + bool IsDone() override { return done_; } void Reset() override { diff --git a/envpool/mujoco/mujoco_egl_teardown_test.py b/envpool/mujoco/mujoco_egl_teardown_test.py new file mode 100644 index 000000000..3f11a307f --- /dev/null +++ b/envpool/mujoco/mujoco_egl_teardown_test.py @@ -0,0 +1,48 @@ +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""EGL teardown regression coverage for native MuJoCo pixel envs.""" + +from absl.testing import absltest + +from envpool.mujoco.pixel_observation_test_utils import ( + assert_egl_pixel_env_teardown_exits_cleanly, +) + +_EGL_TEARDOWN_CASES = ( + ("dmc", "envpool.mujoco.dmc.registration", "WalkerWalk-v1"), + ("gym", "envpool.mujoco.gym.registration", "Walker2d-v4"), + ("robotics", "envpool.mujoco.robotics.registration", "FetchReach-v4"), + ( + "metaworld", + "envpool.mujoco.metaworld.registration", + "MetaWorld/Reach-v3", + ), + ( + "myosuite", + "envpool.mujoco.myosuite.registration", + "myoFingerReachFixed-v0", + ), +) + + +class MujocoEglTeardownTest(absltest.TestCase): + """Teardown tests for MuJoCo env families backed by native GL rendering.""" + + def test_pixel_env_teardown_exits_cleanly_for_all_gl_families(self) -> None: + """All native MuJoCo pixel families should exit cleanly under EGL.""" + assert_egl_pixel_env_teardown_exits_cleanly(self, _EGL_TEARDOWN_CASES) + + +if __name__ == "__main__": + absltest.main() diff --git a/envpool/mujoco/myosuite/__init__.py b/envpool/mujoco/myosuite/__init__.py new file mode 100644 index 000000000..189d77394 --- /dev/null +++ b/envpool/mujoco/myosuite/__init__.py @@ -0,0 +1,49 @@ +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MyoSuite native MuJoCo envs.""" + +import os +import platform + +from envpool.python.glfw_context import preload_windows_gl_dlls + +if platform.system() == "Windows": + preload_windows_gl_dlls(strict=bool(os.environ.get("ENVPOOL_DLL_DIR"))) + +from envpool.mujoco.myosuite_envpool import ( + _MyoSuiteEnvPool, + _MyoSuiteEnvSpec, + _MyoSuitePixelEnvPool, + _MyoSuitePixelEnvSpec, +) + +from envpool.python.api import py_env + +MyoSuiteEnvSpec, MyoSuiteDMEnvPool, MyoSuiteGymnasiumEnvPool = py_env( + _MyoSuiteEnvSpec, _MyoSuiteEnvPool +) +( + MyoSuitePixelEnvSpec, + MyoSuitePixelDMEnvPool, + MyoSuitePixelGymnasiumEnvPool, +) = py_env(_MyoSuitePixelEnvSpec, _MyoSuitePixelEnvPool) + +__all__ = [ + "MyoSuiteDMEnvPool", + "MyoSuiteEnvSpec", + "MyoSuiteGymnasiumEnvPool", + "MyoSuitePixelDMEnvPool", + "MyoSuitePixelEnvSpec", + "MyoSuitePixelGymnasiumEnvPool", +] diff --git a/envpool/mujoco/myosuite/mujoco_plugin_init.cc b/envpool/mujoco/myosuite/mujoco_plugin_init.cc new file mode 100644 index 000000000..f58cf0a64 --- /dev/null +++ b/envpool/mujoco/myosuite/mujoco_plugin_init.cc @@ -0,0 +1,29 @@ +// Copyright 2026 Garena Online Private Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef _WIN32 + +#include + +extern "C" int __stdcall MjObjDecoderDllMain(void* hinst, DWORD reason, + void* reserved); +extern "C" int __stdcall MjStlDecoderDllMain(void* hinst, DWORD reason, + void* reserved); + +extern "C" BOOL WINAPI DllMain(HINSTANCE hinst, DWORD reason, LPVOID reserved) { + return MjObjDecoderDllMain(hinst, reason, reserved) && + MjStlDecoderDllMain(hinst, reason, reserved); +} + +#endif // _WIN32 diff --git a/envpool/mujoco/myosuite/myosuite_clang_tidy.cc b/envpool/mujoco/myosuite/myosuite_clang_tidy.cc new file mode 100644 index 000000000..c1dd5b30b --- /dev/null +++ b/envpool/mujoco/myosuite/myosuite_clang_tidy.cc @@ -0,0 +1,26 @@ +// Copyright 2026 Garena Online Private Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "envpool/mujoco/myosuite/myosuite_env.h" + +namespace { + +static_assert( + std::is_same_v); +static_assert(std::is_same_v); + +} // namespace diff --git a/envpool/mujoco/myosuite/myosuite_env.h b/envpool/mujoco/myosuite/myosuite_env.h new file mode 100644 index 000000000..3d2a899f8 --- /dev/null +++ b/envpool/mujoco/myosuite/myosuite_env.h @@ -0,0 +1,2595 @@ +// Copyright 2026 Garena Online Private Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ENVPOOL_MUJOCO_MYOSUITE_MYOSUITE_ENV_H_ +#define ENVPOOL_MUJOCO_MYOSUITE_MYOSUITE_ENV_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "envpool/core/async_envpool.h" +#include "envpool/core/env.h" +#include "envpool/mujoco/robotics/mujoco_env.h" +#include "third_party/myosuite/myosuite_reference_data.h" +#include "third_party/myosuite/myosuite_task_metadata.h" +#include "third_party/myosuite/myosuite_tasks.h" + +namespace myosuite { + +constexpr int kMyoSuiteTestStatePad = 65536; + +using envpool::mujoco::PixelObservationEnvFns; +using envpool::mujoco::RenderCameraIdOrDefault; +using envpool::mujoco::RenderHeightOrDefault; +using envpool::mujoco::RenderWidthOrDefault; +using envpool::mujoco::StackSpec; +using third_party::myosuite::GetMyoSuiteReferenceData; +using third_party::myosuite::GetMyoSuiteTask; +using third_party::myosuite::GetMyoSuiteTaskMetadata; +using third_party::myosuite::MyoSuiteMuscleCondition; +using third_party::myosuite::MyoSuiteReferenceData; +using third_party::myosuite::MyoSuiteReferenceType; +using third_party::myosuite::MyoSuiteTaskDef; +using third_party::myosuite::MyoSuiteTaskKind; +using third_party::myosuite::MyoSuiteTaskMetadata; + +class MyoSuiteEnvFns { + public: + static decltype(auto) DefaultConfig() { + return MakeDict("frame_stack"_.Bind(1), + "task_name"_.Bind(std::string("myoFingerReachFixed-v0"))); + } + + template + static decltype(auto) StateSpec(const Config& conf) { + const auto& task = GetMyoSuiteTask(std::string(conf["task_name"_])); + mjtNum inf = std::numeric_limits::infinity(); + return MakeDict( + "obs"_.Bind(StackSpec(Spec({task.obs_dim}, {-inf, inf}), + conf["frame_stack"_])), + "info:task_id"_.Bind(Spec({-1})), + "info:sparse"_.Bind(Spec({-1})), + "info:solved"_.Bind(Spec({-1}, {0.0, 1.0})), + "info:oracle_numpy2_broken"_.Bind(Spec({})), +#ifdef ENVPOOL_TEST + "info:qpos0"_.Bind(Spec({2048})), + "info:qvel0"_.Bind(Spec({2048})), + "info:act0"_.Bind(Spec({2048})), + "info:qacc0"_.Bind(Spec({2048})), + "info:qacc_warmstart0"_.Bind(Spec({2048})), + "info:qpos"_.Bind(Spec({2048})), + "info:qvel"_.Bind(Spec({2048})), + "info:act"_.Bind(Spec({2048})), + "info:ctrl"_.Bind(Spec({2048})), + "info:qacc"_.Bind(Spec({2048})), + "info:qacc_warmstart"_.Bind(Spec({2048})), + "info:actuator_length"_.Bind(Spec({2048})), + "info:actuator_velocity"_.Bind(Spec({2048})), + "info:actuator_force"_.Bind(Spec({2048})), + "info:fatigue_ma"_.Bind(Spec({2048})), + "info:fatigue_mr"_.Bind(Spec({2048})), + "info:fatigue_mf"_.Bind(Spec({2048})), + "info:fatigue_tl"_.Bind(Spec({2048})), + "info:fatigue_tauact"_.Bind(Spec({2048})), + "info:fatigue_taudeact"_.Bind(Spec({2048})), + "info:fatigue_dt"_.Bind(Spec({})), + "info:site_pos"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:site_quat"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:site_xpos"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:site_size"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:site_rgba"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:body_pos"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:body_quat"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:body_mass"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:light_xpos"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:light_xdir"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:geom_pos"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:geom_quat"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:geom_size"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:geom_xpos"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:geom_xmat"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:geom_rgba"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:geom_friction"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:geom_aabb"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:geom_rbound"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:geom_contype"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:geom_conaffinity"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:geom_type"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:geom_condim"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:hfield_data"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:mocap_pos"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:mocap_quat"_.Bind(Spec({kMyoSuiteTestStatePad})), + "info:time"_.Bind(Spec({})), + "info:model_timestep"_.Bind(Spec({})), + "info:frame_skip"_.Bind(Spec({})), +#endif + "info:model_nq"_.Bind(Spec({})), + "info:model_nv"_.Bind(Spec({})), + "info:model_na"_.Bind(Spec({})), + "info:model_nu"_.Bind(Spec({})), + "info:model_nsite"_.Bind(Spec({})), + "info:model_nbody"_.Bind(Spec({})), + "info:model_ngeom"_.Bind(Spec({})), + "info:model_nhfielddata"_.Bind(Spec({})), + "info:model_nmocap"_.Bind(Spec({}))); + } + + template + static decltype(auto) ActionSpec(const Config& conf) { + const auto& task = GetMyoSuiteTask(std::string(conf["task_name"_])); + return MakeDict( + "action"_.Bind(Spec({-1, task.action_dim}, {-1.0, 1.0}))); + } +}; + +using MyoSuiteEnvSpec = EnvSpec; +using MyoSuitePixelEnvFns = PixelObservationEnvFns; +using MyoSuitePixelEnvSpec = EnvSpec; + +template +class MyoSuiteEnvBase : public Env, + public gymnasium_robotics::MujocoRobotEnv { + protected: + using Base = Env; + using Base::Allocate; + using Base::gen_; + using Base::spec_; + + enum class OslPhase : std::uint8_t { + kEStance, + kLStance, + kESwing, + kLSwing, + }; + + struct OslStateParams { + mjtNum knee_stiffness; + mjtNum knee_damping; + mjtNum knee_target_angle; + mjtNum ankle_stiffness; + mjtNum ankle_damping; + mjtNum ankle_target_angle; + }; + + const MyoSuiteTaskDef& task_; + const MyoSuiteTaskMetadata& metadata_; + const MyoSuiteReferenceData& reference_; + int task_index_; + std::vector obs_keys_; + std::vector> reward_weights_; + std::vector metadata_init_qpos_; + std::vector metadata_init_qvel_; + std::vector reset_qacc_warmstart_; + std::vector target_jnt_value_; + std::vector tip_sites_; + std::vector target_sites_; + std::vector> target_reach_low_; + std::vector> target_reach_high_; + std::vector last_ctrl_; + std::vector muscle_actuator_ids_; + std::vector fatigue_tauact_; + std::vector fatigue_taudeact_; + std::vector fatigue_ma_; + std::vector fatigue_mr_; + std::vector fatigue_mf_; + std::vector fatigue_tl_; + int task_step_{0}; + int myodm_reference_index_{0}; + mjtNum myodm_lift_z_{0.0}; + OslPhase osl_phase_{OslPhase::kEStance}; + mjtNum osl_body_weight_{0.0}; + int osl_knee_actuator_id_{-1}; + int osl_ankle_actuator_id_{-1}; + int osl_knee_joint_id_{-1}; + int osl_ankle_joint_id_{-1}; + int osl_load_sensor_id_{-1}; + std::vector tabletennis_init_paddle_quat_; + std::array challenge_reorient_goal_obj_offset_{}; + int bimanual_goal_touch_{0}; + mjtNum bimanual_init_obj_z_{0.0}; + mjtNum bimanual_init_palm_z_{0.0}; + mjtNum sparse_{0.0}; + mjtNum solved_{0.0}; +#ifdef ENVPOOL_TEST + std::vector qpos0_pad_; + std::vector qvel0_pad_; + std::vector act0_pad_; + std::vector qacc0_pad_; + std::vector qacc_warmstart0_pad_; + std::vector qpos_pad_; + std::vector qvel_pad_; + std::vector act_pad_; + std::vector ctrl_pad_; + std::vector qacc_pad_; + std::vector qacc_warmstart_pad_; + std::vector actuator_length_pad_; + std::vector actuator_velocity_pad_; + std::vector actuator_force_pad_; + std::vector fatigue_ma_pad_; + std::vector fatigue_mr_pad_; + std::vector fatigue_mf_pad_; + std::vector fatigue_tl_pad_; + std::vector fatigue_tauact_pad_; + std::vector fatigue_taudeact_pad_; + std::vector site_pos_pad_; + std::vector site_quat_pad_; + std::vector site_xpos_pad_; + std::vector site_size_pad_; + std::vector site_rgba_pad_; + std::vector body_pos_pad_; + std::vector body_quat_pad_; + std::vector body_mass_pad_; + std::vector light_xpos_pad_; + std::vector light_xdir_pad_; + std::vector geom_pos_pad_; + std::vector geom_quat_pad_; + std::vector geom_size_pad_; + std::vector geom_xpos_pad_; + std::vector geom_xmat_pad_; + std::vector geom_rgba_pad_; + std::vector geom_friction_pad_; + std::vector geom_aabb_pad_; + std::vector geom_rbound_pad_; + std::vector geom_contype_pad_; + std::vector geom_conaffinity_pad_; + std::vector geom_type_pad_; + std::vector geom_condim_pad_; + std::vector hfield_data_pad_; + std::vector mocap_pos_pad_; + std::vector mocap_quat_pad_; +#endif + + public: + using Spec = EnvSpecT; + using Action = typename Base::Action; + + MyoSuiteEnvBase(const Spec& spec, int env_id) + : Env(spec, env_id), + gymnasium_robotics::MujocoRobotEnv( + spec.config["base_path"_], + AssetPath(spec.config["base_path"_], + GetMyoSuiteTask(std::string(spec.config["task_name"_])) + .model_path), + GetMyoSuiteTask(std::string(spec.config["task_name"_])).frame_skip, + spec.config["max_episode_steps"_], spec.config["frame_stack"_], + RenderWidthOrDefault(spec.config), + RenderHeightOrDefault(spec.config), + RenderCameraIdOrDefault(spec.config)), + task_(GetMyoSuiteTask(std::string(spec.config["task_name"_]))), + metadata_( + GetMyoSuiteTaskMetadata(std::string(spec.config["task_name"_]))), + reference_( + GetMyoSuiteReferenceData(std::string(spec.config["task_name"_]))), + task_index_(TaskIndex(task_.id)), + obs_keys_(SplitList(metadata_.obs_keys)), + reward_weights_(ParseWeights(metadata_.rwd_keys_wt)), + metadata_init_qpos_(ParseNumbers(metadata_.init_qpos)), + metadata_init_qvel_(ParseNumbers(metadata_.init_qvel)), + reset_qacc_warmstart_(ParseNumbers(metadata_.reset_qacc_warmstart)), + target_jnt_value_(ParseNumbers(metadata_.target_jnt_value)), + tip_sites_(SplitList(metadata_.tip_sites)), + target_sites_(SplitList(metadata_.target_sites)), + target_reach_low_(ParseNumberGroups(metadata_.target_reach_low)), + target_reach_high_(ParseNumberGroups(metadata_.target_reach_high)), + last_ctrl_(model_->nu, 0.0) +#ifdef ENVPOOL_TEST + , + qpos0_pad_(2048, 0.0), + qvel0_pad_(2048, 0.0), + act0_pad_(2048, 0.0), + qacc0_pad_(2048, 0.0), + qacc_warmstart0_pad_(2048, 0.0), + qpos_pad_(2048, 0.0), + qvel_pad_(2048, 0.0), + act_pad_(2048, 0.0), + ctrl_pad_(2048, 0.0), + qacc_pad_(2048, 0.0), + qacc_warmstart_pad_(2048, 0.0), + actuator_length_pad_(2048, 0.0), + actuator_velocity_pad_(2048, 0.0), + actuator_force_pad_(2048, 0.0), + fatigue_ma_pad_(2048, 0.0), + fatigue_mr_pad_(2048, 0.0), + fatigue_mf_pad_(2048, 0.0), + fatigue_tl_pad_(2048, 0.0), + fatigue_tauact_pad_(2048, 0.0), + fatigue_taudeact_pad_(2048, 0.0), + site_pos_pad_(kMyoSuiteTestStatePad, 0.0), + site_quat_pad_(kMyoSuiteTestStatePad, 0.0), + site_xpos_pad_(kMyoSuiteTestStatePad, 0.0), + site_size_pad_(kMyoSuiteTestStatePad, 0.0), + site_rgba_pad_(kMyoSuiteTestStatePad, 0.0), + body_pos_pad_(kMyoSuiteTestStatePad, 0.0), + body_quat_pad_(kMyoSuiteTestStatePad, 0.0), + body_mass_pad_(kMyoSuiteTestStatePad, 0.0), + light_xpos_pad_(kMyoSuiteTestStatePad, 0.0), + light_xdir_pad_(kMyoSuiteTestStatePad, 0.0), + geom_pos_pad_(kMyoSuiteTestStatePad, 0.0), + geom_quat_pad_(kMyoSuiteTestStatePad, 0.0), + geom_size_pad_(kMyoSuiteTestStatePad, 0.0), + geom_xpos_pad_(kMyoSuiteTestStatePad, 0.0), + geom_xmat_pad_(kMyoSuiteTestStatePad, 0.0), + geom_rgba_pad_(kMyoSuiteTestStatePad, 0.0), + geom_friction_pad_(kMyoSuiteTestStatePad, 0.0), + geom_aabb_pad_(kMyoSuiteTestStatePad, 0.0), + geom_rbound_pad_(kMyoSuiteTestStatePad, 0.0), + geom_contype_pad_(kMyoSuiteTestStatePad, 0.0), + geom_conaffinity_pad_(kMyoSuiteTestStatePad, 0.0), + geom_type_pad_(kMyoSuiteTestStatePad, 0.0), + geom_condim_pad_(kMyoSuiteTestStatePad, 0.0), + hfield_data_pad_(kMyoSuiteTestStatePad, 0.0), + mocap_pos_pad_(kMyoSuiteTestStatePad, 0.0), + mocap_quat_pad_(kMyoSuiteTestStatePad, 0.0) +#endif + { + ApplyMuscleCondition(); + InitializeFatigue(); + ApplyMyoDmModelEdits(); + ApplyBaodingModelEdits(); + ApplyMetadataInitialState(); + SetDefaultInitialQpos(); + ApplyBimanualInitialState(); + InitializeRobotEnv(); + InitializeTaskCaches(); + } + + bool IsDone() override { return done_; } + + void Reset() override { + done_ = false; + elapsed_step_ = 0; + task_step_ = 0; + myodm_reference_index_ = 0; + bimanual_goal_touch_ = 0; + sparse_ = 0.0; + solved_ = 0.0; + ResetFatigue(); + ResetToInitialState(); + ResetOslController(); + ApplyResetTargets(); + WarmstartFromCurrentAcceleration(); + std::fill(last_ctrl_.begin(), last_ctrl_.end(), 0.0); + CapturePaddedResetState(); + WriteState(0.0, true); + } + + void Step(const Action& action) override { + const auto* raw = static_cast(action["action"_].Data()); + std::vector ctrl(model_->nu); + const int action_dim = + std::min(static_cast(model_->nu), task_.action_dim); + for (int i = 0; i < action_dim; ++i) { + mjtNum value = std::max(-1.0, std::min(1.0, raw[i])); + if (model_->na > 0 && task_.normalize_act && + model_->actuator_dyntype[i] == mjDYN_MUSCLE) { + const auto action_value = static_cast(value); + value = static_cast( + 1.0F / (1.0F + std::exp(-5.0F * (action_value - 0.5F)))); + } + if (task_.muscle_condition == MyoSuiteMuscleCondition::kReafferentation) { + // Applied below after EIP has been copied to EPL. + } + ctrl[i] = value; + } + ApplyFatigue(&ctrl); + ApplyReafferentation(&ctrl); + ApplyOslControls(&ctrl); + PreStepTaskUpdate(); + const bool robot_step_normalizes_ctrl = + task_.normalize_act && + (model_->na == 0 || + task_.kind == MyoSuiteTaskKind::kChallengeTableTennis); + for (int i = 0; i < model_->nu; ++i) { + if (robot_step_normalizes_ctrl && + (model_->na == 0 || model_->actuator_dyntype[i] != mjDYN_MUSCLE)) { + const mjtNum low = model_->actuator_ctrlrange[2 * i]; + const mjtNum high = model_->actuator_ctrlrange[2 * i + 1]; + ctrl[i] = (low + high) * 0.5 + ctrl[i] * (high - low) * 0.5; + } + data_->ctrl[i] = ctrl[i]; + } + last_ctrl_ = ctrl; + DoSimulation(); + // MyoSuite observes through Robot.sensor2sim(), which calls sim.forward() + // after copying the final qpos/qvel/act back into the observed sim. + mj_forward(model_, data_); + ++elapsed_step_; + const auto obs_dict = BuildObsDict(); + const RewardResult reward = ComputeReward(obs_dict); + sparse_ = reward.sparse; + solved_ = reward.solved; + done_ = reward.terminated || elapsed_step_ >= max_episode_steps_; + WriteState(reward.dense, false, obs_dict); + PostStepTaskUpdate(); + } + + bool RenderCamera(mjvCamera* camera) override { + mjv_defaultCamera(camera); + camera->type = mjCAMERA_FREE; + camera->fixedcamid = -1; + mjv_defaultFreeCamera(model_, camera); + return true; + } + + bool RenderOption(mjvOption* option) override { + mjv_defaultOption(option); + option->flags[mjVIS_ACTUATOR] = 1; + option->flags[mjVIS_ACTIVATION] = 1; + if (task_.kind == MyoSuiteTaskKind::kChallengeRunTrack || + task_.kind == MyoSuiteTaskKind::kChallengeChaseTag || + task_.kind == MyoSuiteTaskKind::kChallengeSoccer) { + option->flags[mjVIS_TENDON] = 1; + } + return true; + } + + void RenderCallback() override { mj_forward(model_, data_); } + + bool DisableAuxiliaryRenderVisuals() const override { return false; } + + bool ShareRenderContext() const override { return false; } + + bool PreferOfflineRenderContext() const override { return false; } + + bool ResizeOffscreenRenderContext() const override { return false; } + + protected: + struct RewardResult { + mjtNum dense{0.0}; + mjtNum sparse{0.0}; + mjtNum solved{0.0}; + bool terminated{false}; + }; + + struct MyoDmReferenceFrame { + std::vector robot; + std::vector robot_vel; + std::vector object; + }; + + using ObsDict = std::unordered_map>; + + static std::vector SplitList(std::string_view text, + char delimiter = ',') { + std::vector result; + std::string item; + std::stringstream stream{std::string(text)}; + while (std::getline(stream, item, delimiter)) { + if (!item.empty()) { + result.push_back(item); + } + } + return result; + } + + static std::vector ParseNumbers(std::string_view text) { + std::vector result; + for (const auto& item : SplitList(text)) { + result.push_back(static_cast(std::stod(item))); + } + return result; + } + + static std::vector> ParseNumberGroups( + std::string_view text) { + std::vector> result; + for (const auto& group : SplitList(text, ';')) { + result.push_back(ParseNumbers(group)); + } + return result; + } + + static std::vector> ParseWeights( + std::string_view text) { + std::vector> result; + for (const auto& item : SplitList(text)) { + const std::size_t sep = item.find(':'); + if (sep == std::string::npos) { + continue; + } + result.emplace_back(item.substr(0, sep), + static_cast(std::stod(item.substr(sep + 1)))); + } + return result; + } + + static mjtNum Norm(const std::vector& values) { + mjtNum sum = 0.0; + for (mjtNum value : values) { + sum += value * value; + } + return std::sqrt(sum); + } + + static mjtNum SquaredNorm(const std::vector& values) { + mjtNum sum = 0.0; + for (mjtNum value : values) { + sum += value * value; + } + return sum; + } + + static mjtNum MeanSquare(const std::vector& values) { + if (values.empty()) { + return 0.0; + } + return SquaredNorm(values) / static_cast(values.size()); + } + + static mjtNum Dot(const std::vector& lhs, + const std::vector& rhs) { + mjtNum sum = 0.0; + const int size = + std::min(static_cast(lhs.size()), static_cast(rhs.size())); + for (int i = 0; i < size; ++i) { + sum += lhs[i] * rhs[i]; + } + return sum; + } + + static mjtNum Cosine(const std::vector& lhs, + const std::vector& rhs) { + const mjtNum denom = Norm(lhs) * Norm(rhs); + if (denom <= 0.0) { + return 0.0; + } + return Dot(lhs, rhs) / denom; + } + + static mjtNum QuaternionDistance(const std::vector& current, + const std::vector& target) { + if (current.size() < 4 || target.size() < 4) { + return 0.0; + } + const mjtNum cw = current[0]; + const mjtNum cx = current[1]; + const mjtNum cy = current[2]; + const mjtNum cz = current[3]; + const mjtNum tw = target[0]; + const mjtNum tx = -target[1]; + const mjtNum ty = -target[2]; + const mjtNum tz = -target[3]; + const mjtNum dw = cw * tw - cx * tx - cy * ty - cz * tz; + const mjtNum dx = cw * tx + cx * tw + cy * tz - cz * ty; + const mjtNum dy = cw * ty - cx * tz + cy * tw + cz * tx; + const mjtNum dz = cw * tz + cx * ty - cy * tx + cz * tw; + const mjtNum axis_norm = std::sqrt(dx * dx + dy * dy + dz * dz); + return std::abs(2.0 * std::atan2(axis_norm, dw)); + } + + static mjtNum YawFromQuat(const mjtNum* quat) { + const mjtNum w = quat[0]; + const mjtNum x = quat[1]; + const mjtNum y = quat[2]; + const mjtNum z = quat[3]; + return std::atan2(2.0 * (w * z + x * y), 1.0 - 2.0 * (y * y + z * z)); + } + + static bool StartsWith(std::string_view value, std::string_view prefix) { + return value.size() >= prefix.size() && + value.substr(0, prefix.size()) == prefix; + } + + static int JointQposWidth(int joint_type) { + if (joint_type == mjJNT_FREE) { + return 7; + } + if (joint_type == mjJNT_BALL) { + return 4; + } + return 1; + } + + static int JointDofWidth(int joint_type) { + if (joint_type == mjJNT_FREE) { + return 6; + } + if (joint_type == mjJNT_BALL) { + return 3; + } + return 1; + } + + static std::vector Subtract(const std::vector& lhs, + const std::vector& rhs) { + const int size = + std::min(static_cast(lhs.size()), static_cast(rhs.size())); + std::vector result(size); + for (int i = 0; i < size; ++i) { + result[i] = lhs[i] - rhs[i]; + } + return result; + } + + static std::vector Add(const std::vector& lhs, + const std::vector& rhs) { + const int size = + std::min(static_cast(lhs.size()), static_cast(rhs.size())); + std::vector result(size); + for (int i = 0; i < size; ++i) { + result[i] = lhs[i] + rhs[i]; + } + return result; + } + + static std::vector MatToEuler(const mjtNum* mat) { + const mjtNum eps4 = std::numeric_limits::epsilon() * 4.0; + const mjtNum cy = std::sqrt(mat[8] * mat[8] + mat[5] * mat[5]); + std::vector euler(3, 0.0); + if (cy > eps4) { + euler[2] = -std::atan2(mat[1], mat[0]); + euler[1] = -std::atan2(-mat[2], cy); + euler[0] = -std::atan2(mat[5], mat[8]); + } else { + euler[2] = -std::atan2(-mat[3], mat[4]); + euler[1] = -std::atan2(-mat[2], cy); + euler[0] = 0.0; + } + return euler; + } + + int SiteId(const char* name) const { + return mj_name2id(model_, mjOBJ_SITE, name); + } + + int BodyId(const char* name) const { + return mj_name2id(model_, mjOBJ_BODY, name); + } + + int GeomId(const char* name) const { + return mj_name2id(model_, mjOBJ_GEOM, name); + } + + int JointId(const char* name) const { + return mj_name2id(model_, mjOBJ_JOINT, name); + } + + int ActuatorId(const char* name) const { + return mj_name2id(model_, mjOBJ_ACTUATOR, name); + } + + int SensorId(const char* name) const { + return mj_name2id(model_, mjOBJ_SENSOR, name); + } + + std::vector SiteXpos(int site_id) const { + if (site_id < 0) { + return {0.0, 0.0, 0.0}; + } + return {data_->site_xpos[3 * site_id], data_->site_xpos[3 * site_id + 1], + data_->site_xpos[3 * site_id + 2]}; + } + + std::vector BodyXpos(int body_id) const { + if (body_id < 0) { + return {0.0, 0.0, 0.0}; + } + return {data_->xpos[3 * body_id], data_->xpos[3 * body_id + 1], + data_->xpos[3 * body_id + 2]}; + } + + std::vector BodyPos(int body_id) const { + if (body_id < 0) { + return {0.0, 0.0, 0.0}; + } + return {model_->body_pos[3 * body_id], model_->body_pos[3 * body_id + 1], + model_->body_pos[3 * body_id + 2]}; + } + + std::vector BodyXquat(int body_id) const { + if (body_id < 0) { + return {1.0, 0.0, 0.0, 0.0}; + } + return {data_->xquat[4 * body_id], data_->xquat[4 * body_id + 1], + data_->xquat[4 * body_id + 2], data_->xquat[4 * body_id + 3]}; + } + + std::vector BodyXmat(int body_id) const { + std::vector result(9, 0.0); + if (body_id >= 0) { + std::memcpy(result.data(), data_->ximat + 9 * body_id, + sizeof(mjtNum) * result.size()); + } + return result; + } + + std::vector GeomXpos(int geom_id) const { + if (geom_id < 0) { + return {0.0, 0.0, 0.0}; + } + return {data_->geom_xpos[3 * geom_id], data_->geom_xpos[3 * geom_id + 1], + data_->geom_xpos[3 * geom_id + 2]}; + } + + std::vector SensorData(const char* name) const { + int sensor_id = SensorId(name); + if (sensor_id < 0) { + return {}; + } + const int start = model_->sensor_adr[sensor_id]; + const int dim = model_->sensor_dim[sensor_id]; + return {data_->sensordata + start, data_->sensordata + start + dim}; + } + + static std::string AssetPath(const std::string& base_path, + const std::string& model_path) { + return base_path + "/mujoco/myosuite/assets/" + model_path; + } + + static int TaskIndex(std::string_view task_id) { + const auto& tasks = third_party::myosuite::kMyoSuiteTasks; + for (int i = 0; i < static_cast(tasks.size()); ++i) { + if (tasks[i].id == task_id) { + return i; + } + } + throw std::runtime_error("Unknown MyoSuite task index."); + } + + void ApplyMuscleCondition() { + if (task_.muscle_condition == MyoSuiteMuscleCondition::kSarcopenia) { + for (int i = 0; i < model_->nu; ++i) { + model_->actuator_gainprm[i * mjNGAIN + 2] *= 0.5; + } + } + } + + void InitializeFatigue() { + if (task_.muscle_condition != MyoSuiteMuscleCondition::kFatigue) { + return; + } + for (int i = 0; i < model_->nu; ++i) { + if (model_->actuator_dyntype[i] == mjDYN_MUSCLE) { + muscle_actuator_ids_.push_back(i); + fatigue_tauact_.push_back(model_->actuator_dynprm[i * mjNDYN]); + fatigue_taudeact_.push_back(model_->actuator_dynprm[i * mjNDYN + 1]); + } + } + fatigue_ma_.resize(muscle_actuator_ids_.size()); + fatigue_mr_.resize(muscle_actuator_ids_.size()); + fatigue_mf_.resize(muscle_actuator_ids_.size()); + fatigue_tl_.resize(muscle_actuator_ids_.size()); + ResetFatigue(); + } + + void ResetFatigue() { + if (task_.muscle_condition != MyoSuiteMuscleCondition::kFatigue) { + return; + } + std::fill(fatigue_ma_.begin(), fatigue_ma_.end(), 0.0); + std::fill(fatigue_mr_.begin(), fatigue_mr_.end(), 1.0); + std::fill(fatigue_mf_.begin(), fatigue_mf_.end(), 0.0); + std::fill(fatigue_tl_.begin(), fatigue_tl_.end(), 0.0); + } + + void InitializeOslController() { + if (task_.kind != MyoSuiteTaskKind::kChallengeRunTrack) { + return; + } + osl_knee_actuator_id_ = ActuatorId("osl_knee_torque_actuator"); + osl_ankle_actuator_id_ = ActuatorId("osl_ankle_torque_actuator"); + osl_knee_joint_id_ = JointId("osl_knee_angle_r"); + osl_ankle_joint_id_ = JointId("osl_ankle_angle_r"); + osl_load_sensor_id_ = SensorId("r_osl_load"); + mjtNum body_mass = 0.0; + for (int i = 0; i < model_->nbody; ++i) { + body_mass += model_->body_mass[i]; + } + osl_body_weight_ = body_mass * static_cast(9.81); + ResetOslController(); + } + + void ResetOslController() { + if (task_.kind != MyoSuiteTaskKind::kChallengeRunTrack) { + return; + } + osl_phase_ = OslPhase::kEStance; + if (model_->nkey < 3) { + return; + } + int closest_key = 0; + mjtNum closest_distance = std::numeric_limits::infinity(); + for (int key = 0; key < 3; ++key) { + const mjtNum* key_qpos = model_->key_qpos + key * model_->nq; + mjtNum distance = 0.0; + for (int i = std::min(7, model_->nq); i < model_->nq; ++i) { + const mjtNum diff = data_->qpos[i] - key_qpos[i]; + distance += diff * diff; + } + if (distance < closest_distance) { + closest_distance = distance; + closest_key = key; + } + } + osl_phase_ = closest_key == 1 ? OslPhase::kESwing : OslPhase::kEStance; + } + + static mjtNum DegreesToRadians(mjtNum degrees) { + return degrees * std::acos(static_cast(-1.0)) / + static_cast(180.0); + } + + OslStateParams CurrentOslStateParams() const { + switch (osl_phase_) { + case OslPhase::kLStance: + return {99.372, 1.272, DegreesToRadians(8.0), + 79.498, 0.063, DegreesToRadians(-20.0)}; + case OslPhase::kESwing: + return {39.749, 0.063, DegreesToRadians(60.0), + 7.949, 0.0, DegreesToRadians(25.0)}; + case OslPhase::kLSwing: + return {15.899, 3.816, DegreesToRadians(5.0), + 7.949, 0.0, DegreesToRadians(15.0)}; + case OslPhase::kEStance: + default: + return {99.372, 3.180, DegreesToRadians(5.0), + 19.874, 0.0, DegreesToRadians(-2.0)}; + } + } + + void UpdateOslPhase(mjtNum knee_angle, mjtNum knee_vel, mjtNum ankle_angle, + mjtNum load) { + switch (osl_phase_) { + case OslPhase::kEStance: + if (load > 0.25 * osl_body_weight_ || + ankle_angle > DegreesToRadians(6.0)) { + osl_phase_ = OslPhase::kLStance; + } + break; + case OslPhase::kLStance: + if (load < 0.15 * osl_body_weight_) { + osl_phase_ = OslPhase::kESwing; + } + break; + case OslPhase::kESwing: + if (knee_angle > DegreesToRadians(50.0) || + knee_vel < DegreesToRadians(3.0)) { + osl_phase_ = OslPhase::kLSwing; + } + break; + case OslPhase::kLSwing: + if (load > 0.4 * osl_body_weight_ || + knee_angle < DegreesToRadians(30.0)) { + osl_phase_ = OslPhase::kEStance; + } + break; + } + } + + mjtNum JointQpos(int joint_id) const { + if (joint_id < 0) { + return 0.0; + } + return data_->qpos[model_->jnt_qposadr[joint_id]]; + } + + mjtNum JointQvel(int joint_id) const { + if (joint_id < 0) { + return 0.0; + } + return data_->qvel[model_->jnt_dofadr[joint_id]]; + } + + mjtNum OslLoad() const { + if (osl_load_sensor_id_ < 0 || + model_->sensor_dim[osl_load_sensor_id_] <= 1) { + return 0.0; + } + const int adr = model_->sensor_adr[osl_load_sensor_id_]; + return -data_->sensordata[adr + 1]; + } + + mjtNum OslActuatorControl(int actuator_id, mjtNum torque) const { + if (actuator_id < 0) { + return 0.0; + } + const mjtNum gear = model_->actuator_gear[6 * actuator_id]; + mjtNum ctrl = gear != 0.0 ? torque / gear : 0.0; + const mjtNum low = model_->actuator_ctrlrange[2 * actuator_id]; + const mjtNum high = model_->actuator_ctrlrange[2 * actuator_id + 1]; + ctrl = std::max(low, std::min(high, ctrl)); + if (task_.normalize_act) { + const mjtNum mean = (low + high) * 0.5; + const mjtNum range = (high - low) * 0.5; + ctrl = range != 0.0 ? (ctrl - mean) / range : 0.0; + } + return ctrl; + } + + void ApplyOslControls(std::vector* ctrl) { + if (task_.kind != MyoSuiteTaskKind::kChallengeRunTrack || + osl_knee_actuator_id_ < 0 || osl_ankle_actuator_id_ < 0) { + return; + } + const mjtNum knee_angle = JointQpos(osl_knee_joint_id_); + const mjtNum knee_vel = JointQvel(osl_knee_joint_id_); + const mjtNum ankle_angle = JointQpos(osl_ankle_joint_id_); + const mjtNum ankle_vel = JointQvel(osl_ankle_joint_id_); + UpdateOslPhase(knee_angle, knee_vel, ankle_angle, OslLoad()); + const auto params = CurrentOslStateParams(); + const mjtNum knee_torque = std::max( + -142.272, std::min( + 142.272, params.knee_stiffness * + (params.knee_target_angle - knee_angle) - + params.knee_damping * knee_vel)); + const mjtNum ankle_torque = std::max( + -168.192, + std::min( + 168.192, + params.ankle_stiffness * (params.ankle_target_angle - ankle_angle) - + params.ankle_damping * ankle_vel)); + (*ctrl)[osl_knee_actuator_id_] = + OslActuatorControl(osl_knee_actuator_id_, knee_torque); + (*ctrl)[osl_ankle_actuator_id_] = + OslActuatorControl(osl_ankle_actuator_id_, ankle_torque); + } + + void ApplyFatigue(std::vector* ctrl) { + if (task_.muscle_condition != MyoSuiteMuscleCondition::kFatigue) { + return; + } + constexpr mjtNum k_recovery_multiplier = 10.0 * 15.0; + constexpr mjtNum k_fatigue_coefficient = 0.00912; + constexpr mjtNum k_recovery_coefficient = 0.1 * 0.00094; + const auto dt = static_cast(Dt()); + for (std::size_t i = 0; i < muscle_actuator_ids_.size(); ++i) { + const int actuator_id = muscle_actuator_ids_[i]; + fatigue_tl_[i] = (*ctrl)[actuator_id]; + + const mjtNum ma = fatigue_ma_[i]; + const mjtNum mr = fatigue_mr_[i]; + const mjtNum mf = fatigue_mf_[i]; + const mjtNum tl = fatigue_tl_[i]; + const mjtNum ld = (1.0 / fatigue_tauact_[i]) * (0.5 + 1.5 * ma); + const mjtNum lr = (0.5 + 1.5 * ma) / fatigue_taudeact_[i]; + + mjtNum transfer = 0.0; + if (ma < tl && mr > tl - ma) { + transfer = ld * (tl - ma); + } else if (ma < tl) { + transfer = ld * mr; + } else { + transfer = lr * (tl - ma); + } + + const mjtNum recovery = + ma >= tl ? k_recovery_multiplier * k_recovery_coefficient + : k_recovery_coefficient; + const mjtNum lower = std::max(-ma / dt + k_fatigue_coefficient * ma, + (mr - 1.0) / dt + recovery * mf); + const mjtNum upper = + std::min((1.0 - ma) / dt + k_fatigue_coefficient * ma, + mr / dt + recovery * mf); + transfer = std::max(lower, std::min(upper, transfer)); + + fatigue_ma_[i] += (transfer - k_fatigue_coefficient * ma) * dt; + fatigue_mr_[i] += (-transfer + recovery * mf) * dt; + fatigue_mf_[i] += (k_fatigue_coefficient * ma - recovery * mf) * dt; + (*ctrl)[actuator_id] = + static_cast(static_cast(fatigue_ma_[i])); + } + } + + void ApplyReafferentation(std::vector* ctrl) const { + if (task_.muscle_condition != MyoSuiteMuscleCondition::kReafferentation) { + return; + } + int epl = mj_name2id(model_, mjOBJ_ACTUATOR, "EPL"); + int eip = mj_name2id(model_, mjOBJ_ACTUATOR, "EIP"); + if (epl >= 0 && eip >= 0) { + (*ctrl)[epl] = (*ctrl)[eip]; + (*ctrl)[eip] = 0.0; + } + } + + void ApplyMyoDmModelEdits() { + if (task_.kind != MyoSuiteTaskKind::kMyoDmTrack) { + return; + } + const int body_geom = GeomId("body"); + if (body_geom >= 0) { + model_->geom_rgba[4 * body_geom + 3] = 0.0; + } + } + + void ApplyMetadataInitialState() { + if (static_cast(metadata_init_qpos_.size()) == model_->nq) { + std::memcpy(data_->qpos, metadata_init_qpos_.data(), + sizeof(mjtNum) * model_->nq); + } + if (static_cast(metadata_init_qvel_.size()) == model_->nv) { + std::memcpy(data_->qvel, metadata_init_qvel_.data(), + sizeof(mjtNum) * model_->nv); + } + if (model_->na > 0) { + mju_zero(data_->act, model_->na); + } + mj_forward(model_, data_); + } + + void ApplyBaodingModelEdits() { + if (task_.kind != MyoSuiteTaskKind::kChallengeBaoding) { + return; + } + const int target1 = SiteId("target1_site"); + const int target2 = SiteId("target2_site"); + if (target1 >= 0) { + model_->site_group[target1] = 2; + } + if (target2 >= 0) { + model_->site_group[target2] = 2; + } + } + + void SetDefaultInitialQpos() { + if (!task_.normalize_act) { + return; + } + for (int actuator_id = 0; actuator_id < model_->nu; ++actuator_id) { + if (model_->actuator_trntype[actuator_id] != mjTRN_JOINT) { + continue; + } + const int joint_id = model_->actuator_trnid[2 * actuator_id]; + if (joint_id < 0) { + continue; + } + const int joint_type = model_->jnt_type[joint_id]; + if (joint_type != mjJNT_HINGE && joint_type != mjJNT_SLIDE) { + continue; + } + const int qpos_id = model_->jnt_qposadr[joint_id]; + if (metadata_init_qpos_.empty()) { + data_->qpos[qpos_id] = (model_->jnt_range[2 * joint_id] + + model_->jnt_range[2 * joint_id + 1]) * + 0.5; + } + } + mj_forward(model_, data_); + } + + void ApplyBimanualInitialState() { + if (task_.kind != MyoSuiteTaskKind::kChallengeBimanual || + model_->nkey <= 2) { + return; + } + std::memcpy(data_->qpos, model_->key_qpos + 2 * model_->nq, + sizeof(mjtNum) * model_->nq); + std::fill(data_->qvel, data_->qvel + model_->nv, 0.0); + if (model_->na > 0) { + mju_zero(data_->act, model_->na); + } + mj_forward(model_, data_); + } + + void ApplyResetTargets() { + const int count = std::min({static_cast(target_sites_.size()), + static_cast(target_reach_low_.size()), + static_cast(target_reach_high_.size())}); + for (int i = 0; i < count; ++i) { + const int site_id = SiteId(target_sites_[i].c_str()); + if (site_id < 0 || target_reach_low_[i].size() < 3 || + target_reach_high_[i].size() < 3) { + continue; + } + for (int axis = 0; axis < 3; ++axis) { + // Use the midpoint as the native deterministic reset target. Oracle + // alignment tests sync randomized upstream state immediately after + // reset when the official task randomizes this site. + model_->site_pos[3 * site_id + axis] = + (target_reach_low_[i][axis] + target_reach_high_[i][axis]) * 0.5; + } + } + if (task_.kind == MyoSuiteTaskKind::kChallengeBimanual) { + const int start = BodyId("start"); + const int goal = BodyId("goal"); + if (start >= 0) { + model_->body_pos[3 * start] = -0.4; + model_->body_pos[3 * start + 1] = -0.25; + model_->body_pos[3 * start + 2] = 1.05; + } + if (goal >= 0) { + model_->body_pos[3 * goal] = 0.4; + model_->body_pos[3 * goal + 1] = -0.25; + model_->body_pos[3 * goal + 2] = 1.05; + } + const int object_joint = JointId("manip_object/freejoint"); + if (object_joint >= 0) { + const int qpos_id = model_->jnt_qposadr[object_joint]; + data_->qpos[qpos_id] = -0.4; + data_->qpos[qpos_id + 1] = -0.25; + data_->qpos[qpos_id + 2] = 1.15; + } + } + ApplyChallengeRunTrackTerrainReset(); + mj_forward(model_, data_); + if (task_.kind == MyoSuiteTaskKind::kChallengeBimanual) { + bimanual_init_obj_z_ = SiteXpos(SiteId("touch_site"))[2]; + bimanual_init_palm_z_ = SiteXpos(SiteId("S_grasp"))[2]; + } + } + + void ApplyChallengeRunTrackTerrainReset() { + if (task_.kind != MyoSuiteTaskKind::kChallengeRunTrack) { + return; + } + const int terrain = GeomId("terrain"); + if (terrain < 0) { + return; + } + model_->geom_pos[3 * terrain] = 0.0; + model_->geom_pos[3 * terrain + 1] = 0.0; + model_->geom_pos[3 * terrain + 2] = 0.005; + model_->geom_rgba[4 * terrain + 3] = 1.0; + } + + void WarmstartFromCurrentAcceleration() { + if (static_cast(reset_qacc_warmstart_.size()) == model_->nv) { + std::memcpy(data_->qacc_warmstart, reset_qacc_warmstart_.data(), + sizeof(mjtNum) * model_->nv); + } else { + std::memcpy(data_->qacc_warmstart, data_->qacc, + sizeof(mjtNum) * model_->nv); + } + } + + void InitializeTaskCaches() { + if (task_.kind == MyoSuiteTaskKind::kMyoDmTrack) { + const int object_bid = task_.object_name[0] != '\0' + ? BodyId(task_.object_name) + : BodyId("Object"); + myodm_lift_z_ = BodyXpos(object_bid)[2] + 0.02; + } + InitializeOslController(); + if (task_.kind == MyoSuiteTaskKind::kChallengeTableTennis) { + tabletennis_init_paddle_quat_ = BodyXquat(BodyId("paddle")); + } + if (task_.kind == MyoSuiteTaskKind::kChallengeReorient) { + const auto target = SiteXpos(SiteId("target_o")); + const auto object = SiteXpos(SiteId("object_o")); + for (int axis = 0; axis < 3; ++axis) { + challenge_reorient_goal_obj_offset_[axis] = target[axis] - object[axis]; + } + } + if (task_.kind == MyoSuiteTaskKind::kChallengeBimanual) { + bimanual_init_obj_z_ = SiteXpos(SiteId("touch_site"))[2]; + bimanual_init_palm_z_ = SiteXpos(SiteId("S_grasp"))[2]; + } + } + + static std::vector ReferenceRow(const double* values, int rows, + int cols, int row) { + if (rows <= 0 || cols <= 0) { + return {}; + } + row = std::max(0, std::min(rows - 1, row)); + std::vector result(cols); + for (int i = 0; i < cols; ++i) { + result[i] = static_cast(values[row * cols + i]); + } + return result; + } + + static std::vector ReferenceBlend(const double* values, int rows, + int cols, int row, int next, + mjtNum blend) { + if (rows <= 0 || cols <= 0) { + return {}; + } + row = std::max(0, std::min(rows - 1, row)); + next = std::max(0, std::min(rows - 1, next)); + if (row == next) { + return ReferenceRow(values, rows, cols, row); + } + std::vector result(cols); + for (int i = 0; i < cols; ++i) { + const auto a = static_cast(values[row * cols + i]); + const auto b = static_cast(values[next * cols + i]); + result[i] = (1.0 - blend) * a + blend * b; + } + return result; + } + + std::pair MyoDmReferenceRows(mjtNum time) { + if (reference_.time_size <= 1) { + return {0, 0}; + } + const mjtNum rounded = std::round(time * 10000.0) / 10000.0; + const int last = reference_.time_size - 1; + if (rounded >= static_cast(reference_.time[last])) { + myodm_reference_index_ = last; + return {last, last}; + } + if (myodm_reference_index_ < last && + rounded == + static_cast(reference_.time[myodm_reference_index_ + 1])) { + ++myodm_reference_index_; + return {myodm_reference_index_, myodm_reference_index_}; + } + if (rounded == + static_cast(reference_.time[myodm_reference_index_])) { + return {myodm_reference_index_, myodm_reference_index_}; + } + const double* begin = reference_.time; + const double* end = reference_.time + reference_.time_size; + const auto* const upper = std::upper_bound(begin, end, rounded); + int next = static_cast(upper - begin); + next = std::max(1, std::min(last, next)); + myodm_reference_index_ = next - 1; + if (rounded == static_cast(reference_.time[next])) { + myodm_reference_index_ = next; + return {next, next}; + } + return {myodm_reference_index_, next}; + } + + MyoDmReferenceFrame MyoDmReferenceAt(mjtNum time) { + MyoDmReferenceFrame frame; + if (reference_.type == MyoSuiteReferenceType::kNone) { + return frame; + } + if (reference_.type == MyoSuiteReferenceType::kRandom) { + auto sample = [this](const double* values, int rows, int cols) { + if (rows < 2 || cols <= 0) { + return ReferenceRow(values, rows, cols, 0); + } + std::vector result(cols); + for (int i = 0; i < cols; ++i) { + std::uniform_real_distribution dist( + static_cast(values[i]), + static_cast(values[cols + i])); + result[i] = dist(gen_); + } + return result; + }; + frame.robot = sample(reference_.robot, reference_.robot_rows, + reference_.robot_cols); + frame.robot_vel = sample(reference_.robot_vel, reference_.robot_vel_rows, + reference_.robot_vel_cols); + frame.object = sample(reference_.object, reference_.object_rows, + reference_.object_cols); + return frame; + } + + const auto [row, next] = MyoDmReferenceRows(time); + mjtNum blend = 0.0; + if (row != next) { + const auto t0 = static_cast(reference_.time[row]); + const auto t1 = static_cast(reference_.time[next]); + if (t1 > t0) { + blend = (time - t0) / (t1 - t0); + } + } + frame.robot = ReferenceBlend(reference_.robot, reference_.robot_rows, + reference_.robot_cols, row, next, blend); + frame.robot_vel = + ReferenceBlend(reference_.robot_vel, reference_.robot_vel_rows, + reference_.robot_vel_cols, row, next, blend); + frame.object = ReferenceBlend(reference_.object, reference_.object_rows, + reference_.object_cols, row, next, blend); + return frame; + } + + void ApplyMyoDmReferenceSite(const MyoDmReferenceFrame& reference) { + if (task_.kind != MyoSuiteTaskKind::kMyoDmTrack || + reference.object.size() < 3) { + return; + } + const int target_sid = SiteId("target"); + if (target_sid < 0) { + return; + } + for (int axis = 0; axis < 3; ++axis) { + model_->site_pos[3 * target_sid + axis] = reference.object[axis]; + } + mj_forward(model_, data_); + } + + void PreStepTaskUpdate() { + if (task_.kind != MyoSuiteTaskKind::kChallengeBaoding) { + return; + } + const int target1 = SiteId("target1_site"); + const int target2 = SiteId("target2_site"); + if (target1 < 0 || target2 < 0) { + return; + } + const mjtNum x_radius = 0.025; + const mjtNum y_radius = 0.028; + const mjtNum center_x = -0.0125; + const mjtNum center_y = -0.07; + const mjtNum phase = + (static_cast(task_step_) * Dt()) / static_cast(6.0); + const mjtNum angle1 = 2.0 * M_PI * phase + M_PI / 4.0; + const mjtNum angle2 = angle1 - M_PI; + model_->site_pos[3 * target1] = x_radius * std::cos(angle1) + center_x; + model_->site_pos[3 * target1 + 1] = y_radius * std::sin(angle1) + center_y; + model_->site_pos[3 * target2] = x_radius * std::cos(angle2) + center_x; + model_->site_pos[3 * target2 + 1] = y_radius * std::sin(angle2) + center_y; + mj_forward(model_, data_); + } + + void PostStepTaskUpdate() { + if (task_.kind == MyoSuiteTaskKind::kWalk || + task_.kind == MyoSuiteTaskKind::kTerrain || + task_.kind == MyoSuiteTaskKind::kWalkReach || + task_.kind == MyoSuiteTaskKind::kChallengeBaoding) { + ++task_step_; + } + } + + std::vector Observation(const ObsDict& obs_dict) const { + std::vector obs(task_.obs_dim, 0.0); + std::size_t pos = 0; + auto append = [&](mjtNum value) { + if (pos < obs.size()) { + obs[pos++] = value; + } + }; + for (const std::string& key : obs_keys_) { + auto it = obs_dict.find(key); + if (it == obs_dict.end()) { + continue; + } + for (mjtNum value : it->second) { + append(value); + } + } + return obs; + } + + std::vector QposSlice(int begin, int end) const { + begin = std::max(0, begin); + end = std::min(static_cast(model_->nq), end); + if (end <= begin) { + return {}; + } + return {data_->qpos + begin, data_->qpos + end}; + } + + std::vector QvelSlice(int begin, int end, bool scale_dt) const { + begin = std::max(0, begin); + end = std::min(static_cast(model_->nv), end); + std::vector result; + const mjtNum scale = scale_dt ? static_cast(Dt()) : 1.0; + for (int i = begin; i < end; ++i) { + result.push_back(data_->qvel[i] * scale); + } + return result; + } + + std::vector JointQposValues(int joint_id) const { + if (joint_id < 0) { + return {}; + } + const int begin = model_->jnt_qposadr[joint_id]; + return QposSlice(begin, begin + JointQposWidth(model_->jnt_type[joint_id])); + } + + std::vector JointQvelValues(int joint_id, bool scale_dt) const { + if (joint_id < 0) { + return {}; + } + const int begin = model_->jnt_dofadr[joint_id]; + return QvelSlice(begin, begin + JointDofWidth(model_->jnt_type[joint_id]), + scale_dt); + } + + bool IsBimanualProsthesisJoint(int joint_id) const { + const char* name = mj_id2name(model_, mjOBJ_JOINT, joint_id); + return name != nullptr && StartsWith(name, "prosthesis"); + } + + bool IsBimanualManipObjectJoint(int joint_id) const { + const char* name = mj_id2name(model_, mjOBJ_JOINT, joint_id); + return name != nullptr && + std::string_view(name) == "manip_object/freejoint"; + } + + std::vector BimanualJointQpos(bool prosthesis) const { + std::vector result; + for (int joint_id = 0; joint_id < model_->njnt; ++joint_id) { + if (IsBimanualManipObjectJoint(joint_id) || + IsBimanualProsthesisJoint(joint_id) != prosthesis) { + continue; + } + const auto values = JointQposValues(joint_id); + result.insert(result.end(), values.begin(), values.end()); + } + return result; + } + + std::vector BimanualJointQvel(bool prosthesis) const { + std::vector result; + for (int joint_id = 0; joint_id < model_->njnt; ++joint_id) { + if (IsBimanualManipObjectJoint(joint_id) || + IsBimanualProsthesisJoint(joint_id) != prosthesis) { + continue; + } + const auto values = JointQvelValues(joint_id, false); + result.insert(result.end(), values.begin(), values.end()); + } + return result; + } + + std::vector AverageSites(const char* lhs, const char* rhs) const { + const auto lhs_pos = SiteXpos(SiteId(lhs)); + const auto rhs_pos = SiteXpos(SiteId(rhs)); + return {(lhs_pos[0] + rhs_pos[0]) * 0.5, (lhs_pos[1] + rhs_pos[1]) * 0.5, + (lhs_pos[2] + rhs_pos[2]) * 0.5}; + } + + int BimanualBodyLabel(int body_id) const { + const int start_id = BodyId("start"); + const int goal_id = BodyId("goal"); + const int object_id = BodyId("manip_object"); + int myo_min = model_->nbody; + int myo_max = -1; + int prosth_min = model_->nbody; + int prosth_max = -1; + for (int id = 0; id < model_->nbody; ++id) { + const char* raw_name = mj_id2name(model_, mjOBJ_BODY, id); + const std::string_view name = raw_name == nullptr ? "" : raw_name; + if (StartsWith(name, "prosthesis/")) { + prosth_min = std::min(prosth_min, id); + prosth_max = std::max(prosth_max, id); + } else if (id != start_id && id != goal_id && id != object_id) { + myo_min = std::min(myo_min, id); + myo_max = std::max(myo_max, id); + } + } + if (myo_min <= body_id && body_id <= myo_max) { + return 0; + } + if (prosth_min <= body_id && body_id <= prosth_max) { + return 1; + } + if (body_id == start_id) { + return 2; + } + if (body_id == goal_id) { + return 3; + } + return 4; + } + + std::vector BimanualTouchingBody() const { + std::vector result(5, 0.0); + const int object_id = BodyId("manip_object"); + if (object_id < 0) { + return result; + } + for (int i = 0; i < data_->ncon; ++i) { + const mjContact& contact = data_->contact[i]; + const int body1 = model_->geom_bodyid[contact.geom1]; + const int body2 = model_->geom_bodyid[contact.geom2]; + if (body1 == object_id) { + result[BimanualBodyLabel(body2)] += 1.0; + } else if (body2 == object_id) { + result[BimanualBodyLabel(body1)] += 1.0; + } + } + return result; + } + + std::vector TableTennisTouchingInfo() const { + std::vector result(6, 0.0); + const int ball_id = BodyId("pingpong"); + if (ball_id < 0) { + return result; + } + const int paddle = GeomId("pad"); + const int own = GeomId("coll_own_half"); + const int opponent = GeomId("coll_opponent_half"); + const int net = GeomId("coll_net"); + const int ground = GeomId("ground"); + auto label = [&](int geom_id) { + if (geom_id == paddle) { + return 0; + } + if (geom_id == own) { + return 1; + } + if (geom_id == opponent) { + return 2; + } + if (geom_id == net) { + return 3; + } + if (geom_id == ground) { + return 4; + } + return 5; + }; + for (int i = 0; i < data_->ncon; ++i) { + const mjContact& contact = data_->contact[i]; + const int body1 = model_->geom_bodyid[contact.geom1]; + const int body2 = model_->geom_bodyid[contact.geom2]; + if (body1 == ball_id) { + result[label(contact.geom2)] += 1.0; + } else if (body2 == ball_id) { + result[label(contact.geom1)] += 1.0; + } + } + return result; + } + + std::vector Act() const { + if (model_->na <= 0) { + return {}; + } + return {data_->act, data_->act + model_->na}; + } + + std::vector ActuatorLength() const { + return {data_->actuator_length, data_->actuator_length + model_->nu}; + } + + std::vector ActuatorVelocity() const { + std::vector result(model_->nu); + for (int i = 0; i < model_->nu; ++i) { + result[i] = std::max( + -100.0, std::min(100.0, data_->actuator_velocity[i])); + } + return result; + } + + std::vector ActuatorForce(bool scaled) const { + std::vector result(model_->nu); + for (int i = 0; i < model_->nu; ++i) { + mjtNum value = data_->actuator_force[i]; + if (scaled) { + value = + std::max(-100.0, std::min(100.0, value / 1000.0)); + } + result[i] = value; + } + return result; + } + + std::vector ReachPositions(bool target) const { + const auto& sites = target ? target_sites_ : tip_sites_; + std::vector result; + for (const std::string& site_name : sites) { + const auto xyz = SiteXpos(SiteId(site_name.c_str())); + result.insert(result.end(), xyz.begin(), xyz.end()); + } + return result; + } + + std::vector PenRotation(bool target) const { + const bool sar = task_.kind == MyoSuiteTaskKind::kReorientSar; + const int top = sar ? GeomId(target ? "t_top" : "top") + : SiteId(target ? "target_top" : "object_top"); + const int bottom = sar ? GeomId(target ? "t_bot" : "bot") + : SiteId(target ? "target_bottom" : "object_bottom"); + const auto top_pos = sar ? GeomXpos(top) : SiteXpos(top); + const auto bottom_pos = sar ? GeomXpos(bottom) : SiteXpos(bottom); + auto rot = Subtract(top_pos, bottom_pos); + const mjtNum length = Norm(rot); + if (length > 0.0) { + for (mjtNum& value : rot) { + value /= length; + } + } + return rot; + } + + std::vector CenterOfMassVelocity() const { + mjtNum mass_sum = 0.0; + std::array cvel{}; + for (int body = 0; body < model_->nbody; ++body) { + const mjtNum mass = model_->body_mass[body]; + mass_sum += mass; + for (int axis = 0; axis < 6; ++axis) { + cvel[axis] += mass * (-data_->cvel[6 * body + axis]); + } + } + if (mass_sum > 0.0) { + cvel[3] /= mass_sum; + cvel[4] /= mass_sum; + } + return {cvel[3], cvel[4]}; + } + + std::vector CenterOfMass() const { + mjtNum mass_sum = 0.0; + std::array com{}; + for (int body = 0; body < model_->nbody; ++body) { + const mjtNum mass = model_->body_mass[body]; + mass_sum += mass; + for (int axis = 0; axis < 3; ++axis) { + com[axis] += mass * data_->xipos[3 * body + axis]; + } + } + if (mass_sum > 0.0) { + for (mjtNum& value : com) { + value /= mass_sum; + } + } + return {com[0], com[1], com[2]}; + } + + std::vector FeetRelativePositions() const { + const auto left = BodyXpos(BodyId("talus_l")); + const auto right = BodyXpos(BodyId("talus_r")); + const auto pelvis = BodyXpos(BodyId("pelvis")); + auto left_rel = Subtract(left, pelvis); + auto right_rel = Subtract(right, pelvis); + left_rel.insert(left_rel.end(), right_rel.begin(), right_rel.end()); + return left_rel; + } + + std::vector JointAngles( + const std::vector& joint_names) const { + std::vector result; + for (const char* name : joint_names) { + const int joint_id = JointId(name); + if (joint_id >= 0) { + result.push_back(data_->qpos[model_->jnt_qposadr[joint_id]]); + } + } + return result; + } + + ObsDict BuildObsDict() { + ObsDict obs; + MyoDmReferenceFrame myodm_reference; + if (task_.kind == MyoSuiteTaskKind::kMyoDmTrack) { + myodm_reference = MyoDmReferenceAt(data_->time); + ApplyMyoDmReferenceSite(myodm_reference); + } + const auto dt = static_cast(Dt()); + obs["time"] = {data_->time}; + obs["t"] = {data_->time}; + obs["qpos"] = QposSlice(0, model_->nq); + obs["qp"] = obs["qpos"]; + obs["qvel"] = QvelSlice(0, model_->nv, true); + obs["qv"] = QvelSlice(0, model_->nv, false); + obs["act"] = Act(); + + const int pose_size = !target_jnt_value_.empty() + ? static_cast(target_jnt_value_.size()) + : model_->nq; + obs["pose_err"] = std::vector(pose_size, 0.0); + for (int i = 0; i < pose_size && i < model_->nq; ++i) { + const mjtNum target = i < static_cast(target_jnt_value_.size()) + ? target_jnt_value_[i] + : 0.0; + obs["pose_err"][i] = target - data_->qpos[i]; + } + + obs["tip_pos"] = ReachPositions(false); + obs["target_pos"] = ReachPositions(true); + obs["reach_err"] = Subtract(obs["target_pos"], obs["tip_pos"]); + + const bool key_turn = task_.kind == MyoSuiteTaskKind::kKeyTurn; + const int hand_qpos_end = key_turn ? model_->nq - 1 : model_->nq - 7; + const int hand_qvel_end = key_turn ? model_->nv - 1 : model_->nv - 6; + obs["hand_qpos"] = QposSlice(0, hand_qpos_end); + obs["hand_qpos_noMD5"] = QposSlice(0, model_->nq - 7); + obs["hand_qpos_corrected"] = QposSlice(0, model_->nq - 6); + obs["hand_qvel"] = QvelSlice(0, hand_qvel_end, true); + obs["hand_jnt"] = QposSlice(0, model_->nq - 6); + obs["hand_pos"] = QposSlice(0, model_->nq - 14); + obs["key_qpos"] = QposSlice(model_->nq - 1, model_->nq); + obs["key_qvel"] = QvelSlice(model_->nv - 1, model_->nv, true); + obs["IFtip_approach"] = + Subtract(SiteXpos(SiteId("keyhead")), SiteXpos(SiteId("IFtip"))); + obs["THtip_approach"] = + Subtract(SiteXpos(SiteId("keyhead")), SiteXpos(SiteId("THtip"))); + + if (task_.kind == MyoSuiteTaskKind::kObjHoldFixed || + task_.kind == MyoSuiteTaskKind::kObjHoldRandom) { + obs["obj_pos"] = SiteXpos(SiteId("object")); + obs["obj_err"] = + Subtract(SiteXpos(SiteId("goal")), SiteXpos(SiteId("object"))); + } else if (task_.kind == MyoSuiteTaskKind::kPenTwirlFixed || + task_.kind == MyoSuiteTaskKind::kPenTwirlRandom || + task_.kind == MyoSuiteTaskKind::kReorientSar) { + obs["obj_pos"] = BodyXpos(BodyId("Object")); + obs["obj_des_pos"] = SiteXpos(SiteId("eps_ball")); + obs["obj_vel"] = QvelSlice(model_->nv - 6, model_->nv, true); + obs["obj_rot"] = PenRotation(false); + obs["obj_des_rot"] = PenRotation(true); + obs["obj_err_pos"] = Subtract(obs["obj_pos"], obs["obj_des_pos"]); + obs["obj_err_rot"] = Subtract(obs["obj_rot"], obs["obj_des_rot"]); + obs["mlen"] = ActuatorLength(); + obs["mvel"] = std::vector(data_->actuator_velocity, + data_->actuator_velocity + model_->nu); + obs["mforce"] = std::vector(data_->actuator_force, + data_->actuator_force + model_->nu); + } else if (task_.kind == MyoSuiteTaskKind::kChallengeRelocate || + task_.kind == MyoSuiteTaskKind::kChallengeReorient) { + obs["obj_pos"] = SiteXpos(SiteId("object_o")); + obs["goal_pos"] = SiteXpos(SiteId("target_o")); + obs["palm_pos"] = SiteXpos(SiteId("S_grasp")); + obs["pos_err"] = Subtract(obs["goal_pos"], obs["obj_pos"]); + if (task_.kind == MyoSuiteTaskKind::kChallengeReorient) { + for (int axis = 0; axis < 3; ++axis) { + obs["pos_err"][axis] -= challenge_reorient_goal_obj_offset_[axis]; + } + } + obs["reach_err"] = Subtract(obs["palm_pos"], obs["obj_pos"]); + obs["obj_rot"] = MatToEuler(data_->site_xmat + 9 * SiteId("object_o")); + obs["goal_rot"] = MatToEuler(data_->site_xmat + 9 * SiteId("target_o")); + obs["rot_err"] = Subtract(obs["goal_rot"], obs["obj_rot"]); + } + + obs["qpos_without_xy"] = QposSlice(2, model_->nq); + obs["com_vel"] = CenterOfMassVelocity(); + obs["torso_angle"] = BodyXquat(BodyId("torso")); + obs["feet_heights"] = {BodyXpos(BodyId("talus_l"))[2], + BodyXpos(BodyId("talus_r"))[2]}; + obs["height"] = {CenterOfMass()[2]}; + obs["feet_rel_positions"] = FeetRelativePositions(); + const int hip_period = + metadata_.hip_period > 0 ? metadata_.hip_period : 100; + obs["phase_var"] = {std::fmod( + static_cast(task_step_) / static_cast(hip_period), + 1.0)}; + obs["muscle_length"] = ActuatorLength(); + obs["muscle_velocity"] = ActuatorVelocity(); + obs["muscle_force"] = ActuatorForce(true); + + obs["object1_pos"] = SiteXpos(SiteId("ball1_site")); + obs["object2_pos"] = SiteXpos(SiteId("ball2_site")); + obs["object1_velp"] = QvelSlice(model_->nv - 12, model_->nv - 9, true); + obs["object2_velp"] = QvelSlice(model_->nv - 6, model_->nv - 3, true); + obs["target1_pos"] = SiteXpos(SiteId("target1_site")); + obs["target2_pos"] = SiteXpos(SiteId("target2_site")); + obs["target1_err"] = Subtract(obs["target1_pos"], obs["object1_pos"]); + obs["target2_err"] = Subtract(obs["target2_pos"], obs["object2_pos"]); + + obs["internal_qpos"] = QposSlice(0, model_->nq); + obs["internal_qvel"] = QvelSlice(0, model_->nv, true); + obs["grf"] = SensorData("r_foot"); + auto lfoot = SensorData("l_foot"); + obs["grf"].insert(obs["grf"].end(), lfoot.begin(), lfoot.end()); + obs["model_root_pos"] = + QposSlice(0, std::min(7, static_cast(model_->nq))); + obs["model_root_vel"] = + QvelSlice(0, std::min(6, static_cast(model_->nv)), false); + if (model_->nmocap > 0) { + obs["opponent_pose"] = {data_->mocap_pos[0], data_->mocap_pos[1], + YawFromQuat(data_->mocap_quat)}; + } else { + const auto opponent = BodyXpos(BodyId("opponent")); + obs["opponent_pose"] = {opponent[0], opponent[1], 0.0}; + } + obs["opponent_vel"] = {0.0, 0.0}; + + obs["pelvis_pos"] = SiteXpos(SiteId("pelvis")); + obs["body_qpos"] = + QposSlice(0, std::max(0, static_cast(model_->nq) - 7)); + obs["body_qvel"] = + QvelSlice(0, std::max(0, static_cast(model_->nv) - 6), true); + if (task_.kind == MyoSuiteTaskKind::kChallengeSoccer) { + obs["ball_pos"] = BodyXpos(BodyId("soccer_ball")); + } + if (task_.kind == MyoSuiteTaskKind::kChallengeTableTennis) { + obs["ball_pos"] = SiteXpos(SiteId("pingpong")); + obs["ball_vel"] = SensorData("pingpong_vel_sensor"); + obs["paddle_pos"] = SiteXpos(SiteId("paddle")); + obs["paddle_vel"] = SensorData("paddle_vel_sensor"); + obs["paddle_ori"] = BodyXquat(BodyId("paddle")); + obs["padde_ori_err"] = + Subtract(obs["paddle_ori"], tabletennis_init_paddle_quat_); + obs["reach_err"] = Subtract(obs["paddle_pos"], obs["ball_pos"]); + obs["palm_pos"] = SiteXpos(SiteId("S_grasp")); + obs["palm_err"] = Subtract(obs["palm_pos"], obs["paddle_pos"]); + obs["touching_info"] = TableTennisTouchingInfo(); + } + + if (task_.kind == MyoSuiteTaskKind::kChallengeBimanual) { + obs["myohand_qpos"] = BimanualJointQpos(false); + obs["myohand_qvel"] = BimanualJointQvel(false); + obs["pros_hand_qpos"] = BimanualJointQpos(true); + obs["pros_hand_qvel"] = BimanualJointQvel(true); + const int object_joint = JointId("manip_object/freejoint"); + obs["object_qpos"] = JointQposValues(object_joint); + obs["object_qvel"] = JointQvelValues(object_joint, false); + obs["touching_body"] = BimanualTouchingBody(); + obs["palm_pos"] = SiteXpos(SiteId("S_grasp")); + obs["fin0"] = SiteXpos(SiteId("THtip")); + obs["fin1"] = SiteXpos(SiteId("IFtip")); + obs["fin2"] = SiteXpos(SiteId("MFtip")); + obs["fin3"] = SiteXpos(SiteId("RFtip")); + obs["fin4"] = SiteXpos(SiteId("LFtip")); + obs["Rpalm_pos"] = + AverageSites("prosthesis/palm_thumb", "prosthesis/palm_pinky"); + obs["obj_pos"] = SiteXpos(SiteId("touch_site")); + obs["start_pos"] = BodyPos(BodyId("start")); + obs["goal_pos"] = BodyPos(BodyId("goal")); + obs["reach_err"] = Subtract(obs["palm_pos"], obs["obj_pos"]); + obs["pass_err"] = Subtract(obs["Rpalm_pos"], obs["obj_pos"]); + const int elbow = JointId("elbow_flexion"); + obs["elbow_fle"] = JointQposValues(elbow); + } + + if (task_.kind == MyoSuiteTaskKind::kMyoDmTrack) { + obs["curr_hand_qpos"] = + QposSlice(0, std::max(0, static_cast(model_->nq) - 6)); + obs["curr_hand_qvel"] = + QvelSlice(0, std::max(0, static_cast(model_->nv) - 6), false); + obs["targ_hand_qpos"] = + myodm_reference.robot.empty() + ? std::vector(obs["curr_hand_qpos"].size(), 0.0) + : myodm_reference.robot; + obs["targ_hand_qvel"] = myodm_reference.robot_vel.empty() + ? std::vector{0.0} + : myodm_reference.robot_vel; + obs["hand_qpos_err"] = + Subtract(obs["curr_hand_qpos"], obs["targ_hand_qpos"]); + obs["hand_qvel_err"] = + myodm_reference.robot_vel.empty() + ? std::vector{0.0} + : Subtract(obs["curr_hand_qvel"], obs["targ_hand_qvel"]); + const int object_bid = task_.object_name[0] != '\0' + ? BodyId(task_.object_name) + : BodyId("Object"); + obs["curr_obj_com"] = BodyXpos(object_bid); + obs["curr_obj_rot"] = BodyXquat(object_bid); + if (myodm_reference.object.size() >= 7) { + obs["targ_obj_com"] = {myodm_reference.object[0], + myodm_reference.object[1], + myodm_reference.object[2]}; + obs["targ_obj_rot"] = { + myodm_reference.object[3], myodm_reference.object[4], + myodm_reference.object[5], myodm_reference.object[6]}; + } else { + obs["targ_obj_com"] = {0.2, 0.2, 0.1}; + obs["targ_obj_rot"] = {1.0, 0.0, 0.0, 0.0}; + } + obs["obj_com_err"] = Subtract(obs["curr_obj_com"], obs["targ_obj_com"]); + obs["wrist_err"] = BodyXpos(BodyId("lunate")); + obs["base_error"] = Subtract(obs["curr_obj_com"], obs["wrist_err"]); + } + (void)dt; + return obs; + } + + const std::vector& ObsValue(const ObsDict& obs, + const std::string& key) const { + static const std::vector k_empty; + auto it = obs.find(key); + return it == obs.end() ? k_empty : it->second; + } + + mjtNum ActMagnitude(const ObsDict& obs) const { + const auto& act = ObsValue(obs, "act"); + if (act.empty() || model_->na == 0) { + return 0.0; + } + return Norm(act) / static_cast(model_->na); + } + + mjtNum ComponentValue( + const std::string& key, + const std::unordered_map& values) const { + auto it = values.find(key); + return it == values.end() ? 0.0 : it->second; + } + + bool WalkDone(const ObsDict& obs) const { + const auto& height = ObsValue(obs, "height"); + const mjtNum min_height = + metadata_.min_height > 0.0 ? metadata_.min_height : 0.8; + if (!height.empty() && height[0] < min_height) { + return true; + } + const mjtNum max_rot = metadata_.max_rot > 0.0 ? metadata_.max_rot : 0.8; + const auto quat = QposSlice(3, 7); + if (quat.size() == 4) { + std::array mat{}; + mju_quat2Mat(mat.data(), quat.data()); + if (std::abs(mat[0]) > max_rot) { + return true; + } + } + if (task_.kind == MyoSuiteTaskKind::kTerrain) { + const auto& feet = ObsValue(obs, "feet_heights"); + if (!height.empty() && feet.size() >= 2 && + height[0] - (feet[0] + feet[1]) * 0.5 < 0.61) { + return true; + } + } + return false; + } + + RewardResult ComputeReward(const ObsDict& obs) { + std::unordered_map values; + bool terminated = false; + const mjtNum pi = std::acos(static_cast(-1.0)); + + if (task_.kind == MyoSuiteTaskKind::kMyoDmTrack) { + const mjtNum obj_com_err = Norm(ObsValue(obs, "obj_com_err")); + const mjtNum obj_rot_err = + QuaternionDistance(ObsValue(obs, "curr_obj_rot"), + ObsValue(obs, "targ_obj_rot")) / + pi; + const mjtNum obj_reward = + std::exp(-50.0 * (obj_com_err + 0.1 * obj_rot_err)); + const auto& targ_obj_com = ObsValue(obs, "targ_obj_com"); + const auto& curr_obj_com = ObsValue(obs, "curr_obj_com"); + const bool lift_bonus = + targ_obj_com.size() > 2 && curr_obj_com.size() > 2 && + targ_obj_com[2] >= myodm_lift_z_ && curr_obj_com[2] >= myodm_lift_z_; + const mjtNum qpos_reward = + std::exp(-5.0 * SquaredNorm(ObsValue(obs, "hand_qpos_err"))); + const mjtNum qvel_reward = + std::exp(-0.1 * SquaredNorm(ObsValue(obs, "hand_qvel_err"))); + const mjtNum pose_reward = 0.35 * qpos_reward + 0.05 * qvel_reward; + const mjtNum base_error = Norm(ObsValue(obs, "base_error")); + const mjtNum base_reward = std::exp(-40.0 * base_error); + const bool myodm_done = + SquaredNorm(ObsValue(obs, "obj_com_err")) >= 0.25 * 0.25 || + SquaredNorm(ObsValue(obs, "base_error")) >= 0.25 * 0.25; + values["pose"] = pose_reward; + values["object"] = obj_reward + base_reward; + values["bonus"] = lift_bonus ? 1.0 : 0.0; + values["penalty"] = myodm_done ? 1.0 : 0.0; + values["sparse"] = 0.0; + values["solved"] = 0.0; + values["done"] = myodm_done ? 1.0 : 0.0; + terminated = myodm_done; + } else if (task_.kind == MyoSuiteTaskKind::kPose || + task_.kind == MyoSuiteTaskKind::kTorsoPose) { + const mjtNum pose_dist = Norm(ObsValue(obs, "pose_err")); + const mjtNum pose_thd = + metadata_.pose_thd > 0.0 ? metadata_.pose_thd : 0.35; + const mjtNum far_th = + task_.kind == MyoSuiteTaskKind::kTorsoPose ? pi : 2.0 * pi; + values["pose"] = -pose_dist; + values["bonus"] = (pose_dist < pose_thd ? 1.0 : 0.0) + + (pose_dist < 1.5 * pose_thd ? 1.0 : 0.0); + values["penalty"] = pose_dist > far_th ? -1.0 : 0.0; + values["act_reg"] = -ActMagnitude(obs); + values["sparse"] = -pose_dist; + values["solved"] = pose_dist < pose_thd ? 1.0 : 0.0; + terminated = pose_dist > far_th; + } else if (task_.kind == MyoSuiteTaskKind::kReach || + task_.kind == MyoSuiteTaskKind::kWalkReach) { + const mjtNum reach_dist = Norm(ObsValue(obs, "reach_err")); + const mjtNum vel_dist = Norm(ObsValue(obs, "qvel")); + const mjtNum nsites = + std::max(1.0, static_cast(tip_sites_.size())); + const mjtNum far_base = metadata_.far_th > 0.0 ? metadata_.far_th : 0.35; + const mjtNum far_th = data_->time > 2.0 * Dt() + ? far_base * nsites + : std::numeric_limits::infinity(); + const mjtNum near_th = + (task_.kind == MyoSuiteTaskKind::kWalkReach ? 0.050 : 0.0125) * + nsites; + values["reach"] = task_.kind == MyoSuiteTaskKind::kWalkReach + ? 10.0 - reach_dist - 10.0 * vel_dist + : -reach_dist; + values["bonus"] = (reach_dist < 2.0 * near_th ? 1.0 : 0.0) + + (reach_dist < near_th ? 1.0 : 0.0); + values["act_reg"] = task_.kind == MyoSuiteTaskKind::kWalkReach + ? -100.0 * ActMagnitude(obs) + : -ActMagnitude(obs); + values["penalty"] = reach_dist > far_th ? -1.0 : 0.0; + values["sparse"] = -reach_dist; + values["solved"] = reach_dist < near_th ? 1.0 : 0.0; + terminated = reach_dist > far_th; + } else if (task_.kind == MyoSuiteTaskKind::kKeyTurn) { + const mjtNum if_dist = + std::abs(Norm(ObsValue(obs, "IFtip_approach")) - 0.030); + const mjtNum th_dist = + std::abs(Norm(ObsValue(obs, "THtip_approach")) - 0.030); + const auto& key_qpos = ObsValue(obs, "key_qpos"); + const mjtNum key_pos = key_qpos.empty() ? 0.0 : key_qpos[0]; + const mjtNum far_th = 0.1; + const mjtNum goal_th = metadata_.goal_th > 0.0 ? metadata_.goal_th : 3.14; + values["key_turn"] = key_pos; + values["IFtip_approach"] = -if_dist; + values["THtip_approach"] = -th_dist; + values["act_reg"] = -ActMagnitude(obs); + values["bonus"] = + (key_pos > pi / 2.0 ? 1.0 : 0.0) + (key_pos > pi ? 1.0 : 0.0); + values["penalty"] = (if_dist > far_th / 2.0 ? -1.0 : 0.0) + + (th_dist > far_th / 2.0 ? -1.0 : 0.0); + values["sparse"] = key_pos; + values["solved"] = key_pos > goal_th ? 1.0 : 0.0; + terminated = if_dist > far_th || th_dist > far_th; + } else if (task_.kind == MyoSuiteTaskKind::kObjHoldFixed || + task_.kind == MyoSuiteTaskKind::kObjHoldRandom) { + const mjtNum goal_dist = Norm(ObsValue(obs, "obj_err")); + const mjtNum goal_th = 0.010; + const bool drop = goal_dist > 0.300; + values["goal_dist"] = -goal_dist; + values["bonus"] = (goal_dist < 2.0 * goal_th ? 1.0 : 0.0) + + (goal_dist < goal_th ? 1.0 : 0.0); + values["act_reg"] = -ActMagnitude(obs); + values["penalty"] = drop ? -1.0 : 0.0; + values["sparse"] = -goal_dist; + values["solved"] = goal_dist < goal_th ? 1.0 : 0.0; + terminated = drop; + } else if (task_.kind == MyoSuiteTaskKind::kPenTwirlFixed || + task_.kind == MyoSuiteTaskKind::kPenTwirlRandom || + task_.kind == MyoSuiteTaskKind::kReorientSar) { + const mjtNum pos_align = Norm(ObsValue(obs, "obj_err_pos")); + const mjtNum rot_align = + Cosine(ObsValue(obs, "obj_rot"), ObsValue(obs, "obj_des_rot")); + const bool dropped = pos_align > 0.075; + values["pos_align"] = -pos_align; + values["rot_align"] = rot_align; + values["act_reg"] = -ActMagnitude(obs); + values["drop"] = dropped ? -1.0 : 0.0; + values["bonus"] = (rot_align > 0.9 && pos_align < 0.075 ? 1.0 : 0.0) + + (rot_align > 0.95 && pos_align < 0.075 ? 5.0 : 0.0); + values["sparse"] = -pos_align + rot_align; + const bool solved = rot_align > 0.95 && !dropped; + values["solved"] = solved ? 1.0 : 0.0; + if (task_.kind == MyoSuiteTaskKind::kReorientSar) { + const int indicator = SiteId("success"); + if (indicator >= 0 && (model_->site_rgba[4 * indicator] != 0.0 || + model_->site_rgba[4 * indicator + 1] != 2.0)) { + model_->site_rgba[4 * indicator] = solved ? 0.0 : 2.0; + model_->site_rgba[4 * indicator + 1] = solved ? 2.0 : 0.0; + } + } + terminated = dropped; + } else if (task_.kind == MyoSuiteTaskKind::kWalk || + task_.kind == MyoSuiteTaskKind::kTerrain) { + const mjtNum target_x = metadata_.target_x_vel; + const mjtNum target_y = + metadata_.target_y_vel != 0.0 ? metadata_.target_y_vel : 1.2; + const auto& com_vel = ObsValue(obs, "com_vel"); + const mjtNum vx = com_vel.size() > 0 ? com_vel[0] : 0.0; + const mjtNum vy = com_vel.size() > 1 ? com_vel[1] : 0.0; + const mjtNum vel_reward = std::exp(-std::pow(target_y - vy, 2)) + + std::exp(-std::pow(target_x - vx, 2)); + const int hip_period = + metadata_.hip_period > 0 ? metadata_.hip_period : 100; + const mjtNum phase = + std::fmod(static_cast(task_step_) / hip_period, 1.0); + const std::vector des_angles = { + 0.8 * std::cos(phase * 2.0 * pi + pi), + 0.8 * std::cos(phase * 2.0 * pi)}; + const mjtNum cyclic_hip = Norm(Subtract( + des_angles, JointAngles({"hip_flexion_l", "hip_flexion_r"}))); + const mjtNum ref_rot = 1.0; + const auto joint_angles = + JointAngles({"hip_adduction_l", "hip_adduction_r", "hip_rotation_l", + "hip_rotation_r"}); + mjtNum joint_mag = 0.0; + for (mjtNum angle : joint_angles) { + joint_mag += std::abs(angle); + } + if (!joint_angles.empty()) { + joint_mag /= static_cast(joint_angles.size()); + } + const mjtNum joint_angle_rew = std::exp(-5.0 * joint_mag); + const bool done = WalkDone(obs); + values["vel_reward"] = vel_reward; + values["cyclic_hip"] = cyclic_hip; + values["ref_rot"] = ref_rot; + values["joint_angle_rew"] = joint_angle_rew; + values["act_mag"] = ActMagnitude(obs); + values["sparse"] = vel_reward; + values["solved"] = vel_reward >= 1.0 ? 1.0 : 0.0; + values["done"] = done ? 1.0 : 0.0; + terminated = done; + } else if (task_.kind == MyoSuiteTaskKind::kChallengeBaoding) { + const mjtNum d1 = Norm(ObsValue(obs, "target1_err")); + const mjtNum d2 = Norm(ObsValue(obs, "target2_err")); + const auto& object1 = ObsValue(obs, "object1_pos"); + const auto& object2 = ObsValue(obs, "object2_pos"); + const bool fall = (object1.size() > 2 && object1[2] < 1.25) || + (object2.size() > 2 && object2[2] < 1.25); + values["pos_dist_1"] = -d1; + values["pos_dist_2"] = -d2; + values["act_reg"] = -ActMagnitude(obs); + values["sparse"] = -(d1 + d2); + values["solved"] = d1 < 0.015 && d2 < 0.015 && !fall ? 1.0 : 0.0; + const int object1_gid = GeomId("ball1"); + const int object2_gid = GeomId("ball2"); + if (object1_gid >= 0) { + model_->geom_rgba[4 * object1_gid] = + d1 < 0.015 ? static_cast(1.0) : static_cast(0.5); + model_->geom_rgba[4 * object1_gid + 1] = + d1 < 0.015 ? static_cast(1.0) : static_cast(0.5); + } + if (object2_gid >= 0) { + model_->geom_rgba[4 * object2_gid] = + d1 < 0.015 ? static_cast(0.9) : static_cast(0.5); + model_->geom_rgba[4 * object2_gid + 1] = + d1 < 0.015 ? static_cast(0.7) : static_cast(0.5); + } + terminated = fall; + } else if (task_.kind == MyoSuiteTaskKind::kChallengeRelocate || + task_.kind == MyoSuiteTaskKind::kChallengeReorient) { + const mjtNum pos_dist = Norm(ObsValue(obs, "pos_err")); + const mjtNum rot_dist = Norm(ObsValue(obs, "rot_err")); + const bool drop = task_.kind == MyoSuiteTaskKind::kChallengeRelocate + ? Norm(ObsValue(obs, "reach_err")) > 0.50 + : pos_dist > 0.200; + values["pos_dist"] = -pos_dist; + values["rot_dist"] = -rot_dist; + values["bonus"] = + (pos_dist < 0.05 ? 1.0 : 0.0) + (pos_dist < 0.025 ? 1.0 : 0.0); + values["act_reg"] = -ActMagnitude(obs); + values["penalty"] = drop ? -1.0 : 0.0; + values["sparse"] = -rot_dist - 10.0 * pos_dist; + const bool solved = pos_dist < 0.025 && rot_dist < 0.262 && !drop; + values["solved"] = solved ? 1.0 : 0.0; + const int indicator = SiteId("target_ball"); + if (indicator >= 0) { + model_->site_rgba[4 * indicator] = solved ? 0.0 : 2.0; + model_->site_rgba[4 * indicator + 1] = solved ? 2.0 : 0.0; + if (task_.kind == MyoSuiteTaskKind::kChallengeRelocate) { + for (int axis = 0; axis < 3; ++axis) { + model_->site_size[3 * indicator + axis] = solved ? 0.25 : 0.1; + } + } + } + terminated = drop; + } else if (task_.kind == MyoSuiteTaskKind::kChallengeRunTrack) { + const auto& root_pos = ObsValue(obs, "model_root_pos"); + const auto& root_vel = ObsValue(obs, "model_root_vel"); + const mjtNum x = root_pos.size() > 0 ? root_pos[0] : 0.0; + const mjtNum y = root_pos.size() > 1 ? root_pos[1] : 0.0; + const mjtNum y_vel = root_vel.size() > 1 ? root_vel[1] : 0.0; + const bool random_track = + std::string_view(task_.id).find("Random") != std::string_view::npos; + const mjtNum start_pos = random_track ? 58.0 : 14.0; + const mjtNum end_pos = random_track ? -45.0 : -15.0; + const bool fallen = WalkDone(obs); + const bool win = y < end_pos; + const bool lose = x > 1.0 || x < -1.0 || y > start_pos + 2.0 || fallen; + values["act_reg"] = MeanSquare(ObsValue(obs, "act")); + values["pain"] = 0.0; + values["sparse"] = -y_vel; + values["solved"] = win ? 1.0 : 0.0; + values["done"] = (win || lose) ? 1.0 : 0.0; + terminated = win || lose; + } else if (task_.kind == MyoSuiteTaskKind::kChallengeChaseTag) { + const auto& root_pos = ObsValue(obs, "model_root_pos"); + const auto& opponent_pose = ObsValue(obs, "opponent_pose"); + const mjtNum dx = (root_pos.size() > 0 ? root_pos[0] : 0.0) - + (opponent_pose.size() > 0 ? opponent_pose[0] : 0.0); + const mjtNum dy = (root_pos.size() > 1 ? root_pos[1] : 0.0) - + (opponent_pose.size() > 1 ? opponent_pose[1] : 0.0); + const mjtNum distance = std::sqrt(dx * dx + dy * dy); + const bool tagged = distance <= 0.5; + const auto pelvis = BodyXpos(BodyId("pelvis")); + const bool out_of_bounds = + std::abs(pelvis[0]) > 6.5 || std::abs(pelvis[1]) > 6.5; + const bool fallen = pelvis[2] < 0.5; + const bool win = tagged; + const bool lose = data_->time >= 20.0 || out_of_bounds || fallen; + values["act_reg"] = ActMagnitude(obs); + values["distance"] = distance; + values["lose"] = lose ? 1.0 : 0.0; + values["sparse"] = + win ? 1.0 - std::round(data_->time * 100.0) / 100.0 / 20.0 : 0.0; + values["solved"] = win ? 1.0 : 0.0; + values["done"] = (win || lose) ? 1.0 : 0.0; + const int indicator = SiteId("opponent_indicator"); + if (indicator >= 0) { + model_->site_rgba[4 * indicator] = win ? 0.0 : 2.0; + model_->site_rgba[4 * indicator + 1] = win ? 2.0 : 0.0; + model_->site_rgba[4 * indicator + 2] = 0.0; + model_->site_rgba[4 * indicator + 3] = win ? 0.2 : 0.0; + } + terminated = win || lose; + } else if (task_.kind == MyoSuiteTaskKind::kChallengeTableTennis) { + const mjtNum reach_dist = Norm(ObsValue(obs, "reach_err")); + const mjtNum palm_dist = Norm(ObsValue(obs, "palm_err")); + const mjtNum paddle_quat_err = Norm(ObsValue(obs, "padde_ori_err")); + const int torso_joint = JointId("flex_extension"); + const mjtNum torso_err = + torso_joint >= 0 + ? std::abs(data_->qpos[model_->jnt_qposadr[torso_joint]]) + : 0.0; + const auto& ball_pos = ObsValue(obs, "ball_pos"); + const auto& touching = ObsValue(obs, "touching_info"); + const bool paddle_touch = !touching.empty() && touching[0] == 1.0; + const bool ball_done = + data_->time > 20.0 || (ball_pos.size() > 2 && ball_pos[2] < 0.3); + values["reach_dist"] = std::exp(-reach_dist); + values["palm_dist"] = std::exp(-5.0 * palm_dist); + values["paddle_quat"] = std::exp(-5.0 * paddle_quat_err); + values["torso_up"] = std::exp(-5.0 * torso_err); + values["act_reg"] = -ActMagnitude(obs); + values["sparse"] = paddle_touch ? 1.0 : 0.0; + values["solved"] = 0.0; + values["done"] = ball_done ? 1.0 : 0.0; + terminated = ball_done; + } else if (task_.kind == MyoSuiteTaskKind::kChallengeBimanual) { + const mjtNum reach_dist = Norm(ObsValue(obs, "reach_err")); + const mjtNum pass_dist = Norm(ObsValue(obs, "pass_err")); + const auto& obj_pos = ObsValue(obs, "obj_pos"); + const auto& palm_pos = ObsValue(obs, "palm_pos"); + auto goal_pos = ObsValue(obs, "goal_pos"); + if (goal_pos.size() >= 3) { + goal_pos[2] = 1.09; + } + mjtNum lift_height = 0.0; + if (obj_pos.size() >= 3 && palm_pos.size() >= 3) { + const mjtNum obj_lift = obj_pos[2] - bimanual_init_obj_z_; + const mjtNum palm_lift = palm_pos[2] - bimanual_init_palm_z_; + lift_height = + 5.0 * std::exp(-10.0 * ((obj_lift - 0.2) * (obj_lift - 0.2) + + (palm_lift - 0.2) * (palm_lift - 0.2))) - + 5.0; + } + mjtNum fin_open = 0.0; + mjtNum fin_dis = 0.0; + for (const char* key : {"fin0", "fin1", "fin2", "fin3", "fin4"}) { + fin_open += Norm(Subtract(ObsValue(obs, key), palm_pos)); + fin_dis += Norm(Subtract(ObsValue(obs, key), obj_pos)); + } + const auto& elbow = ObsValue(obs, "elbow_fle"); + const mjtNum elbow_value = elbow.empty() ? 0.0 : elbow[0]; + const mjtNum elbow_err = + 5.0 * std::exp(-10.0 * (elbow_value - 1.0) * (elbow_value - 1.0)) - + 5.0; + const mjtNum goal_dist = Norm(Subtract(obj_pos, goal_pos)); + const auto& touching = ObsValue(obs, "touching_body"); + if (touching.size() > 3 && touching[3] == 1.0) { + ++bimanual_goal_touch_; + } + const bool solved = goal_dist < 0.17 && bimanual_goal_touch_ >= 10; + const bool done = data_->time > 10.0 || + (obj_pos.size() > 2 && obj_pos[2] < 0.3) || solved; + values["reach_dist"] = reach_dist + std::log(reach_dist + 1e-6); + values["pass_err"] = pass_dist + std::log(pass_dist + 1e-3); + values["act"] = ActMagnitude(obs); + values["fin_open"] = std::exp(-5.0 * fin_open); + values["fin_dis"] = fin_dis + std::log(fin_dis + 1e-6); + values["lift_bonus"] = elbow_err; + values["lift_height"] = lift_height; + values["goal_dist"] = goal_dist; + values["sparse"] = 0.0; + values["solved"] = solved ? 1.0 : 0.0; + values["done"] = done ? 1.0 : 0.0; + terminated = done; + } else if (task_.kind == MyoSuiteTaskKind::kChallengeSoccer) { + const auto& root_pos = ObsValue(obs, "model_root_pos"); + const auto& ball_pos = ObsValue(obs, "ball_pos"); + std::vector root_xyz(3, 0.0); + for (std::size_t i = 0; i < root_xyz.size() && i < root_pos.size(); ++i) { + root_xyz[i] = root_pos[i]; + } + const mjtNum distance = Norm(Subtract(root_xyz, ball_pos)); + const bool goal_scored = ball_pos.size() >= 3 && ball_pos[0] >= 50.0 && + ball_pos[1] >= -3.3 && ball_pos[1] <= 3.3 && + ball_pos[2] >= 0.0 && ball_pos[2] <= 2.2; + const auto pelvis = BodyXpos(BodyId("pelvis")); + const bool fallen = pelvis[2] < 0.2; + const bool done = goal_scored || data_->time >= 10.0 || fallen; + values["goal_scored"] = goal_scored ? 1.0 : 0.0; + values["time_cost"] = data_->time; + values["act_reg"] = ActMagnitude(obs); + values["pain"] = 0.0; + values["distance"] = distance; + values["sparse"] = done ? 1.0 : 0.0; + values["solved"] = goal_scored ? 1.0 : 0.0; + values["done"] = done ? 1.0 : 0.0; + terminated = done; + } else { + throw std::runtime_error("Unhandled MyoSuite reward task kind: " + + std::string(task_.id)); + } + + RewardResult result; + for (const auto& [key, weight] : reward_weights_) { + result.dense += weight * ComponentValue(key, values); + } + result.sparse = ComponentValue("sparse", values); + result.solved = ComponentValue("solved", values); + result.terminated = terminated || ComponentValue("done", values) != 0.0; + return result; + } + + void CapturePaddedResetState() { + CaptureResetState(); +#ifdef ENVPOOL_TEST + std::fill(qpos0_pad_.begin(), qpos0_pad_.end(), 0.0); + std::fill(qvel0_pad_.begin(), qvel0_pad_.end(), 0.0); + std::fill(act0_pad_.begin(), act0_pad_.end(), 0.0); + std::fill(qacc0_pad_.begin(), qacc0_pad_.end(), 0.0); + std::fill(qacc_warmstart0_pad_.begin(), qacc_warmstart0_pad_.end(), 0.0); + for (int i = 0; i < model_->nq && i < 2048; ++i) { + qpos0_pad_[i] = data_->qpos[i]; + } + for (int i = 0; i < model_->nv && i < 2048; ++i) { + qvel0_pad_[i] = data_->qvel[i]; + } + for (int i = 0; i < model_->na && i < 2048; ++i) { + act0_pad_[i] = data_->act[i]; + } + for (int i = 0; i < model_->nv && i < 2048; ++i) { + qacc0_pad_[i] = data_->qacc[i]; + qacc_warmstart0_pad_[i] = data_->qacc_warmstart[i]; + } +#endif + } + + void CapturePaddedCurrentState() { +#ifdef ENVPOOL_TEST + std::fill(qpos_pad_.begin(), qpos_pad_.end(), 0.0); + std::fill(qvel_pad_.begin(), qvel_pad_.end(), 0.0); + std::fill(act_pad_.begin(), act_pad_.end(), 0.0); + std::fill(ctrl_pad_.begin(), ctrl_pad_.end(), 0.0); + std::fill(qacc_pad_.begin(), qacc_pad_.end(), 0.0); + std::fill(qacc_warmstart_pad_.begin(), qacc_warmstart_pad_.end(), 0.0); + std::fill(actuator_length_pad_.begin(), actuator_length_pad_.end(), 0.0); + std::fill(actuator_velocity_pad_.begin(), actuator_velocity_pad_.end(), + 0.0); + std::fill(actuator_force_pad_.begin(), actuator_force_pad_.end(), 0.0); + std::fill(fatigue_ma_pad_.begin(), fatigue_ma_pad_.end(), 0.0); + std::fill(fatigue_mr_pad_.begin(), fatigue_mr_pad_.end(), 0.0); + std::fill(fatigue_mf_pad_.begin(), fatigue_mf_pad_.end(), 0.0); + std::fill(fatigue_tl_pad_.begin(), fatigue_tl_pad_.end(), 0.0); + std::fill(fatigue_tauact_pad_.begin(), fatigue_tauact_pad_.end(), 0.0); + std::fill(fatigue_taudeact_pad_.begin(), fatigue_taudeact_pad_.end(), 0.0); + std::fill(site_pos_pad_.begin(), site_pos_pad_.end(), 0.0); + std::fill(site_quat_pad_.begin(), site_quat_pad_.end(), 0.0); + std::fill(site_xpos_pad_.begin(), site_xpos_pad_.end(), 0.0); + std::fill(site_size_pad_.begin(), site_size_pad_.end(), 0.0); + std::fill(site_rgba_pad_.begin(), site_rgba_pad_.end(), 0.0); + std::fill(body_pos_pad_.begin(), body_pos_pad_.end(), 0.0); + std::fill(body_quat_pad_.begin(), body_quat_pad_.end(), 0.0); + std::fill(body_mass_pad_.begin(), body_mass_pad_.end(), 0.0); + std::fill(light_xpos_pad_.begin(), light_xpos_pad_.end(), 0.0); + std::fill(light_xdir_pad_.begin(), light_xdir_pad_.end(), 0.0); + std::fill(geom_pos_pad_.begin(), geom_pos_pad_.end(), 0.0); + std::fill(geom_quat_pad_.begin(), geom_quat_pad_.end(), 0.0); + std::fill(geom_size_pad_.begin(), geom_size_pad_.end(), 0.0); + std::fill(geom_xpos_pad_.begin(), geom_xpos_pad_.end(), 0.0); + std::fill(geom_xmat_pad_.begin(), geom_xmat_pad_.end(), 0.0); + std::fill(geom_rgba_pad_.begin(), geom_rgba_pad_.end(), 0.0); + std::fill(geom_friction_pad_.begin(), geom_friction_pad_.end(), 0.0); + std::fill(geom_aabb_pad_.begin(), geom_aabb_pad_.end(), 0.0); + std::fill(geom_rbound_pad_.begin(), geom_rbound_pad_.end(), 0.0); + std::fill(geom_contype_pad_.begin(), geom_contype_pad_.end(), 0.0); + std::fill(geom_conaffinity_pad_.begin(), geom_conaffinity_pad_.end(), 0.0); + std::fill(geom_type_pad_.begin(), geom_type_pad_.end(), 0.0); + std::fill(geom_condim_pad_.begin(), geom_condim_pad_.end(), 0.0); + std::fill(hfield_data_pad_.begin(), hfield_data_pad_.end(), 0.0); + std::fill(mocap_pos_pad_.begin(), mocap_pos_pad_.end(), 0.0); + std::fill(mocap_quat_pad_.begin(), mocap_quat_pad_.end(), 0.0); + for (int i = 0; i < model_->nq && i < 2048; ++i) { + qpos_pad_[i] = data_->qpos[i]; + } + for (int i = 0; i < model_->nv && i < 2048; ++i) { + qvel_pad_[i] = data_->qvel[i]; + qacc_pad_[i] = data_->qacc[i]; + qacc_warmstart_pad_[i] = data_->qacc_warmstart[i]; + } + for (int i = 0; i < model_->na && i < 2048; ++i) { + act_pad_[i] = data_->act[i]; + } + for (int i = 0; i < model_->nu && i < 2048; ++i) { + ctrl_pad_[i] = data_->ctrl[i]; + actuator_length_pad_[i] = data_->actuator_length[i]; + actuator_velocity_pad_[i] = data_->actuator_velocity[i]; + actuator_force_pad_[i] = data_->actuator_force[i]; + } + for (int i = 0; i < static_cast(fatigue_ma_.size()) && i < 2048; ++i) { + fatigue_ma_pad_[i] = fatigue_ma_[i]; + fatigue_mr_pad_[i] = fatigue_mr_[i]; + fatigue_mf_pad_[i] = fatigue_mf_[i]; + fatigue_tl_pad_[i] = fatigue_tl_[i]; + fatigue_tauact_pad_[i] = fatigue_tauact_[i]; + fatigue_taudeact_pad_[i] = fatigue_taudeact_[i]; + } + for (int i = 0; i < model_->nsite * 3 && i < kMyoSuiteTestStatePad; ++i) { + site_pos_pad_[i] = model_->site_pos[i]; + site_xpos_pad_[i] = data_->site_xpos[i]; + } + for (int i = 0; i < model_->nsite * 4 && i < kMyoSuiteTestStatePad; ++i) { + site_quat_pad_[i] = model_->site_quat[i]; + } + for (int i = 0; i < model_->nsite * 3 && i < kMyoSuiteTestStatePad; ++i) { + site_size_pad_[i] = model_->site_size[i]; + } + for (int i = 0; i < model_->nsite * 4 && i < kMyoSuiteTestStatePad; ++i) { + site_rgba_pad_[i] = model_->site_rgba[i]; + } + for (int i = 0; i < model_->nbody * 3 && i < kMyoSuiteTestStatePad; ++i) { + body_pos_pad_[i] = model_->body_pos[i]; + } + for (int i = 0; i < model_->nbody * 4 && i < kMyoSuiteTestStatePad; ++i) { + body_quat_pad_[i] = model_->body_quat[i]; + } + for (int i = 0; i < model_->nbody && i < kMyoSuiteTestStatePad; ++i) { + body_mass_pad_[i] = model_->body_mass[i]; + } + for (int i = 0; i < model_->nlight * 3 && i < kMyoSuiteTestStatePad; ++i) { + light_xpos_pad_[i] = data_->light_xpos[i]; + light_xdir_pad_[i] = data_->light_xdir[i]; + } + for (int i = 0; i < model_->ngeom * 3 && i < kMyoSuiteTestStatePad; ++i) { + geom_pos_pad_[i] = model_->geom_pos[i]; + geom_size_pad_[i] = model_->geom_size[i]; + geom_friction_pad_[i] = model_->geom_friction[i]; + geom_xpos_pad_[i] = data_->geom_xpos[i]; + } + for (int i = 0; i < model_->ngeom * 4 && i < kMyoSuiteTestStatePad; ++i) { + geom_quat_pad_[i] = model_->geom_quat[i]; + geom_rgba_pad_[i] = model_->geom_rgba[i]; + } + for (int i = 0; i < model_->ngeom * 9 && i < kMyoSuiteTestStatePad; ++i) { + geom_xmat_pad_[i] = data_->geom_xmat[i]; + } + for (int i = 0; i < model_->ngeom && i < kMyoSuiteTestStatePad; ++i) { + geom_rbound_pad_[i] = model_->geom_rbound[i]; + geom_contype_pad_[i] = model_->geom_contype[i]; + geom_conaffinity_pad_[i] = model_->geom_conaffinity[i]; + geom_type_pad_[i] = model_->geom_type[i]; + geom_condim_pad_[i] = model_->geom_condim[i]; + } + for (int i = 0; i < model_->ngeom * 6 && i < kMyoSuiteTestStatePad; ++i) { + geom_aabb_pad_[i] = model_->geom_aabb[i]; + } + for (int i = 0; i < model_->nhfielddata && i < kMyoSuiteTestStatePad; ++i) { + hfield_data_pad_[i] = model_->hfield_data[i]; + } + for (int i = 0; i < model_->nmocap * 3 && i < kMyoSuiteTestStatePad; ++i) { + mocap_pos_pad_[i] = data_->mocap_pos[i]; + } + for (int i = 0; i < model_->nmocap * 4 && i < kMyoSuiteTestStatePad; ++i) { + mocap_quat_pad_[i] = data_->mocap_quat[i]; + } +#endif + } + + void WriteState(mjtNum reward, bool reset) { + WriteState(reward, reset, BuildObsDict()); + } + + void WriteState(mjtNum reward, bool reset, const ObsDict& obs_dict) { + auto state = Allocate(); + if constexpr (kFromPixels) { + auto obs_pixels = state["obs:pixels"_]; + AssignPixelObservation("obs:pixels", &obs_pixels, reset); + } else { + std::vector obs = Observation(obs_dict); + auto obs_state = state["obs"_]; + AssignObservation("obs", &obs_state, obs.data(), obs.size(), reset); + } + state["reward"_] = static_cast(reward); + state["trunc"_] = elapsed_step_ >= max_episode_steps_; + state["info:task_id"_] = task_index_; + state["info:sparse"_] = sparse_; + state["info:solved"_] = solved_; + state["info:oracle_numpy2_broken"_] = task_.oracle_numpy2_broken; + state["info:model_nq"_] = model_->nq; + state["info:model_nv"_] = model_->nv; + state["info:model_na"_] = model_->na; + state["info:model_nu"_] = model_->nu; + state["info:model_nsite"_] = model_->nsite; + state["info:model_nbody"_] = model_->nbody; + state["info:model_ngeom"_] = model_->ngeom; + state["info:model_nhfielddata"_] = model_->nhfielddata; + state["info:model_nmocap"_] = model_->nmocap; +#ifdef ENVPOOL_TEST + CapturePaddedCurrentState(); + state["info:qpos0"_].Assign(qpos0_pad_.data(), qpos0_pad_.size()); + state["info:qvel0"_].Assign(qvel0_pad_.data(), qvel0_pad_.size()); + state["info:act0"_].Assign(act0_pad_.data(), act0_pad_.size()); + state["info:qacc0"_].Assign(qacc0_pad_.data(), qacc0_pad_.size()); + state["info:qacc_warmstart0"_].Assign(qacc_warmstart0_pad_.data(), + qacc_warmstart0_pad_.size()); + state["info:qpos"_].Assign(qpos_pad_.data(), qpos_pad_.size()); + state["info:qvel"_].Assign(qvel_pad_.data(), qvel_pad_.size()); + state["info:act"_].Assign(act_pad_.data(), act_pad_.size()); + state["info:ctrl"_].Assign(ctrl_pad_.data(), ctrl_pad_.size()); + state["info:qacc"_].Assign(qacc_pad_.data(), qacc_pad_.size()); + state["info:qacc_warmstart"_].Assign(qacc_warmstart_pad_.data(), + qacc_warmstart_pad_.size()); + state["info:actuator_length"_].Assign(actuator_length_pad_.data(), + actuator_length_pad_.size()); + state["info:actuator_velocity"_].Assign(actuator_velocity_pad_.data(), + actuator_velocity_pad_.size()); + state["info:actuator_force"_].Assign(actuator_force_pad_.data(), + actuator_force_pad_.size()); + state["info:fatigue_ma"_].Assign(fatigue_ma_pad_.data(), + fatigue_ma_pad_.size()); + state["info:fatigue_mr"_].Assign(fatigue_mr_pad_.data(), + fatigue_mr_pad_.size()); + state["info:fatigue_mf"_].Assign(fatigue_mf_pad_.data(), + fatigue_mf_pad_.size()); + state["info:fatigue_tl"_].Assign(fatigue_tl_pad_.data(), + fatigue_tl_pad_.size()); + state["info:fatigue_tauact"_].Assign(fatigue_tauact_pad_.data(), + fatigue_tauact_pad_.size()); + state["info:fatigue_taudeact"_].Assign(fatigue_taudeact_pad_.data(), + fatigue_taudeact_pad_.size()); + state["info:fatigue_dt"_] = + task_.muscle_condition == MyoSuiteMuscleCondition::kFatigue + ? static_cast(Dt()) + : 0.0; + state["info:site_pos"_].Assign(site_pos_pad_.data(), site_pos_pad_.size()); + state["info:site_quat"_].Assign(site_quat_pad_.data(), + site_quat_pad_.size()); + state["info:site_xpos"_].Assign(site_xpos_pad_.data(), + site_xpos_pad_.size()); + state["info:site_size"_].Assign(site_size_pad_.data(), + site_size_pad_.size()); + state["info:site_rgba"_].Assign(site_rgba_pad_.data(), + site_rgba_pad_.size()); + state["info:body_pos"_].Assign(body_pos_pad_.data(), body_pos_pad_.size()); + state["info:body_quat"_].Assign(body_quat_pad_.data(), + body_quat_pad_.size()); + state["info:body_mass"_].Assign(body_mass_pad_.data(), + body_mass_pad_.size()); + state["info:light_xpos"_].Assign(light_xpos_pad_.data(), + light_xpos_pad_.size()); + state["info:light_xdir"_].Assign(light_xdir_pad_.data(), + light_xdir_pad_.size()); + state["info:geom_pos"_].Assign(geom_pos_pad_.data(), geom_pos_pad_.size()); + state["info:geom_quat"_].Assign(geom_quat_pad_.data(), + geom_quat_pad_.size()); + state["info:geom_size"_].Assign(geom_size_pad_.data(), + geom_size_pad_.size()); + state["info:geom_xpos"_].Assign(geom_xpos_pad_.data(), + geom_xpos_pad_.size()); + state["info:geom_xmat"_].Assign(geom_xmat_pad_.data(), + geom_xmat_pad_.size()); + state["info:geom_rgba"_].Assign(geom_rgba_pad_.data(), + geom_rgba_pad_.size()); + state["info:geom_friction"_].Assign(geom_friction_pad_.data(), + geom_friction_pad_.size()); + state["info:geom_aabb"_].Assign(geom_aabb_pad_.data(), + geom_aabb_pad_.size()); + state["info:geom_rbound"_].Assign(geom_rbound_pad_.data(), + geom_rbound_pad_.size()); + state["info:geom_contype"_].Assign(geom_contype_pad_.data(), + geom_contype_pad_.size()); + state["info:geom_conaffinity"_].Assign(geom_conaffinity_pad_.data(), + geom_conaffinity_pad_.size()); + state["info:geom_type"_].Assign(geom_type_pad_.data(), + geom_type_pad_.size()); + state["info:geom_condim"_].Assign(geom_condim_pad_.data(), + geom_condim_pad_.size()); + state["info:hfield_data"_].Assign(hfield_data_pad_.data(), + hfield_data_pad_.size()); + state["info:mocap_pos"_].Assign(mocap_pos_pad_.data(), + mocap_pos_pad_.size()); + state["info:mocap_quat"_].Assign(mocap_quat_pad_.data(), + mocap_quat_pad_.size()); + state["info:time"_] = data_->time; + state["info:model_timestep"_] = model_->opt.timestep; + state["info:frame_skip"_] = frame_skip_; +#endif + } +}; + +using MyoSuiteEnv = MyoSuiteEnvBase; +using MyoSuitePixelEnv = MyoSuiteEnvBase; +using MyoSuiteEnvPool = AsyncEnvPool; +using MyoSuitePixelEnvPool = AsyncEnvPool; + +} // namespace myosuite + +#endif // ENVPOOL_MUJOCO_MYOSUITE_MYOSUITE_ENV_H_ diff --git a/envpool/mujoco/myosuite/myosuite_envpool.cc b/envpool/mujoco/myosuite/myosuite_envpool.cc new file mode 100644 index 000000000..6c5a623b1 --- /dev/null +++ b/envpool/mujoco/myosuite/myosuite_envpool.cc @@ -0,0 +1,93 @@ +// Copyright 2026 Garena Online Private Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "envpool/core/py_envpool.h" +#include "envpool/mujoco/myosuite/myosuite_env.h" + +using MyoSuiteEnvSpec = PyEnvSpec; +using MyoSuiteEnvPool = PyEnvPool; +using MyoSuitePixelEnvSpec = PyEnvSpec; +using MyoSuitePixelEnvPool = PyEnvPool; + +namespace { + +template +py::tuple ExportSpecEntry(const SpecT& spec) { + return py::make_tuple(py::dtype::of(), spec.shape, + spec.bounds, spec.elementwise_bounds, spec.is_discrete); +} + +template +py::tuple ExportSpecEntry(const Spec>& spec) { + return py::make_tuple( + py::dtype::of(), py::make_tuple(spec.shape, spec.inner_spec.shape), + spec.inner_spec.bounds, spec.inner_spec.elementwise_bounds, + spec.inner_spec.is_discrete); +} + +template +py::tuple ExportSpecsDynamic(const std::tuple& specs) { + py::tuple out(sizeof...(SpecT)); + std::size_t index = 0; + std::apply( + [&](const auto&... spec) { + ((out[index++] = ExportSpecEntry(spec)), ...); + }, + specs); + return out; +} + +template +void RegisterMyoSuite(py::module_& m, const char* spec_name, + const char* envpool_name) { + // Avoid binding SPEC::StateSpecT directly; MyoSuite's test info surface is + // large enough to exceed MSVC's 64K debug type-record limit through pybind11. + py::class_(m, spec_name, + py::metaclass(py::module_::import("abc").attr("ABCMeta"))) + .def(py::init()) + .def_readonly("_config_values", &SPEC::py_config_values) + .def_property_readonly( + "_state_spec", + [](const SPEC& self) { return ExportSpecsDynamic(self.state_spec); }) + .def_property_readonly( + "_action_spec", + [](const SPEC& self) { return ExportSpecsDynamic(self.action_spec); }) + .def_readonly_static("_state_keys", &SPEC::py_state_keys) + .def_readonly_static("_action_keys", &SPEC::py_action_keys) + .def_readonly_static("_config_keys", &SPEC::py_config_keys) + .def_readonly_static("_default_config_values", + &SPEC::py_default_config_values); + py::class_(m, envpool_name, + py::metaclass(py::module_::import("abc").attr("ABCMeta"))) + .def(py::init()) + .def_readonly("_spec", &ENVPOOL::py_spec) + .def("_recv", &ENVPOOL::PyRecv) + .def("_send", &ENVPOOL::PySend) + .def("_reset", &ENVPOOL::PyReset) + .def("_render", &ENVPOOL::PyRender) + .def_readonly_static("_state_keys", &ENVPOOL::py_state_keys) + .def_readonly_static("_action_keys", &ENVPOOL::py_action_keys) + .def("_xla", &ENVPOOL::Xla); +} + +} // namespace + +PYBIND11_MODULE(myosuite_envpool, m) { + RegisterMyoSuite(m, "_MyoSuiteEnvSpec", + "_MyoSuiteEnvPool"); + RegisterMyoSuite( + m, "_MyoSuitePixelEnvSpec", "_MyoSuitePixelEnvPool"); +} diff --git a/envpool/mujoco/myosuite/myosuite_oracle_align_test.py b/envpool/mujoco/myosuite/myosuite_oracle_align_test.py new file mode 100644 index 000000000..5340f30a3 --- /dev/null +++ b/envpool/mujoco/myosuite/myosuite_oracle_align_test.py @@ -0,0 +1,533 @@ +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Oracle coverage checks for native MyoSuite envs.""" + +from __future__ import annotations + +import importlib +import json +import os +import platform +import subprocess +import sys +import tempfile +import unittest +from pathlib import Path +from typing import Any, cast + +import numpy as np +from absl import logging +from absl.testing import absltest + +from envpool.mujoco.myosuite.tasks import ( + MYOSUITE_ORACLE_NUMPY2_BROKEN_IDS, + MYOSUITE_ORACLE_VERSION, + MYOSUITE_TASKS, + MyoSuiteTask, +) +from envpool.registration import make_gymnasium, make_spec + +importlib.import_module("envpool.mujoco.myosuite.registration") + +_ROLLOUT_STEPS = 128 +_ORACLE_SPACE_BATCH_SIZE = 64 +_ROLLOUT_BATCH_SIZE = 4 +# Keep the expensive 128-step oracle traces to a diagonal sample across +# orthogonal task modifiers. Full registry/space/render coverage still checks +# every official ID; this set keeps muscle-condition/player-side combinations +# from growing as a cartesian product. +_ROLLOUT_TASK_IDS = frozenset({ + "MyoHandAirplaneFixed-v0", + "MyoHandAirplaneFly-v0", + "myoFingerReachFixed-v0", + "myoFingerPoseFixed-v0", + "myoHandReachFixed-v0", + "myoHandPoseFixed-v0", + "myoHandKeyTurnFixed-v0", + "myoHandObjHoldFixed-v0", + "myoHandPenTwirlFixed-v0", + "myoHandReorient8-v0", + "myoLegStandRandom-v0", + "myoLegWalk-v0", + "myoLegRoughTerrainWalk-v0", + "myoChallengeBaodingP1-v1", + "myoChallengeBimanual-v0", + "myoChallengeChaseTagP1-v0", + "myoChallengeDieReorientP1-v0", + "myoChallengeOslRunFixed-v0", + "myoChallengeRelocateP1-v0", + "myoChallengeSoccerP1-v0", + "myoChallengeTableTennisP0-v0", + "myoFatiChallengeBimanual-v0", + "myoSarcChallengeSoccerP2-v0", +}) +_BITWISE_ROLLOUT_TASK_IDS = frozenset({ + "myoFingerReachFixed-v0", + "myoFingerPoseFixed-v0", +}) +_LINUX_AARCH64_FINGER_ROLLOUT_RTOL = 1e-5 +_LINUX_AARCH64_FINGER_ROLLOUT_ATOL = 1e-7 +_EXPECTED_ORACLE_NUMPY2_BROKEN_IDS: frozenset[str] = frozenset() +_SYNC_STATE_KEYS = ( + "qpos0", + "qvel0", + "act0", + "qacc0", + "qacc_warmstart0", + "ctrl", + "site_pos", + "site_quat", + "site_size", + "site_rgba", + "body_pos", + "body_quat", + "body_mass", + "geom_pos", + "geom_quat", + "geom_size", + "geom_rgba", + "geom_friction", + "geom_aabb", + "geom_rbound", + "geom_contype", + "geom_conaffinity", + "geom_type", + "geom_condim", + "hfield_data", + "mocap_pos", + "mocap_quat", + "fatigue_ma", + "fatigue_mr", + "fatigue_mf", + "fatigue_tl", +) +_SYNC_STATE_SIZES = { + "qpos0": "nq", + "qvel0": "nv", + "act0": "na", + "qacc0": "nv", + "qacc_warmstart0": "nv", + "ctrl": "nu", + "site_pos": "nsite3", + "site_quat": "nsite4", + "site_size": "nsite3", + "site_rgba": "nsite4", + "body_pos": "nbody3", + "body_quat": "nbody4", + "body_mass": "nbody", + "geom_pos": "ngeom3", + "geom_quat": "ngeom4", + "geom_size": "ngeom3", + "geom_rgba": "ngeom4", + "geom_friction": "ngeom3", + "geom_aabb": "ngeom6", + "geom_rbound": "ngeom", + "geom_contype": "ngeom", + "geom_conaffinity": "ngeom", + "geom_type": "ngeom", + "geom_condim": "ngeom", + "hfield_data": "nhfielddata", + "mocap_pos": "nmocap3", + "mocap_quat": "nmocap4", + "fatigue_ma": "nu", + "fatigue_mr": "nu", + "fatigue_mf": "nu", + "fatigue_tl": "nu", +} + + +def _assert_bitwise_rollout_obs( + actual: np.ndarray, + desired: np.ndarray, + *, + label: str, +) -> None: + if sys.platform.startswith("linux") and platform.machine().lower() in { + "aarch64", + "arm64", + }: + # Linux aarch64 accumulates small float32 differences in these long + # MuJoCo finger traces after tens of steps. Keep the residual scoped to + # that platform and far below a semantically meaningful trajectory drift. + try: + np.testing.assert_allclose( + actual, + desired, + rtol=_LINUX_AARCH64_FINGER_ROLLOUT_RTOL, + atol=_LINUX_AARCH64_FINGER_ROLLOUT_ATOL, + ) + except AssertionError as exc: + raise AssertionError(f"{label}\n{exc}") from exc + return + np.testing.assert_array_equal(actual, desired, err_msg=label) + + +def _oracle_task_ids() -> tuple[str, ...]: + return tuple(task["id"] for task in MYOSUITE_TASKS) + + +def _shard_task_ids(task_ids: tuple[str, ...]) -> tuple[str, ...]: + total_shards = int( + os.environ.get( + "MYOSUITE_ORACLE_TOTAL_SHARDS", + os.environ.get("TEST_TOTAL_SHARDS", "1"), + ) + ) + shard_index = int( + os.environ.get( + "MYOSUITE_ORACLE_SHARD_INDEX", + os.environ.get("TEST_SHARD_INDEX", "0"), + ) + ) + shard_status_file = os.environ.get("TEST_SHARD_STATUS_FILE") + if shard_status_file: + Path(shard_status_file).touch() + if total_shards <= 1: + return task_ids + if shard_index < 0 or shard_index >= total_shards: + raise ValueError(f"invalid Bazel shard {shard_index} of {total_shards}") + return tuple( + task_id + for index, task_id in enumerate(task_ids) + if index % total_shards == shard_index + ) + + +def _oracle_rollout_task_ids() -> tuple[str, ...]: + return _shard_task_ids( + tuple( + task_id + for task_id in _oracle_task_ids() + if task_id in _ROLLOUT_TASK_IDS + ) + ) + + +def _task_batches( + task_ids: tuple[str, ...], + batch_size: int, +) -> tuple[tuple[str, ...], ...]: + return tuple( + task_ids[start : start + batch_size] + for start in range(0, len(task_ids), batch_size) + ) + + +def _task_metadata_by_id() -> dict[str, MyoSuiteTask]: + return {task["id"]: task for task in MYOSUITE_TASKS} + + +def _oracle_probe_path() -> Path: + runfiles = Path(os.environ["TEST_SRCDIR"]) + workspace = os.environ.get("TEST_WORKSPACE", "envpool") + launcher_names: tuple[str, ...] = ( + "myosuite_oracle_probe", + "myosuite_oracle_probe.exe", + ) + logical_suffixes = ( + tuple(f"envpool/mujoco/{launcher}" for launcher in launcher_names) + + launcher_names + ) + manifest = os.environ.get("RUNFILES_MANIFEST_FILE") + if manifest: + with Path(manifest).open(encoding="utf-8") as f: + for line in f: + logical, _, physical = line.rstrip("\n").partition(" ") + logical = logical.replace("\\", "/") + if any(logical.endswith(suffix) for suffix in logical_suffixes): + candidate = Path(physical or logical) + if candidate.is_file(): + return candidate + candidates = [ + runfiles / workspace / "envpool/mujoco" / launcher + for launcher in launcher_names + ] + if sys.platform == "win32": + candidates.extend( + runfiles.parent / launcher for launcher in launcher_names + ) + for candidate in candidates: + if candidate.is_file(): + return candidate + for launcher in launcher_names: + for match in runfiles.rglob(launcher): + if match.is_file(): + return match + raise RuntimeError( + f"could not locate myosuite_oracle_probe under {runfiles}" + ) + + +def _oracle_probe_cmd() -> list[str]: + path = _oracle_probe_path() + if sys.platform == "win32" and path.suffix.lower() != ".exe": + return [sys.executable, str(path)] + return [str(path)] + + +def _run_oracle_probe( + mode: str, + task_ids: tuple[str, ...] = (), + steps: int = _ROLLOUT_STEPS, + sync_states: dict[str, dict[str, Any]] | None = None, +) -> dict[str, Any]: + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as out: + out_path = Path(out.name) + sync_path: Path | None = None + cmd = _oracle_probe_cmd() + [ + "--mode", + mode, + "--out", + str(out_path), + "--steps", + str(steps), + "--seed", + "5", + ] + for task_id in task_ids: + cmd.extend(["--task_id", task_id]) + if sync_states is not None: + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as sync: + sync_path = Path(sync.name) + sync_path.write_text(json.dumps(sync_states, sort_keys=True)) + cmd.extend(["--sync_state", str(sync_path)]) + env = os.environ.copy() + env["ROBOHIVE_VERBOSITY"] = "SILENT" + try: + try: + result = subprocess.run( + cmd, + check=False, + capture_output=True, + env=env, + text=True, + ) + except OSError as exc: + raise RuntimeError( + f"MyoSuite oracle probe failed to start\ncmd: {' '.join(cmd)}" + ) from exc + if result.returncode != 0: + raise RuntimeError( + "MyoSuite oracle probe failed\n" + f"cmd: {' '.join(cmd)}\n" + f"stdout:\n{result.stdout}\n" + f"stderr:\n{result.stderr}" + ) + return cast(dict[str, Any], json.loads(out_path.read_text())) + finally: + out_path.unlink(missing_ok=True) + if sync_path is not None: + sync_path.unlink(missing_ok=True) + + +def _run_oracle_space_reports( + task_ids: tuple[str, ...], +) -> dict[str, dict[str, Any]]: + tasks: dict[str, dict[str, Any]] = {} + for start in range(0, len(task_ids), _ORACLE_SPACE_BATCH_SIZE): + batch = task_ids[start : start + _ORACLE_SPACE_BATCH_SIZE] + report = _run_oracle_probe("space", batch) + if report["version"] != MYOSUITE_ORACLE_VERSION: + raise AssertionError(report["version"]) + tasks.update(cast(dict[str, dict[str, Any]], report["tasks"])) + return tasks + + +def _sync_state_from_info(info: dict[str, Any]) -> dict[str, Any]: + dims = { + "nq": int(np.asarray(info["model_nq"]).ravel()[0]), + "nv": int(np.asarray(info["model_nv"]).ravel()[0]), + "na": int(np.asarray(info["model_na"]).ravel()[0]), + "nu": int(np.asarray(info["model_nu"]).ravel()[0]), + "nsite": int(np.asarray(info["model_nsite"]).ravel()[0]), + "nbody": int(np.asarray(info["model_nbody"]).ravel()[0]), + "ngeom": int(np.asarray(info["model_ngeom"]).ravel()[0]), + "nhfielddata": int(np.asarray(info["model_nhfielddata"]).ravel()[0]), + "nmocap": int(np.asarray(info["model_nmocap"]).ravel()[0]), + } + dims.update({ + "nsite3": dims["nsite"] * 3, + "nsite4": dims["nsite"] * 4, + "nbody3": dims["nbody"] * 3, + "nbody4": dims["nbody"] * 4, + "ngeom3": dims["ngeom"] * 3, + "ngeom4": dims["ngeom"] * 4, + "ngeom6": dims["ngeom"] * 6, + "nmocap3": dims["nmocap"] * 3, + "nmocap4": dims["nmocap"] * 4, + }) + sync_state = {} + for key in _SYNC_STATE_KEYS: + if key not in info: + continue + size = dims[_SYNC_STATE_SIZES[key]] + sync_state[key] = ( + np.asarray(info[key][0], dtype=np.float64).ravel()[:size].tolist() + ) + return sync_state + + +class MyoSuiteOracleAlignTest(absltest.TestCase): + """Validate native MyoSuite coverage against the pinned oracle surface.""" + + def test_no_numpy2_oracle_failures_are_excluded(self) -> None: + """Every pinned upstream ID is instantiable by the oracle.""" + self.assertSetEqual( + MYOSUITE_ORACLE_NUMPY2_BROKEN_IDS, + _EXPECTED_ORACLE_NUMPY2_BROKEN_IDS, + ) + self.assertEmpty(MYOSUITE_ORACLE_NUMPY2_BROKEN_IDS) + + def test_pinned_official_registry_coverage(self) -> None: + """Every pinned upstream registry ID must be represented locally.""" + report = _run_oracle_probe("space") + self.assertEqual(report["version"], MYOSUITE_ORACLE_VERSION) + official_ids = tuple(cast(list[str], report["ids"])) + envpool_ids = tuple(task["id"] for task in MYOSUITE_TASKS) + self.assertEqual(official_ids, envpool_ids) + self.assertLen(official_ids, 398) + + def test_oracle_space_coverage(self) -> None: + """Native spaces must match every official oracle env.""" + task_ids = _shard_task_ids(_oracle_task_ids()) + oracle_tasks = _run_oracle_space_reports(task_ids) + self.assertLen(oracle_tasks, len(task_ids)) + task_metadata = _task_metadata_by_id() + for task_id, oracle_task in oracle_tasks.items(): + with self.subTest(task_id=task_id): + task = task_metadata[task_id] + envpool_spec = make_spec(task_id) + self.assertEqual( + tuple(oracle_task["observation_shape"]), + (task["obs_dim"],), + ) + self.assertEqual( + tuple(oracle_task["action_shape"]), + (task["action_dim"],), + ) + self.assertEqual( + oracle_task["max_episode_steps"], + task["max_episode_steps"], + ) + self.assertEqual( + envpool_spec.observation_space.shape, + (task["obs_dim"],), + ) + self.assertEqual( + envpool_spec.action_space.shape, + (task["action_dim"],), + ) + self.assertEqual( + envpool_spec.config.max_episode_steps, + task["max_episode_steps"], + ) + + def test_oracle_rollout_surface(self) -> None: + """Exercise nontrivial rollouts with oracle-generated actions.""" + rollout_task_ids = _oracle_rollout_task_ids() + task_metadata = _task_metadata_by_id() + for batch in _task_batches(rollout_task_ids, _ROLLOUT_BATCH_SIZE): + envpools: dict[str, Any] = {} + envpool_reset_obs: dict[str, np.ndarray] = {} + sync_states: dict[str, dict[str, Any]] = {} + try: + for task_id in batch: + envpool = make_gymnasium(task_id, num_envs=1, seed=5) + envpool_obs, info = envpool.reset() + envpools[task_id] = envpool + envpool_reset_obs[task_id] = envpool_obs + sync_states[task_id] = _sync_state_from_info(info) + + report = _run_oracle_probe( + "trace", batch, sync_states=sync_states + ) + oracle_tasks = cast(dict[str, dict[str, Any]], report["tasks"]) + self.assertSetEqual(set(oracle_tasks), set(batch)) + + for task_id in batch: + task = task_metadata[task_id] + envpool = envpools[task_id] + envpool_obs = envpool_reset_obs[task_id] + oracle_task = oracle_tasks[task_id] + with self.subTest(task_id=task_id): + self.assertLen(oracle_task["obs"], _ROLLOUT_STEPS + 1) + self.assertLen(oracle_task["actions"], _ROLLOUT_STEPS) + self.assertEqual( + envpool_obs.shape, (1, task["obs_dim"]) + ) + if task_id in _BITWISE_ROLLOUT_TASK_IDS: + _assert_bitwise_rollout_obs( + envpool_obs[0].astype(np.float32), + np.asarray( + oracle_task["obs"][0], dtype=np.float32 + ), + label=f"{task_id} reset obs", + ) + + for step_id, action in enumerate( + oracle_task["actions"] + ): + action = np.asarray(action, dtype=np.float32) + envpool_step = envpool.step(action[None, :]) + self.assertEqual( + envpool_step[0].shape, (1, task["obs_dim"]) + ) + self.assertEqual(envpool_step[1].shape, (1,)) + self.assertEqual(envpool_step[2].shape, (1,)) + self.assertEqual(envpool_step[3].shape, (1,)) + if task_id in _BITWISE_ROLLOUT_TASK_IDS: + oracle_obs = np.asarray( + oracle_task["obs"][step_id + 1], + dtype=np.float32, + ) + oracle_reward = np.asarray( + oracle_task["rewards"][step_id], + dtype=np.float32, + ) + _assert_bitwise_rollout_obs( + envpool_step[0][0].astype(np.float32), + oracle_obs, + label=f"{task_id} step {step_id} obs", + ) + self.assertEqual( + float(envpool_step[1][0]), + float(oracle_reward), + msg=f"{task_id} step {step_id} reward", + ) + self.assertEqual( + bool(envpool_step[2][0]), + bool(oracle_task["terminated"][step_id]), + msg=f"{task_id} step {step_id} terminated", + ) + self.assertEqual( + bool(envpool_step[3][0]), + bool(oracle_task["truncated"][step_id]), + msg=f"{task_id} step {step_id} truncated", + ) + if bool(oracle_task["terminated"][step_id]) or bool( + oracle_task["truncated"][step_id] + ): + break + finally: + for envpool in envpools.values(): + try: + envpool.close() + except Exception as exc: + logging.warning( + "ignored MyoSuite env close failure: %s", exc + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/envpool/mujoco/myosuite/myosuite_oracle_probe.py b/envpool/mujoco/myosuite/myosuite_oracle_probe.py new file mode 100644 index 000000000..f870db903 --- /dev/null +++ b/envpool/mujoco/myosuite/myosuite_oracle_probe.py @@ -0,0 +1,1217 @@ +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pinned official MyoSuite oracle helper. + +This binary is used only by tests. It intentionally runs in a separate Python +process from EnvPool so the official MyoSuite dependencies can stay pinned to +the upstream v2.11.6 contract without replacing EnvPool's normal runtime deps. +""" + +from __future__ import annotations + +import argparse +import atexit +import ctypes +import importlib +import importlib.util +import json +import os +import platform +import shutil +import sys +import tempfile +import warnings +from pathlib import Path +from typing import Any, ClassVar + +# MyoSuite projects normalized muscle actions through np.exp(float32). NumPy's +# optional x86 SIMD kernels and its scalar kernel differ by single-ULP amounts, +# so pin the oracle helper to the portable baseline before NumPy is imported. +_NUMPY_X86_BASELINE_FEATURE_MASK = ( + "AVX", + "AVX2", + "FMA3", + "F16C", + "SSE42", + "SSE41", + "POPCNT", + "SSSE3", + "AVX512F", + "AVX512CD", + "AVX512_SKX", + "AVX512_CLX", + "AVX512_CNL", + "AVX512_ICL", + "AVX512_SPR", +) +if platform.machine().lower() in {"amd64", "x86_64"}: + os.environ.setdefault( + "NPY_DISABLE_CPU_FEATURES", + ",".join(_NUMPY_X86_BASELINE_FEATURE_MASK), + ) + +import numpy as np + +from envpool.python.glfw_context import preload_windows_gl_dlls + +if platform.system() == "Windows": + preload_windows_gl_dlls(strict=True) + +_CGL_FIRST_FRAME_SETTLE_PASSES = 4 + + +def _runfiles_root() -> Path: + path = Path(__file__).absolute() + for parent in (path, *path.parents): + if parent.name.endswith(".runfiles"): + return parent + path = Path(__file__).resolve() + runfiles_dir = os.environ.get("RUNFILES_DIR") + if runfiles_dir: + return Path(runfiles_dir) + if "TEST_SRCDIR" in os.environ: + return Path(os.environ["TEST_SRCDIR"]) + return path.parents[3] + + +def _runfiles_manifests(runfiles: Path) -> tuple[Path, ...]: + manifests = [] + env_manifest = os.environ.get("RUNFILES_MANIFEST_FILE") + if env_manifest: + manifests.append(Path(env_manifest)) + manifests.extend([ + runfiles / "MANIFEST", + runfiles.parent / f"{runfiles.name}_manifest", + ]) + + unique_manifests = [] + seen = set() + for manifest in manifests: + key = os.fspath(manifest) + if key not in seen: + unique_manifests.append(manifest) + seen.add(key) + return tuple(unique_manifests) + + +def _mujoco_shared_lib_name() -> str | None: + system = platform.system() + if system == "Darwin": + return "libmujoco.3.6.0.dylib" + if system == "Windows": + return "mujoco.dll" + return None + + +def _bazel_mujoco_shared_lib_path() -> Path: + shared_lib = _mujoco_shared_lib_name() + if shared_lib is None: + raise RuntimeError( + f"no Bazel-built MuJoCo shared library for {platform.system()}" + ) + runfiles = _runfiles_root() + workspace = os.environ.get("TEST_WORKSPACE", "envpool") + manifest_keys = ( + f"mujoco/{shared_lib}", + f"{workspace}/external/mujoco/{shared_lib}", + ) + for manifest in _runfiles_manifests(runfiles): + if not manifest.is_file(): + continue + with manifest.open(encoding="utf-8") as f: + for line in f: + logical_path, _, real_path = line.rstrip("\n").partition(" ") + if logical_path not in manifest_keys: + continue + candidate = Path(real_path) + if ( + candidate.is_file() + and "site-packages" not in candidate.parts + ): + return candidate + + candidates = ( + runfiles / "mujoco" / shared_lib, + runfiles / workspace / "external" / "mujoco" / shared_lib, + ) + for candidate in candidates: + if candidate.is_file(): + return candidate + for candidate in runfiles.rglob(shared_lib): + if candidate.is_file() and "site-packages" not in candidate.parts: + return candidate + raise RuntimeError( + f"could not locate Bazel-built {shared_lib} under {runfiles}" + ) + + +def _configure_mujoco_package_shared_lib() -> None: + """Make the pinned oracle import use EnvPool's Bazel-built MuJoCo lib. + + Linux uses the pinned pip MuJoCo wheel directly. Replacing or preloading the + package library there corrupts the Python binding's model-name reads in + MuJoCo 3.6.0, while the pip wheel already works with the EGL render path. + """ + shared_lib = _mujoco_shared_lib_name() + if shared_lib is None or getattr( + _configure_mujoco_package_shared_lib, "_configured", False + ): + return + + spec = importlib.util.find_spec("mujoco") + if spec is None or spec.submodule_search_locations is None: + raise RuntimeError("could not locate pinned mujoco Python package") + package_dir = Path(next(iter(spec.submodule_search_locations))) + if not (package_dir / "__init__.py").is_file(): + raise RuntimeError(f"invalid mujoco package path: {package_dir}") + + patched_root = Path(tempfile.mkdtemp(prefix="mujoco-oracle-")) + atexit.register(shutil.rmtree, patched_root, ignore_errors=True) + patched_package = patched_root / "mujoco" + shutil.copytree(package_dir, patched_package, symlinks=False) + shutil.copy2(_bazel_mujoco_shared_lib_path(), patched_package / shared_lib) + sys.path.insert(0, str(patched_root)) + _configure_mujoco_package_shared_lib._configured = True # type: ignore[attr-defined] + + +def _configure_macos_mujoco_renderer() -> None: + """Use MuJoCo's default CGL pixel format with EnvPool's lock lifecycle.""" + if platform.system() != "Darwin": + return + + import mujoco + from mujoco import cgl as mujoco_cgl + from mujoco import gl_context + from mujoco.cgl import cgl + from mujoco.rendering.classic import renderer as classic_renderer + + class _CglContext: + def __init__(self, width: int, height: int) -> None: + del width, height + self._pixel_format: Any = None + self._context: Any = None + self._locked = False + attrib = cgl.CGLPixelFormatAttribute + profile = cgl.CGLOpenGLProfile + preferred_attribs = ( + attrib.CGLPFAOpenGLProfile, + profile.CGLOGLPVersion_Legacy, + attrib.CGLPFAColorSize, + 24, + attrib.CGLPFAAlphaSize, + 8, + attrib.CGLPFADepthSize, + 24, + attrib.CGLPFAStencilSize, + 8, + attrib.CGLPFAMultisample, + attrib.CGLPFASampleBuffers, + 1, + attrib.CGLPFASample, + 4, + attrib.CGLPFAAccelerated, + 0, # terminator + ) + offline_attribs = ( + attrib.CGLPFAOpenGLProfile, + profile.CGLOGLPVersion_Legacy, + attrib.CGLPFAColorSize, + 24, + attrib.CGLPFAAlphaSize, + 8, + attrib.CGLPFADepthSize, + 24, + attrib.CGLPFAStencilSize, + 8, + attrib.CGLPFAAllowOfflineRenderers, + 0, # terminator + ) + + if not self._choose_pixel_format( + cgl, preferred_attribs + ) and not self._choose_pixel_format(cgl, offline_attribs): + raise RuntimeError("failed to create CGL pixel format") + + self._context = cgl.CGLContextObj() + cgl.CGLCreateContext( + self._pixel_format, + 0, + ctypes.byref(self._context), + ) + if not self._context: + cgl.CGLReleasePixelFormat(self._pixel_format) + self._pixel_format = None + raise RuntimeError("failed to create CGL context") + + def _choose_pixel_format( + self, cgl: Any, attrib_values: tuple[int, ...] + ) -> bool: + attribs = (ctypes.c_int * len(attrib_values))(*attrib_values) + pixel_format = cgl.CGLPixelFormatObj() + num_pixel_formats = cgl.GLint() + try: + cgl.CGLChoosePixelFormat( + attribs, + ctypes.byref(pixel_format), + ctypes.byref(num_pixel_formats), + ) + except cgl.CGLError: + return False + if not pixel_format or num_pixel_formats.value == 0: + return False + self._pixel_format = pixel_format + return True + + def make_current(self) -> None: + cgl.CGLSetCurrentContext(self._context) + if not self._locked: + cgl.CGLLockContext(self._context) + self._locked = True + + def free(self) -> None: + if self._context: + if self._locked: + cgl.CGLUnlockContext(self._context) + self._locked = False + cgl.CGLSetCurrentContext(None) + cgl.CGLReleaseContext(self._context) + self._context = None + if self._pixel_format: + cgl.CGLReleasePixelFormat(self._pixel_format) + self._pixel_format = None + + def __del__(self) -> None: + self.free() + + gl_context.GLContext = _CglContext + mujoco.gl_context.GLContext = _CglContext + mujoco_cgl.GLContext = _CglContext + classic_renderer.gl_context.GLContext = _CglContext + + +def _configure_windows_mujoco_renderer() -> None: + """Match EnvPool's native WGL context for official render-oracle tests.""" + if platform.system() != "Windows" or getattr( + _configure_windows_mujoco_renderer, "_configured", False + ): + return + + import mujoco + from mujoco import gl_context + from mujoco import glfw as mujoco_glfw + from mujoco.rendering.classic import gl_context as classic_gl_context + from mujoco.rendering.classic import renderer as classic_renderer + + ctypes_attrs = vars(ctypes) + wintypes = importlib.import_module("ctypes.wintypes") + windll = ctypes_attrs["WinDLL"] + winfunctype = ctypes_attrs["WINFUNCTYPE"] + win_error = ctypes_attrs["WinError"] + get_last_error = ctypes_attrs["get_last_error"] + kernel32 = windll("kernel32", use_last_error=True) + user32 = windll("user32", use_last_error=True) + gdi32 = windll("gdi32", use_last_error=True) + opengl32 = windll("opengl32", use_last_error=True) + + lresult = getattr(wintypes, "LRESULT", ctypes.c_ssize_t) + hcursor = vars(wintypes).get("HCURSOR", wintypes.HANDLE) + wndproc = winfunctype( + lresult, + wintypes.HWND, + wintypes.UINT, + wintypes.WPARAM, + wintypes.LPARAM, + ) + user32.DefWindowProcW.argtypes = [ + wintypes.HWND, + wintypes.UINT, + wintypes.WPARAM, + wintypes.LPARAM, + ] + user32.DefWindowProcW.restype = lresult + window_proc = wndproc(user32.DefWindowProcW) + + class _WndClass(ctypes.Structure): + _fields_: ClassVar[Any] = [ + ("style", wintypes.UINT), + ("lpfnWndProc", wndproc), + ("cbClsExtra", ctypes.c_int), + ("cbWndExtra", ctypes.c_int), + ("hInstance", wintypes.HINSTANCE), + ("hIcon", wintypes.HICON), + ("hCursor", hcursor), + ("hbrBackground", wintypes.HBRUSH), + ("lpszMenuName", wintypes.LPCWSTR), + ("lpszClassName", wintypes.LPCWSTR), + ] + + class _PixelFormatDescriptor(ctypes.Structure): + _fields_: ClassVar[Any] = [ + ("nSize", wintypes.WORD), + ("nVersion", wintypes.WORD), + ("dwFlags", wintypes.DWORD), + ("iPixelType", ctypes.c_ubyte), + ("cColorBits", ctypes.c_ubyte), + ("cRedBits", ctypes.c_ubyte), + ("cRedShift", ctypes.c_ubyte), + ("cGreenBits", ctypes.c_ubyte), + ("cGreenShift", ctypes.c_ubyte), + ("cBlueBits", ctypes.c_ubyte), + ("cBlueShift", ctypes.c_ubyte), + ("cAlphaBits", ctypes.c_ubyte), + ("cAlphaShift", ctypes.c_ubyte), + ("cAccumBits", ctypes.c_ubyte), + ("cAccumRedBits", ctypes.c_ubyte), + ("cAccumGreenBits", ctypes.c_ubyte), + ("cAccumBlueBits", ctypes.c_ubyte), + ("cAccumAlphaBits", ctypes.c_ubyte), + ("cDepthBits", ctypes.c_ubyte), + ("cStencilBits", ctypes.c_ubyte), + ("cAuxBuffers", ctypes.c_ubyte), + ("iLayerType", ctypes.c_ubyte), + ("bReserved", ctypes.c_ubyte), + ("dwLayerMask", wintypes.DWORD), + ("dwVisibleMask", wintypes.DWORD), + ("dwDamageMask", wintypes.DWORD), + ] + + kernel32.GetModuleHandleW.argtypes = [wintypes.LPCWSTR] + kernel32.GetModuleHandleW.restype = wintypes.HMODULE + user32.RegisterClassW.argtypes = [ctypes.POINTER(_WndClass)] + user32.RegisterClassW.restype = wintypes.ATOM + user32.CreateWindowExW.argtypes = [ + wintypes.DWORD, + wintypes.LPCWSTR, + wintypes.LPCWSTR, + wintypes.DWORD, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + wintypes.HWND, + wintypes.HMENU, + wintypes.HINSTANCE, + wintypes.LPVOID, + ] + user32.CreateWindowExW.restype = wintypes.HWND + user32.GetDC.argtypes = [wintypes.HWND] + user32.GetDC.restype = wintypes.HDC + user32.ReleaseDC.argtypes = [wintypes.HWND, wintypes.HDC] + user32.ReleaseDC.restype = ctypes.c_int + user32.DestroyWindow.argtypes = [wintypes.HWND] + user32.DestroyWindow.restype = wintypes.BOOL + gdi32.ChoosePixelFormat.argtypes = [ + wintypes.HDC, + ctypes.POINTER(_PixelFormatDescriptor), + ] + gdi32.ChoosePixelFormat.restype = ctypes.c_int + gdi32.SetPixelFormat.argtypes = [ + wintypes.HDC, + ctypes.c_int, + ctypes.POINTER(_PixelFormatDescriptor), + ] + gdi32.SetPixelFormat.restype = wintypes.BOOL + opengl32.wglCreateContext.argtypes = [wintypes.HDC] + opengl32.wglCreateContext.restype = ctypes.c_void_p + opengl32.wglMakeCurrent.argtypes = [wintypes.HDC, ctypes.c_void_p] + opengl32.wglMakeCurrent.restype = wintypes.BOOL + opengl32.wglDeleteContext.argtypes = [ctypes.c_void_p] + opengl32.wglDeleteContext.restype = wintypes.BOOL + + class _WglContext: + _class_name = "EnvPoolMyoSuiteOracleOffscreen" + _window_proc = window_proc + _registered = False + + def __init__(self, width: int, height: int) -> None: + del width, height + self._window = None + self._device_context = None + self._context = None + self._ensure_window_class() + self._window = user32.CreateWindowExW( + 0, + self._class_name, + "EnvPool MyoSuite Oracle Offscreen", + 0x00CF0000, # WS_OVERLAPPEDWINDOW + 0, + 0, + 1, + 1, + None, + None, + kernel32.GetModuleHandleW(None), + None, + ) + if not self._window: + raise win_error(get_last_error()) + self._device_context = user32.GetDC(self._window) + if not self._device_context: + self.free() + raise win_error(get_last_error()) + pixel_format = _PixelFormatDescriptor() + pixel_format.nSize = ctypes.sizeof(_PixelFormatDescriptor) + pixel_format.nVersion = 1 + pixel_format.dwFlags = 0x00000004 | 0x00000020 + pixel_format.iPixelType = 0 + pixel_format.cColorBits = 24 + pixel_format.cAlphaBits = 8 + pixel_format.cDepthBits = 24 + pixel_format.cStencilBits = 8 + pixel_format.iLayerType = 0 + format_id = gdi32.ChoosePixelFormat( + self._device_context, ctypes.byref(pixel_format) + ) + if format_id == 0 or not gdi32.SetPixelFormat( + self._device_context, format_id, ctypes.byref(pixel_format) + ): + self.free() + raise win_error(get_last_error()) + self._context = opengl32.wglCreateContext(self._device_context) + if not self._context: + self.free() + raise win_error(get_last_error()) + + @classmethod + def _ensure_window_class(cls) -> None: + if cls._registered: + return + window_class = _WndClass() + window_class.style = 0x0020 # CS_OWNDC + window_class.lpfnWndProc = cls._window_proc + window_class.hInstance = kernel32.GetModuleHandleW(None) + window_class.lpszClassName = cls._class_name + if not user32.RegisterClassW(ctypes.byref(window_class)): + error = get_last_error() + if error != 1410: # ERROR_CLASS_ALREADY_EXISTS + raise win_error(error) + cls._registered = True + + def make_current(self) -> None: + if not opengl32.wglMakeCurrent(self._device_context, self._context): + raise win_error(get_last_error()) + + def free(self) -> None: + if self._context: + opengl32.wglMakeCurrent(None, None) + opengl32.wglDeleteContext(self._context) + self._context = None + if self._window and self._device_context: + user32.ReleaseDC(self._window, self._device_context) + self._device_context = None + if self._window: + user32.DestroyWindow(self._window) + self._window = None + + def __del__(self) -> None: + self.free() + + gl_context.GLContext = _WglContext + mujoco.GLContext = _WglContext + mujoco.glfw.GLContext = _WglContext + mujoco_glfw.GLContext = _WglContext + classic_gl_context.GLContext = _WglContext + classic_renderer.GLContext = _WglContext + classic_renderer.gl_context.GLContext = _WglContext + _configure_windows_mujoco_renderer._configured = True # type: ignore[attr-defined] + + +def _configure_linux_mujoco_renderer(render: bool) -> None: + """Force the pinned oracle onto EnvPool CI's headless EGL renderer.""" + if not render or platform.system() != "Linux": + return + + os.environ["MUJOCO_GL"] = "egl" + os.environ["PYOPENGL_PLATFORM"] = "egl" + os.environ.setdefault("EGL_PLATFORM", "surfaceless") + + +def _link_or_copy_file(src: str, dst: str) -> None: + try: + os.link(src, dst) + except OSError: + shutil.copy2(src, dst) + + +def _overlay_tree( + source: Path, + destination: Path, + *, + ignore: Any = None, + prefer_directory_symlink: bool = True, +) -> None: + if prefer_directory_symlink: + try: + os.symlink(source, destination, target_is_directory=True) + return + except OSError: + # Some Bazel runfiles and Windows paths cannot be symlinked. + pass + shutil.copytree( + source, + destination, + symlinks=True, + copy_function=_link_or_copy_file, + ignore=ignore, + ) + + +def _oracle_source_path() -> Path: + runfiles = _runfiles_root() + source = runfiles / "myosuite_source/myosuite" + if not (source / "__init__.py").is_file(): + raise RuntimeError(f"could not locate MyoSuite source at {source}") + assembled = Path(tempfile.mkdtemp(prefix="myosuite-oracle-")) + atexit.register(shutil.rmtree, assembled, ignore_errors=True) + package = assembled / "myosuite" + _overlay_tree( + source, + package, + ignore=lambda _root, names: ( + {"simhive"} if "simhive" in names else set() + ), + prefer_directory_symlink=False, + ) + simhive = package / "simhive" + simhive.mkdir() + for repo, name in ( + ("myosuite_mpl_sim", "MPL_sim"), + ("myosuite_ycb_sim", "YCB_sim"), + ("myosuite_furniture_sim", "furniture_sim"), + ("myosuite_myo_sim", "myo_sim"), + ("myosuite_object_sim", "object_sim"), + ): + repo_path = runfiles / repo + if not repo_path.is_dir(): + raise RuntimeError(f"could not locate {repo_path}") + _overlay_tree(repo_path, simhive / name) + return assembled + + +def _import_official() -> tuple[Any, Any, Any]: + warnings.filterwarnings("ignore") + _configure_mujoco_package_shared_lib() + sys.path.insert(0, str(_oracle_source_path())) + _configure_macos_mujoco_renderer() + _configure_windows_mujoco_renderer() + official_myosuite = importlib.import_module("myosuite") + gym = importlib.import_module("myosuite.utils").gym + gym_registry_specs = official_myosuite.gym_registry_specs + return official_myosuite, gym_registry_specs, gym + + +def _space_report(task_ids: list[str]) -> dict[str, Any]: + official_myosuite, gym_registry_specs, gym = _import_official() + registry = gym_registry_specs() + tasks: dict[str, dict[str, Any]] = {} + for task_id in task_ids: + spec = registry[task_id] + env = gym.make(task_id) + try: + tasks[task_id] = { + "action_shape": list(env.action_space.shape), + "max_episode_steps": int(spec.max_episode_steps), + "observation_shape": list(env.observation_space.shape), + } + except Exception as exc: + raise RuntimeError(f"oracle space failed for {task_id}") from exc + finally: + env.close() + return { + "ids": list(official_myosuite.myosuite_env_suite), + "tasks": tasks, + "version": official_myosuite.__version__, + } + + +def _array(value: Any) -> np.ndarray: + return np.asarray(value) + + +def _jsonable_array(value: Any) -> Any: + if isinstance(value, dict): + return {str(key): _jsonable_array(item) for key, item in value.items()} + if isinstance(value, list | tuple): + return [_jsonable_array(item) for item in value] + if isinstance(value, np.generic): + return value.item() + array = _array(value) + if array.ndim == 0: + return array.item() + if array.dtype == object: + return [str(item) for item in array.ravel()] + return array.tolist() + + +def _names_from_ids(model: Any, obj_type: Any, ids: list[int]) -> list[str]: + import mujoco + + raw_model = model.ptr if hasattr(model, "ptr") else model + return [ + mujoco.mj_id2name(raw_model, int(obj_type), int(obj_id)) + for obj_id in ids + ] + + +def _metadata_report(task_ids: list[str]) -> dict[str, Any]: + official_myosuite, _, gym = _import_official() + import mujoco + + tasks: dict[str, dict[str, Any]] = {} + for task_id in task_ids: + env = gym.make(task_id) + try: + unwrapped = env.unwrapped + model = unwrapped.sim.model + data = unwrapped.sim.data + task: dict[str, Any] = { + "action_shape": list(env.action_space.shape), + "entry_class": type(unwrapped).__name__, + "frame_skip": int(unwrapped.frame_skip), + "init_qpos": _jsonable_array(unwrapped.init_qpos), + "init_qvel": _jsonable_array(unwrapped.init_qvel), + "model_nq": int(model.nq), + "model_nv": int(model.nv), + "model_na": int(model.na), + "model_nu": int(model.nu), + "obs_keys": list(unwrapped.obs_keys), + "observation_shape": list(env.observation_space.shape), + "rwd_keys_wt": dict(unwrapped.rwd_keys_wt), + } + for attr in ( + "far_th", + "goal_th", + "hip_period", + "max_rot", + "min_height", + "pose_thd", + "reset_type", + "target_rot", + "target_x_vel", + "target_y_vel", + "terrain", + "variant", + ): + if hasattr(unwrapped, attr): + task[attr] = _jsonable_array(getattr(unwrapped, attr)) + if hasattr(unwrapped, "tip_sids"): + task["tip_sites"] = _names_from_ids( + model, mujoco.mjtObj.mjOBJ_SITE, unwrapped.tip_sids + ) + if hasattr(unwrapped, "target_sids"): + task["target_sites"] = _names_from_ids( + model, mujoco.mjtObj.mjOBJ_SITE, unwrapped.target_sids + ) + if hasattr(unwrapped, "target_jnt_ids"): + task["target_joints"] = _names_from_ids( + model, mujoco.mjtObj.mjOBJ_JOINT, unwrapped.target_jnt_ids + ) + for attr in ( + "target_jnt_range", + "target_jnt_value", + "target_reach_range", + ): + if hasattr(unwrapped, attr): + task[attr] = _jsonable_array(getattr(unwrapped, attr)) + task["initial_state"] = { + "qpos": _jsonable_array(data.qpos), + "qvel": _jsonable_array(data.qvel), + "act": _jsonable_array(data.act) if model.na > 0 else [], + "qacc_warmstart": _jsonable_array(data.qacc_warmstart), + "site_pos": _jsonable_array(model.site_pos), + "site_quat": _jsonable_array(model.site_quat), + "body_pos": _jsonable_array(model.body_pos), + "body_quat": _jsonable_array(model.body_quat), + } + env.reset(seed=0) + task["reset_state"] = _state_report(unwrapped) + tasks[task_id] = task + finally: + env.close() + return {"tasks": tasks, "version": official_myosuite.__version__} + + +def _state_report(env: Any) -> dict[str, Any]: + model = env.sim.model + data = env.sim.data + state = { + "act": _jsonable_array(data.act) if model.na > 0 else [], + "actuator_force": _jsonable_array(data.actuator_force), + "actuator_length": _jsonable_array(data.actuator_length), + "actuator_velocity": _jsonable_array(data.actuator_velocity), + "ctrl": _jsonable_array(data.ctrl), + "geom_xpos": _jsonable_array(data.geom_xpos), + "geom_xmat": _jsonable_array(data.geom_xmat), + "geom_rgba": _jsonable_array(model.geom_rgba), + "qacc_warmstart": _jsonable_array(data.qacc_warmstart), + "body_pos": _jsonable_array(model.body_pos), + "body_quat": _jsonable_array(model.body_quat), + "light_xdir": _jsonable_array(data.light_xdir), + "light_xpos": _jsonable_array(data.light_xpos), + "mocap_pos": _jsonable_array(data.mocap_pos), + "mocap_quat": _jsonable_array(data.mocap_quat), + "qpos": _jsonable_array(data.qpos), + "qvel": _jsonable_array(data.qvel), + "site_pos": _jsonable_array(model.site_pos), + "site_quat": _jsonable_array(model.site_quat), + "site_size": _jsonable_array(model.site_size), + "site_xpos": _jsonable_array(data.site_xpos), + "site_rgba": _jsonable_array(model.site_rgba), + "time": float(data.time), + } + fatigue = getattr(env, "muscle_fatigue", None) + if fatigue is not None: + state.update({ + "fatigue_ma": _jsonable_array(fatigue._MA), + "fatigue_mr": _jsonable_array(fatigue._MR), + "fatigue_mf": _jsonable_array(fatigue._MF), + "fatigue_tl": _jsonable_array(fatigue.TL), + "fatigue_tauact": _jsonable_array(fatigue._tauact), + "fatigue_taudeact": _jsonable_array(fatigue._taudeact), + "fatigue_dt": float(fatigue._dt), + }) + return state + + +def _state_array( + state: dict[str, Any], key: str, shape: tuple[int, ...] +) -> np.ndarray | None: + value = state.get(key) + if value is None: + return None + array = np.asarray(value, dtype=np.float64) + size = int(np.prod(shape, dtype=np.int64)) + if array.size < size: + raise ValueError( + f"sync state {key} has {array.size} values, expected {size}" + ) + return array[:size].reshape(shape) + + +def _assign_sync_array( + state: dict[str, Any], key: str, target: np.ndarray +) -> None: + value = state.get(key) + if value is None: + return + array = np.asarray(value, dtype=np.float64).ravel() + target_flat = target.reshape(-1) + count = min(array.size, target_flat.size) + target_flat[:count] = array[:count] + if count < target_flat.size: + target_flat[count:] = 0.0 + + +def _assign_sync_array_if_same_size( + state: dict[str, Any], key: str, target: np.ndarray +) -> None: + value = state.get(key) + if value is None: + return + array = np.asarray(value, dtype=np.float64).ravel() + if array.size != target.size: + return + target.reshape(-1)[:] = array + + +def _sync_osl_phase_from_qpos(env: Any) -> None: + controller = getattr(env, "OSL_CTRL", None) + if controller is None: + return + model = env.sim.model + data = env.sim.data + if model.nkey < 3: + controller.reset("e_stance") + controller.start() + return + qpos = np.asarray(data.qpos, dtype=np.float64) + key_qpos = np.asarray(model.key_qpos, dtype=np.float64).reshape( + model.nkey, model.nq + ) + start = min(7, model.nq) + distances = np.sum((key_qpos[:3, start:] - qpos[start:]) ** 2, axis=1) + phase = "e_swing" if int(np.argmin(distances)) == 1 else "e_stance" + controller.reset(phase) + controller.start() + + +def _sync_baoding_goal_from_envpool_reset_state(env: Any) -> None: + if not all( + hasattr(env, attr) + for attr in ( + "ball_1_starting_angle", + "ball_2_starting_angle", + "center_pos", + "create_goal_trajectory", + "x_radius", + "y_radius", + ) + ): + return + task_type = type(getattr(env, "which_task", object())) + if hasattr(task_type, "BAODING_CCW"): + env.which_task = task_type.BAODING_CCW + env.ball_1_starting_angle = np.pi / 4.0 + env.ball_2_starting_angle = env.ball_1_starting_angle - np.pi + env.center_pos = np.array([-0.0125, -0.07], dtype=np.float64) + env.x_radius = 0.025 + env.y_radius = 0.028 + env.goal = env.create_goal_trajectory( + time_step=float(getattr(env, "dt", 0.025)), time_period=6.0 + ) + env.counter = 0 + + +def _sync_chasetag_hidden_state(env: Any) -> None: + if not all(hasattr(env, attr) for attr in ("current_task", "opponent")): + return + task_type = type(env.current_task) + if hasattr(task_type, "CHASE"): + env.current_task = task_type.CHASE + opponent = env.opponent + opponent.opponent_policy = "stationary" + opponent.opponent_vel = np.zeros((2,), dtype=np.float64) + if hasattr(opponent, "chase_velocity"): + opponent.chase_velocity = 1.0 + + +def _sync_fatigue_hidden_state(env: Any, state: dict[str, Any]) -> None: + fatigue = getattr(env, "muscle_fatigue", None) + if fatigue is None: + return + _assign_sync_array(state, "fatigue_ma", fatigue._MA) + _assign_sync_array(state, "fatigue_mr", fatigue._MR) + _assign_sync_array(state, "fatigue_mf", fatigue._MF) + _assign_sync_array(state, "fatigue_tl", fatigue.TL) + + +def _sync_to_envpool_reset_state(env: Any, state: dict[str, Any]) -> np.ndarray: + """Patch the official oracle to EnvPool's reset-time MuJoCo state once.""" + sim = env.sim + model = sim.model + data = sim.data + + _assign_sync_array(state, "site_pos", model.site_pos) + _assign_sync_array(state, "site_quat", model.site_quat) + _assign_sync_array(state, "site_size", model.site_size) + _assign_sync_array(state, "site_rgba", model.site_rgba) + _assign_sync_array(state, "body_pos", model.body_pos) + _assign_sync_array(state, "body_quat", model.body_quat) + _assign_sync_array(state, "body_mass", model.body_mass) + _assign_sync_array(state, "geom_pos", model.geom_pos) + _assign_sync_array(state, "geom_quat", model.geom_quat) + _assign_sync_array(state, "geom_size", model.geom_size) + _assign_sync_array(state, "geom_rgba", model.geom_rgba) + _assign_sync_array(state, "geom_friction", model.geom_friction) + _assign_sync_array_if_same_size(state, "geom_aabb", model.geom_aabb) + _assign_sync_array_if_same_size(state, "geom_rbound", model.geom_rbound) + _assign_sync_array_if_same_size(state, "geom_contype", model.geom_contype) + _assign_sync_array_if_same_size( + state, "geom_conaffinity", model.geom_conaffinity + ) + _assign_sync_array_if_same_size(state, "geom_type", model.geom_type) + _assign_sync_array_if_same_size(state, "geom_condim", model.geom_condim) + _assign_sync_array(state, "hfield_data", model.hfield_data) + if model.nmocap > 0: + _assign_sync_array(state, "mocap_pos", data.mocap_pos) + _assign_sync_array(state, "mocap_quat", data.mocap_quat) + + qpos = _state_array(state, "qpos0", data.qpos.shape) + qvel = _state_array(state, "qvel0", data.qvel.shape) + act = _state_array(state, "act0", data.act.shape) if model.na > 0 else None + sim.set_state(time=0.0, qpos=qpos, qvel=qvel, act=act) + + _assign_sync_array(state, "ctrl", data.ctrl) + sim.forward() + _sync_osl_phase_from_qpos(env) + _sync_baoding_goal_from_envpool_reset_state(env) + _sync_chasetag_hidden_state(env) + _sync_fatigue_hidden_state(env, state) + obs = env.get_obs() + _assign_sync_array(state, "qacc0", data.qacc) + _assign_sync_array(state, "qacc_warmstart0", data.qacc_warmstart) + if hasattr(env, "last_ctrl"): + env.last_ctrl = data.ctrl.copy() + return obs + + +def _trace_info(info: dict[str, Any]) -> dict[str, Any]: + scalar_info: dict[str, Any] = {} + for key in ("rwd_dense", "rwd_sparse", "solved", "done", "time"): + if key in info: + scalar_info[key] = _jsonable_array(info[key]) + if "rwd_dict" in info: + scalar_info["rwd_dict"] = { + key: _jsonable_array(value) + for key, value in info["rwd_dict"].items() + if np.asarray(value).size <= 16 + } + return scalar_info + + +def _render_frame(env: Any, width: int, height: int, camera_id: int) -> Any: + env.unwrapped.sim.forward() + renderer = env.unwrapped.sim.renderer + frame = renderer.render_offscreen( + width=width, + height=height, + camera_id=camera_id, + ) + if platform.system() == "Darwin" and not getattr( + renderer, "_envpool_cgl_first_render_done", False + ): + renderer._envpool_cgl_first_render_done = True + for _ in range(_CGL_FIRST_FRAME_SETTLE_PASSES): + frame = renderer.render_offscreen( + width=width, + height=height, + camera_id=camera_id, + ) + return frame + + +def _next_action( + rng: np.random.Generator, + low: np.ndarray, + high: np.ndarray, + action_mode: str, +) -> np.ndarray: + if action_mode == "random": + return rng.uniform(low, high).astype(np.float32) + if action_mode == "midpoint": + return ((low + high) * 0.5).astype(np.float32) + if action_mode == "zero": + return np.clip(np.zeros_like(low), low, high).astype(np.float32) + raise ValueError(f"unknown action mode: {action_mode}") + + +def _rollout_report( + task_ids: list[str], steps: int, seed: int, action_mode: str +) -> dict[str, Any]: + official_myosuite, _, gym = _import_official() + rng = np.random.default_rng(seed + 17) + tasks: dict[str, dict[str, Any]] = {} + for task_id in task_ids: + env = gym.make(task_id) + try: + reset = env.reset(seed=seed) + obs = reset[0] if isinstance(reset, tuple) else reset + low = _array(env.action_space.low).astype(np.float32) + high = _array(env.action_space.high).astype(np.float32) + rewards: list[float] = [] + terminals: list[bool] = [] + truncateds: list[bool] = [] + obs_checksum = [float(_array(obs).astype(np.float64).sum())] + for _ in range(steps): + action = _next_action(rng, low, high, action_mode) + step = env.step(action) + obs = step[0] + rewards.append(float(step[1])) + terminals.append(bool(step[2])) + truncateds.append(bool(step[3]) if len(step) > 4 else False) + obs_checksum.append(float(_array(obs).astype(np.float64).sum())) + tasks[task_id] = { + "obs_checksum": obs_checksum, + "rewards": rewards, + "terminated": terminals, + "truncated": truncateds, + } + finally: + env.close() + return {"tasks": tasks, "version": official_myosuite.__version__} + + +def _trace_report( + task_ids: list[str], + steps: int, + seed: int, + render: bool, + render_width: int, + render_height: int, + camera_id: int, + action_mode: str, + sync_states: dict[str, Any] | None = None, + trace_plan: dict[str, Any] | None = None, +) -> dict[str, Any]: + official_myosuite, _, gym = _import_official() + rng = np.random.default_rng(seed + 17) + tasks: dict[str, dict[str, Any]] = {} + for task_id in task_ids: + task_plan = trace_plan.get(task_id, {}) if trace_plan else {} + planned_actions = task_plan.get("actions") + planned_resets = task_plan.get("reset_before_step", []) + planned_sync_states = task_plan.get("sync_states", []) + env = gym.make(task_id) + try: + reset = env.reset(seed=seed) + obs = reset[0] if isinstance(reset, tuple) else reset + unwrapped = env.unwrapped + if planned_sync_states: + obs = _sync_to_envpool_reset_state( + unwrapped, planned_sync_states[0] + ) + elif sync_states is not None and task_id in sync_states: + obs = _sync_to_envpool_reset_state( + unwrapped, sync_states[task_id] + ) + low = _array(env.action_space.low).astype(np.float32) + high = _array(env.action_space.high).astype(np.float32) + frames: list[Any] = [] + if render: + frames.append( + _jsonable_array( + _render_frame( + env, + render_width, + render_height, + camera_id, + ) + ) + ) + trace: dict[str, Any] = { + "actions": [], + "infos": [], + "obs": [_jsonable_array(obs)], + "reset_state": _state_report(unwrapped), + "rewards": [], + "states": [], + "terminated": [], + "truncated": [], + } + trace_steps = ( + len(planned_actions) if planned_actions is not None else steps + ) + for step_id in range(trace_steps): + if planned_actions is None: + action = _next_action(rng, low, high, action_mode) + else: + action = np.asarray( + planned_actions[step_id], dtype=np.float32 + ) + reset_before_step = step_id < len(planned_resets) and bool( + planned_resets[step_id] + ) + trace["actions"].append(_jsonable_array(action)) + if reset_before_step: + reset = env.reset() + obs = reset[0] if isinstance(reset, tuple) else reset + if step_id + 1 < len(planned_sync_states): + obs = _sync_to_envpool_reset_state( + unwrapped, planned_sync_states[step_id + 1] + ) + else: + env.sim.forward() + trace["obs"].append(_jsonable_array(obs)) + trace["rewards"].append(0.0) + trace["terminated"].append(False) + trace["truncated"].append(False) + trace["infos"].append({}) + else: + step = env.step(action) + obs = step[0] + trace["obs"].append(_jsonable_array(obs)) + trace["rewards"].append(float(step[1])) + trace["terminated"].append(bool(step[2])) + trace["truncated"].append( + bool(step[3]) if len(step) > 4 else False + ) + trace["infos"].append(_trace_info(step[-1])) + state = _state_report(unwrapped) + if hasattr(unwrapped, "last_ctrl"): + state["last_ctrl"] = _jsonable_array(unwrapped.last_ctrl) + trace["states"].append(state) + if render: + frames.append( + _jsonable_array( + _render_frame( + env, + render_width, + render_height, + camera_id, + ) + ) + ) + if render: + trace["frames"] = frames + tasks[task_id] = trace + except Exception as exc: + raise RuntimeError(f"oracle trace failed for {task_id}") from exc + finally: + env.close() + return {"tasks": tasks, "version": official_myosuite.__version__} + + +def main() -> None: + """Run the requested pinned-oracle probe and write a JSON report.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", + choices=("metadata", "space", "rollout", "trace"), + required=True, + ) + parser.add_argument("--out", required=True) + parser.add_argument("--render", action="store_true") + parser.add_argument("--render_width", type=int, default=64) + parser.add_argument("--render_height", type=int, default=48) + parser.add_argument("--camera_id", type=int, default=-1) + parser.add_argument("--sync_state") + parser.add_argument("--trace_plan") + parser.add_argument( + "--action_mode", + choices=("random", "midpoint", "zero"), + default="random", + ) + parser.add_argument("--task_id", action="append", default=[]) + parser.add_argument("--steps", type=int, default=64) + parser.add_argument("--seed", type=int, default=5) + args = parser.parse_args() + _configure_linux_mujoco_renderer(args.render) + + sync_states = ( + json.loads(Path(args.sync_state).read_text()) + if args.sync_state is not None + else None + ) + trace_plan = ( + json.loads(Path(args.trace_plan).read_text()) + if args.trace_plan is not None + else None + ) + + if args.mode == "space": + report = _space_report(args.task_id) + elif args.mode == "rollout": + report = _rollout_report( + args.task_id, args.steps, args.seed, args.action_mode + ) + elif args.mode == "trace": + report = _trace_report( + args.task_id, + args.steps, + args.seed, + args.render, + args.render_width, + args.render_height, + args.camera_id, + args.action_mode, + sync_states, + trace_plan, + ) + else: + report = _metadata_report(args.task_id) + Path(args.out).write_text(json.dumps(report, sort_keys=True)) + + +if __name__ == "__main__": + main() diff --git a/envpool/mujoco/myosuite/myosuite_render_test.py b/envpool/mujoco/myosuite/myosuite_render_test.py new file mode 100644 index 000000000..03b0b3b1d --- /dev/null +++ b/envpool/mujoco/myosuite/myosuite_render_test.py @@ -0,0 +1,579 @@ +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Render smoke tests for native MyoSuite envs.""" + +from __future__ import annotations + +import importlib +import json +import os +import platform +import subprocess +import sys +import tempfile +import unittest +from pathlib import Path +from typing import Any + +import numpy as np +from absl.testing import absltest + +from envpool.python.glfw_context import preload_windows_gl_dlls + +if platform.system() == "Windows": + preload_windows_gl_dlls(strict=True) + +from envpool.mujoco.myosuite.tasks import ( + MYOSUITE_ORACLE_NUMPY2_BROKEN_IDS, + MYOSUITE_TASKS, +) +from envpool.registration import make_gymnasium + +importlib.import_module("envpool.mujoco.myosuite.registration") + +_TASK_IDS = tuple(str(task["id"]) for task in MYOSUITE_TASKS) +_TASK_ID_SET = frozenset(_TASK_IDS) +# Render traces are the expensive part of this suite. The full 398-ID surface is +# covered by registry, space, reset/step, determinism, and generated docs. Keep +# CI render checks to a diagonal set that catches wrong-camera/model/scene +# regressions without retesting the full modifier cartesian product. +_ORACLE_RENDER_REPRESENTATIVE_TASK_IDS = frozenset({ + "MyoHandAirplaneFixed-v0", + "MyoHandAirplaneFly-v0", + "MyoHandCupPour-v0", + "MyoHandHammerUse-v0", + "MyoHandWatchRandom-v0", + "motorFingerPoseFixed-v0", + "myoArmReachFixed-v0", + "myoElbowPose1D6MExoFixed-v0", + "myoFingerReachFixed-v0", + "myoFingerPoseFixed-v0", + "myoHandReachFixed-v0", + "myoHandPoseFixed-v0", + "myoHandKeyTurnFixed-v0", + "myoHandObjHoldFixed-v0", + "myoHandPenTwirlFixed-v0", + "myoHandReorient8-v0", + "myoLegStandRandom-v0", + "myoLegWalk-v0", + "myoLegRoughTerrainWalk-v0", + "myoFatiArmReachFixed-v0", + "myoFatiHandReorient8-v0", + "myoFatiLegWalk-v0", + "myoSarcArmReachFixed-v0", + "myoSarcHandReorient8-v0", + "myoSarcLegWalk-v0", + "myoChallengeBaodingP1-v1", + "myoChallengeBimanual-v0", + "myoChallengeChaseTagP1-v0", + "myoChallengeDieReorientP1-v0", + "myoChallengeOslRunFixed-v0", + "myoChallengeRelocateP1-v0", + "myoChallengeSoccerP1-v0", + "myoChallengeTableTennisP0-v0", + "myoFatiChallengeBimanual-v0", + "myoSarcChallengeSoccerP2-v0", +}) +_UNKNOWN_ORACLE_RENDER_REPRESENTATIVE_TASK_IDS = ( + _ORACLE_RENDER_REPRESENTATIVE_TASK_IDS - _TASK_ID_SET +) +if _UNKNOWN_ORACLE_RENDER_REPRESENTATIVE_TASK_IDS: + raise ValueError( + "unknown MyoSuite oracle render representatives: " + f"{sorted(_UNKNOWN_ORACLE_RENDER_REPRESENTATIVE_TASK_IDS)}" + ) + + +def _render_task_allowlist_from_env() -> tuple[str, ...] | None: + raw = os.environ.get("MYOSUITE_RENDER_TASK_IDS") + if raw is None: + return None + task_ids = tuple( + dict.fromkeys( + task_id for task_id in raw.replace(",", " ").split() if task_id + ) + ) + if not task_ids: + raise ValueError("MYOSUITE_RENDER_TASK_IDS is set but empty") + unknown = sorted(set(task_ids) - _TASK_ID_SET) + if unknown: + raise ValueError(f"unknown MYOSUITE_RENDER_TASK_IDS: {unknown}") + return task_ids + + +_RENDER_TASK_ALLOWLIST = _render_task_allowlist_from_env() + + +def _filter_render_task_ids(task_ids: tuple[str, ...]) -> tuple[str, ...]: + if _RENDER_TASK_ALLOWLIST is None: + return task_ids + return tuple( + task_id for task_id in task_ids if task_id in _RENDER_TASK_ALLOWLIST + ) + + +def _native_render_task_ids() -> tuple[str, ...]: + if _RENDER_TASK_ALLOWLIST is not None: + return _filter_render_task_ids(_TASK_IDS) + return tuple( + task_id + for task_id in _TASK_IDS + if task_id in _ORACLE_RENDER_REPRESENTATIVE_TASK_IDS + ) + + +def _oracle_trace_task_ids() -> tuple[str, ...]: + if _RENDER_TASK_ALLOWLIST is None: + return tuple( + task_id + for task_id in _TASK_IDS + if task_id in _ORACLE_RENDER_REPRESENTATIVE_TASK_IDS + and task_id not in MYOSUITE_ORACLE_NUMPY2_BROKEN_IDS + ) + return tuple( + task_id + for task_id in _filter_render_task_ids(_TASK_IDS) + if task_id not in MYOSUITE_ORACLE_NUMPY2_BROKEN_IDS + ) + + +_NATIVE_RENDER_TASK_IDS = _native_render_task_ids() +_WIDTH = 64 +_HEIGHT = 48 +_ORACLE_RENDER_BATCH_SIZE = 8 +# Catch wrong camera/model/scene regressions without chasing backend pixel noise. +_RENDER_BLOCK_SIZE = 4 +_MAX_RENDER_MEAN_ABS_DIFF = 24.0 +_MAX_RENDER_BLOCK_MEAN_ABS_DIFF = 24.0 +_MAX_RENDER_LARGE_MISMATCH_RATIO = 0.50 +_LARGE_RENDER_DELTA = 32 +_SYNC_STATE_KEYS = ( + "qpos0", + "qvel0", + "act0", + "qacc0", + "qacc_warmstart0", + "ctrl", + "site_pos", + "site_quat", + "site_size", + "site_rgba", + "body_pos", + "body_quat", + "body_mass", + "geom_pos", + "geom_quat", + "geom_size", + "geom_rgba", + "geom_friction", + "geom_aabb", + "geom_rbound", + "geom_contype", + "geom_conaffinity", + "geom_type", + "geom_condim", + "hfield_data", + "mocap_pos", + "mocap_quat", + "fatigue_ma", + "fatigue_mr", + "fatigue_mf", + "fatigue_tl", +) +_SYNC_STATE_SIZES = { + "qpos0": "nq", + "qvel0": "nv", + "act0": "na", + "qacc0": "nv", + "qacc_warmstart0": "nv", + "ctrl": "nu", + "site_pos": "nsite3", + "site_quat": "nsite4", + "site_size": "nsite3", + "site_rgba": "nsite4", + "body_pos": "nbody3", + "body_quat": "nbody4", + "body_mass": "nbody", + "geom_pos": "ngeom3", + "geom_quat": "ngeom4", + "geom_size": "ngeom3", + "geom_rgba": "ngeom4", + "geom_friction": "ngeom3", + "geom_aabb": "ngeom6", + "geom_rbound": "ngeom", + "geom_contype": "ngeom", + "geom_conaffinity": "ngeom", + "geom_type": "ngeom", + "geom_condim": "ngeom", + "hfield_data": "nhfielddata", + "mocap_pos": "nmocap3", + "mocap_quat": "nmocap4", + "fatigue_ma": "nu", + "fatigue_mr": "nu", + "fatigue_mf": "nu", + "fatigue_tl": "nu", +} + + +def _render_shard_task_ids(task_ids: tuple[str, ...]) -> tuple[str, ...]: + total_shards = int( + os.environ.get( + "MYOSUITE_RENDER_TOTAL_SHARDS", + os.environ.get("TEST_TOTAL_SHARDS", "1"), + ) + ) + shard_index = int( + os.environ.get( + "MYOSUITE_RENDER_SHARD_INDEX", + os.environ.get("TEST_SHARD_INDEX", "0"), + ) + ) + shard_status_file = os.environ.get("TEST_SHARD_STATUS_FILE") + if shard_status_file: + Path(shard_status_file).touch() + if total_shards <= 1: + return task_ids + if shard_index < 0 or shard_index >= total_shards: + raise ValueError(f"invalid Bazel shard {shard_index} of {total_shards}") + return tuple( + task_id + for index, task_id in enumerate(task_ids) + if index % total_shards == shard_index + ) + + +def _task_batches( + task_ids: tuple[str, ...], + batch_size: int, +) -> tuple[tuple[str, ...], ...]: + return tuple( + task_ids[start : start + batch_size] + for start in range(0, len(task_ids), batch_size) + ) + + +_SHARDED_ORACLE_TRACE_TASK_IDS = _render_shard_task_ids( + _oracle_trace_task_ids() +) +_SHARDED_NATIVE_RENDER_TASK_IDS = _render_shard_task_ids( + _NATIVE_RENDER_TASK_IDS +) + + +def _oracle_probe_path() -> Path: + runfiles = Path(os.environ["TEST_SRCDIR"]) + workspace = os.environ.get("TEST_WORKSPACE", "envpool") + launcher_names: tuple[str, ...] = ( + "myosuite_oracle_probe", + "myosuite_oracle_probe.exe", + ) + logical_suffixes = ( + tuple(f"envpool/mujoco/{launcher}" for launcher in launcher_names) + + launcher_names + ) + manifest = os.environ.get("RUNFILES_MANIFEST_FILE") + if manifest: + with Path(manifest).open(encoding="utf-8") as f: + for line in f: + logical, _, physical = line.rstrip("\n").partition(" ") + logical = logical.replace("\\", "/") + if any(logical.endswith(suffix) for suffix in logical_suffixes): + candidate = Path(physical or logical) + if candidate.is_file(): + return candidate + candidates = [ + runfiles / workspace / "envpool/mujoco" / launcher + for launcher in launcher_names + ] + if sys.platform == "win32": + candidates.extend( + runfiles.parent / launcher for launcher in launcher_names + ) + for candidate in candidates: + if candidate.is_file(): + return candidate + for launcher in launcher_names: + for match in runfiles.rglob(launcher): + if match.is_file(): + return match + raise RuntimeError( + f"could not locate myosuite_oracle_probe under {runfiles}" + ) + + +def _oracle_probe_cmd() -> list[str]: + path = _oracle_probe_path() + if sys.platform == "win32" and path.suffix.lower() != ".exe": + return [sys.executable, str(path)] + return [str(path)] + + +def _oracle_trace( + task_ids: tuple[str, ...], + trace_plan: dict[str, dict[str, Any]], +) -> dict[str, Any]: + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as out: + out_path = Path(out.name) + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as plan: + plan_path = Path(plan.name) + plan_path.write_text(json.dumps(trace_plan, sort_keys=True)) + cmd = _oracle_probe_cmd() + [ + "--mode", + "trace", + "--render", + "--render_width", + str(_WIDTH), + "--render_height", + str(_HEIGHT), + "--action_mode", + "midpoint", + "--steps", + "3", + "--seed", + "3", + "--out", + str(out_path), + "--trace_plan", + str(plan_path), + ] + for task_id in task_ids: + cmd.extend(["--task_id", task_id]) + env = os.environ.copy() + env["ROBOHIVE_VERBOSITY"] = "SILENT" + try: + try: + result = subprocess.run( + cmd, + check=False, + capture_output=True, + env=env, + text=True, + ) + except OSError as exc: + raise RuntimeError( + f"MyoSuite oracle probe failed to start\ncmd: {' '.join(cmd)}" + ) from exc + if result.returncode != 0: + raise RuntimeError( + "MyoSuite oracle probe failed\n" + f"cmd: {' '.join(cmd)}\n" + f"stdout:\n{result.stdout}\n" + f"stderr:\n{result.stderr}" + ) + return json.loads(out_path.read_text())["tasks"] + finally: + out_path.unlink(missing_ok=True) + plan_path.unlink(missing_ok=True) + + +def _sync_state_from_info(info: dict[str, Any]) -> dict[str, Any]: + dims = { + "nq": int(np.asarray(info["model_nq"]).ravel()[0]), + "nv": int(np.asarray(info["model_nv"]).ravel()[0]), + "na": int(np.asarray(info["model_na"]).ravel()[0]), + "nu": int(np.asarray(info["model_nu"]).ravel()[0]), + "nsite": int(np.asarray(info["model_nsite"]).ravel()[0]), + "nbody": int(np.asarray(info["model_nbody"]).ravel()[0]), + "ngeom": int(np.asarray(info["model_ngeom"]).ravel()[0]), + "nhfielddata": int(np.asarray(info["model_nhfielddata"]).ravel()[0]), + "nmocap": int(np.asarray(info["model_nmocap"]).ravel()[0]), + } + dims.update({ + "nsite3": dims["nsite"] * 3, + "nsite4": dims["nsite"] * 4, + "nbody3": dims["nbody"] * 3, + "nbody4": dims["nbody"] * 4, + "ngeom3": dims["ngeom"] * 3, + "ngeom4": dims["ngeom"] * 4, + "ngeom6": dims["ngeom"] * 6, + "nmocap3": dims["nmocap"] * 3, + "nmocap4": dims["nmocap"] * 4, + }) + sync_state = {} + for key in _SYNC_STATE_KEYS: + if key not in info: + continue + size = dims[_SYNC_STATE_SIZES[key]] + sync_state[key] = ( + np.asarray(info[key][0], dtype=np.float64).ravel()[:size].tolist() + ) + return sync_state + + +def _render(env: Any) -> np.ndarray: + frame = env.render() + if frame is None: + raise AssertionError("MyoSuite render returned None") + return frame + + +def _block_mean(frame: np.ndarray) -> np.ndarray: + height, width, channels = frame.shape + block = _RENDER_BLOCK_SIZE + if height % block or width % block: + raise AssertionError( + f"render size {frame.shape} is not divisible by block {block}" + ) + return frame.reshape( + height // block, + block, + width // block, + block, + channels, + ).mean(axis=(1, 3)) + + +def _assert_render_aligned( + test: absltest.TestCase, + frame: np.ndarray, + oracle_frame: np.ndarray, + *, + task_id: str, + step_id: int, +) -> None: + test.assertEqual(frame.shape, oracle_frame.shape) + test.assertEqual(frame.dtype, np.uint8) + test.assertEqual(oracle_frame.dtype, np.uint8) + diff = np.abs(frame.astype(np.int16) - oracle_frame.astype(np.int16)) + max_abs = int(diff.max()) + mean_abs = float(np.mean(diff)) + block_diff = np.abs(_block_mean(frame) - _block_mean(oracle_frame)) + block_mean_abs = float(np.mean(block_diff)) + large_mismatch_ratio = float( + np.mean(np.max(diff, axis=-1) > _LARGE_RENDER_DELTA) + ) + if ( + mean_abs > _MAX_RENDER_MEAN_ABS_DIFF + or block_mean_abs > _MAX_RENDER_BLOCK_MEAN_ABS_DIFF + or large_mismatch_ratio > _MAX_RENDER_LARGE_MISMATCH_RATIO + ): + test.fail( + f"{task_id} render step {step_id} drifted: " + f"max_abs={max_abs}, mean_abs={mean_abs:.6f}, " + f"block_mean_abs={block_mean_abs:.6f}, " + f"large_mismatch_ratio={large_mismatch_ratio:.6f}" + ) + + +def _midpoint_action(env: Any) -> np.ndarray: + low = np.asarray(env.action_space.low, dtype=np.float32) + high = np.asarray(env.action_space.high, dtype=np.float32) + return ((low + high) * 0.5).astype(np.float32) + + +def _envpool_trace_record( + task_id: str, +) -> tuple[list[np.ndarray], dict[str, Any]]: + env = make_gymnasium( + task_id, + num_envs=1, + seed=3, + render_mode="rgb_array", + render_width=_WIDTH, + render_height=_HEIGHT, + ) + try: + _, info = env.reset() + frames = [_render(env)[0]] + actions: list[list[float]] = [] + reset_before_step: list[bool] = [] + sync_states = [_sync_state_from_info(info)] + action = _midpoint_action(env) + for _ in range(3): + actions.append(action.tolist()) + *_, info = env.step(action[None, :]) + frames.append(_render(env)[0]) + step_info = info + elapsed_step = int(np.asarray(step_info["elapsed_step"]).ravel()[0]) + reset_before_step.append(elapsed_step == 0) + sync_states.append(_sync_state_from_info(step_info)) + plan = { + "actions": actions, + "reset_before_step": reset_before_step, + "sync_states": sync_states, + } + return frames, plan + finally: + env.close() + + +class MyoSuiteRenderTest(absltest.TestCase): + """Validate native MyoSuite RGB rendering after reset and steps.""" + + def test_reset_and_first_three_step_render(self) -> None: + """Representative tasks render through reset and first three steps.""" + for task_id in _SHARDED_NATIVE_RENDER_TASK_IDS: + with self.subTest(task_id=task_id): + env = make_gymnasium( + task_id, + num_envs=1, + seed=3, + render_mode="rgb_array", + render_width=_WIDTH, + render_height=_HEIGHT, + ) + try: + env.reset() + frames = [_render(env)] + action = np.zeros( + (1, *env.action_space.shape), dtype=np.float32 + ) + for _ in range(3): + env.step(action) + frames.append(_render(env)) + for frame in frames: + self.assertEqual(frame.shape, (1, _HEIGHT, _WIDTH, 3)) + self.assertEqual(frame.dtype, np.uint8) + self.assertGreater(int(frame.max()), int(frame.min())) + finally: + env.close() + + def test_official_trace_native_render_alignment(self) -> None: + """Official render matches EnvPool reset and first 3 API frames.""" + for batch in _task_batches( + _SHARDED_ORACLE_TRACE_TASK_IDS, _ORACLE_RENDER_BATCH_SIZE + ): + envpool_frames: dict[str, list[np.ndarray]] = {} + trace_plan: dict[str, dict[str, Any]] = {} + for task_id in batch: + frames, plan = _envpool_trace_record(task_id) + envpool_frames[task_id] = frames + trace_plan[task_id] = plan + oracle_tasks = _oracle_trace(batch, trace_plan) + self.assertSetEqual(set(oracle_tasks), set(batch)) + for task_id in batch: + with self.subTest(task_id=task_id): + oracle = oracle_tasks[task_id] + frames = envpool_frames[task_id] + oracle_frames = [ + np.asarray(frame, dtype=np.uint8) + for frame in oracle["frames"] + ] + self.assertLen(frames, 4) + self.assertLen(oracle_frames, 4) + for step_id, (frame, oracle_frame) in enumerate( + zip(frames, oracle_frames, strict=True) + ): + _assert_render_aligned( + self, + frame, + oracle_frame, + task_id=task_id, + step_id=step_id, + ) + self.assertGreater(int(frame.max()), int(frame.min())) + + +if __name__ == "__main__": + unittest.main() diff --git a/envpool/mujoco/myosuite/myosuite_test.py b/envpool/mujoco/myosuite/myosuite_test.py new file mode 100644 index 000000000..fb434fd19 --- /dev/null +++ b/envpool/mujoco/myosuite/myosuite_test.py @@ -0,0 +1,184 @@ +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Registry, smoke, and determinism tests for native MyoSuite envs.""" + +from __future__ import annotations + +import os +import unittest +from pathlib import Path +from typing import Any + +import numpy as np +from absl.testing import absltest + +import envpool.mujoco.myosuite.registration as myosuite_registration +from envpool.mujoco.myosuite.tasks import MYOSUITE_TASKS, MyoSuiteTask +from envpool.registration import list_all_envs, make_gymnasium, make_spec + +_TASKS = tuple(MYOSUITE_TASKS) +_TASK_IDS = tuple(task["id"] for task in _TASKS) +_DETERMINISM_STEPS = 8 +_INFO_KEYS = { + "task_id", + "sparse", + "solved", + "oracle_numpy2_broken", + "model_nq", + "model_nv", + "model_na", +} + + +def _shard_tasks(tasks: tuple[MyoSuiteTask, ...]) -> tuple[MyoSuiteTask, ...]: + total_shards = int( + os.environ.get( + "MYOSUITE_TEST_TOTAL_SHARDS", + os.environ.get("TEST_TOTAL_SHARDS", "1"), + ) + ) + shard_index = int( + os.environ.get( + "MYOSUITE_TEST_SHARD_INDEX", + os.environ.get("TEST_SHARD_INDEX", "0"), + ) + ) + shard_status_file = os.environ.get("TEST_SHARD_STATUS_FILE") + if shard_status_file: + Path(shard_status_file).touch() + if total_shards <= 1: + return tasks + if shard_index < 0 or shard_index >= total_shards: + raise ValueError(f"invalid Bazel shard {shard_index} of {total_shards}") + return tuple( + task + for index, task in enumerate(tasks) + if index % total_shards == shard_index + ) + + +_SHARDED_TASKS = _shard_tasks(_TASKS) + + +class MyoSuiteTest(absltest.TestCase): + """Validate registration, runtime surface, and determinism.""" + + def _assert_info_equal( + self, info0: dict[str, Any], info1: dict[str, Any] + ) -> None: + """Assert two vectorized EnvPool info dictionaries are equal.""" + self.assertEqual(info0.keys(), info1.keys()) + for key in info0: + arr0 = np.asarray(info0[key]) + arr1 = np.asarray(info1[key]) + if arr0.dtype == object or arr1.dtype == object: + self.assertEqual(arr0.shape, arr1.shape) + else: + np.testing.assert_array_equal( + arr0, arr1, err_msg=f"info[{key}]" + ) + + def test_generated_registry_matches_official_surface(self) -> None: + """Generated metadata must cover all pinned official MyoSuite IDs.""" + self.assertLen(_TASKS, 398) + self.assertEqual( + tuple(myosuite_registration.myosuite_task_ids), + _TASK_IDS, + ) + registered = set(list_all_envs()) + for task in _TASKS: + task_id = task["id"] + alias = f"MyoSuite/{task_id}" + with self.subTest(task_id=task_id): + self.assertIn(task_id, registered) + self.assertIn(alias, registered) + spec = make_spec(task_id) + alias_spec = make_spec(alias) + self.assertEqual( + spec.observation_space.shape, (task["obs_dim"],) + ) + self.assertEqual(spec.action_space.shape, (task["action_dim"],)) + self.assertEqual( + spec.config.max_episode_steps, + task["max_episode_steps"], + ) + self.assertEqual( + alias_spec.observation_space.shape, + spec.observation_space.shape, + ) + self.assertEqual( + alias_spec.action_space.shape, + spec.action_space.shape, + ) + + def test_reset_and_step_reference_surface(self) -> None: + """Every registered task must reset and step with expected shapes.""" + for task in _SHARDED_TASKS: + task_id = task["id"] + with self.subTest(task_id=task_id): + env = make_gymnasium(task_id, num_envs=2, seed=7) + try: + obs, info = env.reset() + self.assertEqual(obs.shape, (2, task["obs_dim"])) + self.assertTrue(_INFO_KEYS.issubset(info.keys())) + action = np.zeros((2, task["action_dim"]), dtype=np.float32) + obs, rew, term, trunc, info = env.step(action) + self.assertEqual(obs.shape, (2, task["obs_dim"])) + self.assertEqual(rew.shape, (2,)) + self.assertEqual(term.shape, (2,)) + self.assertEqual(trunc.shape, (2,)) + self.assertFalse(np.any(np.isnan(obs))) + self.assertFalse(np.any(np.isnan(rew))) + self.assertTrue(_INFO_KEYS.issubset(info.keys())) + finally: + env.close() + + def test_no_tasks_need_oracle_numpy2_exclusion(self) -> None: + """The pinned oracle can instantiate every official task under NumPy 2.""" + for task in _TASKS: + self.assertFalse(task["oracle_numpy2_broken"], task["id"]) + + def test_reference_surface_is_deterministic(self) -> None: + """Same seed and action sequence must produce identical rollouts.""" + rng = np.random.default_rng(123) + for task in _SHARDED_TASKS: + task_id = task["id"] + with self.subTest(task_id=task_id): + actions = rng.uniform( + -1.0, + 1.0, + size=(_DETERMINISM_STEPS, 2, task["action_dim"]), + ).astype(np.float32) + env0 = make_gymnasium(task_id, num_envs=2, seed=42) + env1 = make_gymnasium(task_id, num_envs=2, seed=42) + try: + obs0, info0 = env0.reset() + obs1, info1 = env1.reset() + np.testing.assert_array_equal(obs0, obs1) + self._assert_info_equal(info0, info1) + for action in actions: + step0 = env0.step(action) + step1 = env1.step(action) + for value0, value1 in zip( + step0[:4], step1[:4], strict=True + ): + np.testing.assert_array_equal(value0, value1) + self._assert_info_equal(step0[4], step1[4]) + finally: + env0.close() + env1.close() + + +if __name__ == "__main__": + unittest.main() diff --git a/envpool/mujoco/myosuite/registration.py b/envpool/mujoco/myosuite/registration.py new file mode 100644 index 000000000..09c9f00c5 --- /dev/null +++ b/envpool/mujoco/myosuite/registration.py @@ -0,0 +1,45 @@ +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MyoSuite v2.11.6 env registration.""" + +import os + +from envpool.registration import base_path, package_base_path, register + +from .tasks import MYOSUITE_TASKS + +myosuite_task_ids = [str(task["id"]) for task in MYOSUITE_TASKS] +myosuite_envpool_task_ids = [ + f"MyoSuite/{task_id}" for task_id in myosuite_task_ids +] +_myosuite_package_assets = os.path.join( + package_base_path, "mujoco/myosuite/assets" +) +_myosuite_base_path = ( + package_base_path if os.path.exists(_myosuite_package_assets) else base_path +) + +for task in MYOSUITE_TASKS: + task_id = str(task["id"]) + register( + task_id=task_id, + aliases=(f"MyoSuite/{task_id}",), + import_path="envpool.mujoco.myosuite", + spec_cls="MyoSuiteEnvSpec", + dm_cls="MyoSuiteDMEnvPool", + gymnasium_cls="MyoSuiteGymnasiumEnvPool", + task_name=task_id, + max_episode_steps=task["max_episode_steps"], + base_path=_myosuite_base_path, + ) diff --git a/envpool/mujoco/myosuite/tasks.py b/envpool/mujoco/myosuite/tasks.py new file mode 100644 index 000000000..0cf74335c --- /dev/null +++ b/envpool/mujoco/myosuite/tasks.py @@ -0,0 +1,113 @@ +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MyoSuite task metadata generated from the pinned upstream source.""" + +from __future__ import annotations + +import json +import os +from importlib import resources +from pathlib import Path +from typing import Any, TypedDict, cast + + +class MyoSuiteTask(TypedDict): + """Generated Python metadata for one pinned MyoSuite task.""" + + id: str + entry_point: str + model_path: str + reference_path: str + object_name: str + obs_dim: int + action_dim: int + max_episode_steps: int + frame_skip: int + normalize_act: bool + oracle_numpy2_broken: bool + + +_METADATA_DIR = Path("assets/metadata") +_ASSETS_METADATA_DIR = Path("mujoco/myosuite") / _METADATA_DIR +_TASKS_JSON = "myosuite_tasks.json" +_ORACLE_JSON = "myosuite_oracle_metadata.json" + + +def _metadata_candidates(filename: str) -> tuple[Path, ...]: + package_dir = Path(__file__).resolve().parent + candidates = [package_dir / _METADATA_DIR / filename] + assets_override = os.environ.get("ENVPOOL_ASSETS_PATH") + if assets_override: + candidates.append( + Path(assets_override) / _ASSETS_METADATA_DIR / filename + ) + runfiles = os.environ.get("TEST_SRCDIR") + if runfiles: + workspace = os.environ.get("TEST_WORKSPACE", "envpool") + candidates.append( + Path(runfiles) + / workspace + / "envpool/mujoco/myosuite" + / _METADATA_DIR + / filename + ) + return tuple(dict.fromkeys(candidates)) + + +def _read_metadata_json(filename: str) -> Any: + attempted: list[str] = [] + for path in _metadata_candidates(filename): + attempted.append(str(path)) + if path.is_file(): + return json.loads(path.read_text()) + resource_roots: list[tuple[str, tuple[str, ...]]] = [] + if __package__: + resource_roots.append((__package__, ("assets", "metadata"))) + resource_roots.append(( + "envpool_assets", + ("mujoco", "myosuite", "assets", "metadata"), + )) + for package, parts in resource_roots: + try: + resource = resources.files(package) + except ModuleNotFoundError: + continue + for part in (*parts, filename): + resource = resource.joinpath(part) + attempted.append(str(resource)) + if resource.is_file(): + return json.loads(resource.read_text()) + raise FileNotFoundError( + f"could not find MyoSuite generated metadata {filename}; " + f"tried {attempted}" + ) + + +MYOSUITE_TASKS = cast(list[MyoSuiteTask], _read_metadata_json(_TASKS_JSON)) +_ORACLE_METADATA = cast(dict[str, object], _read_metadata_json(_ORACLE_JSON)) +MYOSUITE_ORACLE_VERSION = str(_ORACLE_METADATA["version"]) +MYOSUITE_ORACLE_COMMIT = str(_ORACLE_METADATA["commit"]) +MYOSUITE_ORACLE_NUMPY2_BROKEN_IDS = frozenset( + str(task_id) + for task_id in cast(list[object], _ORACLE_METADATA["numpy2_broken_ids"]) +) + +__all__ = [ + "MYOSUITE_ORACLE_COMMIT", + "MYOSUITE_ORACLE_NUMPY2_BROKEN_IDS", + "MYOSUITE_ORACLE_VERSION", + "MYOSUITE_TASKS", + "MyoSuiteTask", +] diff --git a/envpool/mujoco/offscreen_renderer.cc b/envpool/mujoco/offscreen_renderer.cc index 7a6cecec1..744d1cd4f 100644 --- a/envpool/mujoco/offscreen_renderer.cc +++ b/envpool/mujoco/offscreen_renderer.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -24,10 +25,12 @@ #include #include #include +#include #include #if defined(__APPLE__) && __has_include() #include +#include #define ENVPOOL_HAS_CGL 1 #elif defined(_WIN32) && __has_include() #ifndef NOMINMAX @@ -46,6 +49,10 @@ namespace envpool::mujoco { +#if defined(ENVPOOL_HAS_CGL) +constexpr int kCglFirstFrameSettlePasses = 4; +#endif + namespace { mjtNum MedianGeomPosition(const mjData* data, int ngeom, int axis) { @@ -67,8 +74,27 @@ mjtNum MedianGeomPosition(const mjData* data, int ngeom, int axis) { class CglContext final : public GlContext { public: - CglContext() { - const std::array attribs = { + explicit CglContext(bool prefer_offline_context) { + const std::array preferred_attribs = { + kCGLPFAOpenGLProfile, + static_cast(kCGLOGLPVersion_Legacy), + kCGLPFAColorSize, + static_cast(24), + kCGLPFAAlphaSize, + static_cast(8), + kCGLPFADepthSize, + static_cast(24), + kCGLPFAStencilSize, + static_cast(8), + kCGLPFAMultisample, + kCGLPFASampleBuffers, + static_cast(1), + kCGLPFASamples, + static_cast(4), + kCGLPFAAccelerated, + static_cast(0), // terminator + }; + const std::array offline_attribs = { kCGLPFAOpenGLProfile, static_cast(kCGLOGLPVersion_Legacy), kCGLPFAColorSize, @@ -83,12 +109,20 @@ class CglContext final : public GlContext { static_cast(0), // value static_cast(0), // terminator }; - GLint npix = 0; - CGLError err = CGLChoosePixelFormat(attribs.data(), &pixel_format_, &npix); - if (err != kCGLNoError || pixel_format_ == nullptr || npix == 0) { + // Most MuJoCo oracles use the default accelerated CGL format; callers can + // still request the offline renderer when an upstream oracle does so. + bool chose_pixel_format = prefer_offline_context + ? ChoosePixelFormat(offline_attribs) + : ChoosePixelFormat(preferred_attribs); + if (!chose_pixel_format) { + chose_pixel_format = prefer_offline_context + ? ChoosePixelFormat(preferred_attribs) + : ChoosePixelFormat(offline_attribs); + } + if (!chose_pixel_format) { throw std::runtime_error("failed to create CGL pixel format"); } - err = CGLCreateContext(pixel_format_, nullptr, &context_); + CGLError err = CGLCreateContext(pixel_format_, nullptr, &context_); if (err != kCGLNoError || context_ == nullptr) { CGLReleasePixelFormat(pixel_format_); pixel_format_ = nullptr; @@ -140,11 +174,40 @@ class CglContext final : public GlContext { } private: + template + bool ChoosePixelFormat( + const std::array& attribs) { + GLint npix = 0; + CGLPixelFormatObj pixel_format = nullptr; + CGLError err = CGLChoosePixelFormat(attribs.data(), &pixel_format, &npix); + if (err != kCGLNoError || pixel_format == nullptr || npix == 0) { + if (pixel_format != nullptr) { + CGLReleasePixelFormat(pixel_format); + } + return false; + } + pixel_format_ = pixel_format; + return true; + } + CGLPixelFormatObj pixel_format_{nullptr}; CGLContextObj context_{nullptr}; bool locked_{false}; }; +void PrimeCglContextForFirstReadback() { + // GitHub's macOS 14 CGL/Metal stack lazily finalizes renderer sample state. + // These no-op queries happen after MuJoCo creates the offscreen framebuffer + // and before the first render, making the first readback deterministic. + GLint value = 0; + glGetIntegerv(GL_MAX_SAMPLES, &value); + glGetIntegerv(GL_SAMPLE_BUFFERS, &value); + glGetIntegerv(GL_SAMPLES, &value); + (void)glGetString(GL_VENDOR); + (void)glGetString(GL_RENDERER); + (void)glGetString(GL_VERSION); +} + #elif defined(ENVPOOL_HAS_WGL) namespace { @@ -302,64 +365,66 @@ class EglContext final : public GlContext { EGL_WIDTH, 1, EGL_HEIGHT, 1, EGL_NONE, }; - display_ = CreateDisplay(); - if (display_ == EGL_NO_DISPLAY) { + display_ = AcquireDisplay(); + if (display_ == nullptr) { throw std::runtime_error("failed to initialize EGL"); } eglReleaseThread(); + EGLDisplay display = display_->Get(); EGLConfig config = nullptr; EGLint num_configs = 0; - if (eglChooseConfig(display_, config_attribs.data(), &config, 1, + if (eglChooseConfig(display, config_attribs.data(), &config, 1, &num_configs) != EGL_TRUE || num_configs < 1) { - eglTerminate(display_); - display_ = EGL_NO_DISPLAY; + display_.reset(); throw std::runtime_error("failed to choose EGL config"); } if (eglBindAPI(EGL_OPENGL_API) != EGL_TRUE) { - eglTerminate(display_); - display_ = EGL_NO_DISPLAY; + display_.reset(); throw std::runtime_error("failed to bind EGL OpenGL API"); } - surface_ = - eglCreatePbufferSurface(display_, config, pbuffer_attribs.data()); - if (surface_ == EGL_NO_SURFACE) { - eglTerminate(display_); - display_ = EGL_NO_DISPLAY; - throw std::runtime_error("failed to create EGL pbuffer surface"); + if (!HasExtension(display, "EGL_KHR_surfaceless_context")) { + surface_ = + eglCreatePbufferSurface(display, config, pbuffer_attribs.data()); + if (surface_ == EGL_NO_SURFACE) { + display_.reset(); + throw std::runtime_error("failed to create EGL pbuffer surface"); + } } - context_ = eglCreateContext(display_, config, EGL_NO_CONTEXT, nullptr); + context_ = eglCreateContext(display, config, EGL_NO_CONTEXT, nullptr); if (context_ == EGL_NO_CONTEXT) { - eglDestroySurface(display_, surface_); - surface_ = EGL_NO_SURFACE; - eglTerminate(display_); - display_ = EGL_NO_DISPLAY; + if (surface_ != EGL_NO_SURFACE) { + eglDestroySurface(display, surface_); + surface_ = EGL_NO_SURFACE; + } + display_.reset(); throw std::runtime_error("failed to create EGL context"); } } ~EglContext() override { - if (display_ != EGL_NO_DISPLAY) { - eglMakeCurrent(display_, EGL_NO_SURFACE, EGL_NO_SURFACE, EGL_NO_CONTEXT); + if (display_ != nullptr) { + EGLDisplay display = display_->Get(); + eglMakeCurrent(display, EGL_NO_SURFACE, EGL_NO_SURFACE, EGL_NO_CONTEXT); if (context_ != EGL_NO_CONTEXT) { - eglDestroyContext(display_, context_); + eglDestroyContext(display, context_); } if (surface_ != EGL_NO_SURFACE) { - eglDestroySurface(display_, surface_); + eglDestroySurface(display, surface_); } - eglTerminate(display_); eglReleaseThread(); } } void MakeCurrent() override { - if (eglMakeCurrent(display_, surface_, surface_, context_) != EGL_TRUE) { + if (eglMakeCurrent(display_->Get(), surface_, surface_, context_) != + EGL_TRUE) { throw std::runtime_error("failed to make EGL context current"); } } void ClearCurrent() override { - if (eglMakeCurrent(display_, EGL_NO_SURFACE, EGL_NO_SURFACE, + if (eglMakeCurrent(display_->Get(), EGL_NO_SURFACE, EGL_NO_SURFACE, EGL_NO_CONTEXT) != EGL_TRUE) { throw std::runtime_error("failed to clear EGL context"); } @@ -367,6 +432,43 @@ class EglContext final : public GlContext { } private: + class DisplayHandle { + public: + explicit DisplayHandle(EGLDisplay display) : display_(display) {} + + ~DisplayHandle() { + if (display_ != EGL_NO_DISPLAY) { + eglTerminate(display_); + eglReleaseThread(); + } + } + + EGLDisplay Get() const { return display_; } + + private: + EGLDisplay display_{EGL_NO_DISPLAY}; + }; + + static bool HasExtension(EGLDisplay display, const char* extension) { + const char* extensions = eglQueryString(display, EGL_EXTENSIONS); + if (extensions == nullptr) { + return false; + } + const std::string_view list(extensions); + const std::string_view needle(extension); + std::size_t pos = 0; + while ((pos = list.find(needle, pos)) != std::string_view::npos) { + const bool begins_token = pos == 0 || list[pos - 1] == ' '; + const std::size_t end = pos + needle.size(); + const bool ends_token = end == list.size() || list[end] == ' '; + if (begins_token && ends_token) { + return true; + } + pos = end; + } + return false; + } + static EGLDisplay TryInitializeDisplay(EGLDisplay display) { if (display == EGL_NO_DISPLAY) { return EGL_NO_DISPLAY; @@ -440,7 +542,7 @@ class EglContext final : public GlContext { } #endif - static EGLDisplay CreateDisplay() { + static EGLDisplay CreateRawDisplay() { #if defined(ENVPOOL_HAS_EGL_DEVICE_EXT) EGLDisplay display = TryInitializeDeviceDisplay(); if (display != EGL_NO_DISPLAY) { @@ -450,20 +552,51 @@ class EglContext final : public GlContext { return TryInitializeDisplay(eglGetDisplay(EGL_DEFAULT_DISPLAY)); } - EGLDisplay display_{EGL_NO_DISPLAY}; + static std::shared_ptr AcquireDisplay() { + static std::mutex mutex; + static std::weak_ptr weak_display; + + std::lock_guard lock(mutex); + std::shared_ptr display = weak_display.lock(); + if (display != nullptr) { + return display; + } + EGLDisplay raw_display = CreateRawDisplay(); + if (raw_display == EGL_NO_DISPLAY) { + return nullptr; + } + display = std::make_shared(raw_display); + weak_display = display; + return display; + } + + std::shared_ptr display_; EGLContext context_{EGL_NO_CONTEXT}; EGLSurface surface_{EGL_NO_SURFACE}; }; #endif -std::shared_ptr CreateGlContext() { +std::shared_ptr CreateGlContext(bool share_cgl_context, + bool prefer_offline_cgl_context) { #if defined(ENVPOOL_HAS_CGL) + if (share_cgl_context) { + if (prefer_offline_cgl_context) { + thread_local std::shared_ptr offline_context = + std::make_shared(true); + return offline_context; + } + thread_local std::shared_ptr preferred_context = + std::make_shared(false); + return preferred_context; + } // Match Gymnasium's CGL lifecycle: create a context per renderer/viewer. // Reusing one CGL context across different MuJoCo models can leave renderer // state behind on macOS software/offline renderers. - return std::make_shared(); + return std::make_shared(prefer_offline_cgl_context); #elif defined(ENVPOOL_HAS_WGL) + (void)share_cgl_context; + (void)prefer_offline_cgl_context; if (wglGetCurrentContext() != nullptr && wglGetCurrentDC() != nullptr) { // Borrowed WGL handles become invalid if another library later calls // `glfw.terminate()`, so do not cache them across renderer instances. @@ -473,22 +606,38 @@ std::shared_ptr CreateGlContext() { std::make_shared(); return context; #elif defined(ENVPOOL_HAS_EGL) + (void)share_cgl_context; + (void)prefer_offline_cgl_context; thread_local std::shared_ptr context = std::make_shared(); return context; #else + (void)share_cgl_context; + (void)prefer_offline_cgl_context; throw std::runtime_error( "MuJoCo rendering is unsupported on this platform/build"); #endif } -OffscreenRenderer::OffscreenRenderer(CameraPolicy camera_policy) - : camera_policy_(camera_policy) { +OffscreenRenderer::OffscreenRenderer(CameraPolicy camera_policy, + bool disable_auxiliary_visuals, + bool share_cgl_context, + bool prefer_offline_cgl_context, + bool resize_offscreen) + : camera_policy_(camera_policy), + share_cgl_context_(share_cgl_context), + prefer_offline_cgl_context_(prefer_offline_cgl_context), + resize_offscreen_(resize_offscreen) { mjv_defaultScene(&scene_); mjv_defaultCamera(&camera_); mjv_defaultOption(&option_); mjv_defaultPerturb(&perturb_); mjr_defaultContext(&context_); + if (disable_auxiliary_visuals) { + option_.flags[mjVIS_TENDON] = 0; + option_.flags[mjVIS_ACTUATOR] = 0; + option_.flags[mjVIS_ACTIVATION] = 0; + } camera_.fixedcamid = -1; } @@ -496,19 +645,42 @@ OffscreenRenderer::~OffscreenRenderer() { if (!initialized_) { return; } - gl_context_->MakeCurrent(); - mjr_freeContext(&context_); + bool made_current = false; + try { + gl_context_->MakeCurrent(); + made_current = true; + } catch (const std::exception& error) { + static_cast(error); + // Destructors must not throw during Python/interpreter shutdown. If the GL + // context is already unavailable, free CPU-side scene state and let process + // teardown reclaim the backend resources. + } + if (made_current) { + mjr_freeContext(&context_); + } mjv_freeScene(&scene_); - gl_context_->ClearCurrent(); + if (made_current) { + try { + gl_context_->ClearCurrent(); + } catch (const std::exception& error) { + static_cast(error); + // Preserve no-throw destructor semantics. + } + } } void OffscreenRenderer::Initialize(const mjModel* model) { - gl_context_ = CreateGlContext(); + gl_context_ = + CreateGlContext(share_cgl_context_, prefer_offline_cgl_context_); gl_context_->MakeCurrent(); mjv_makeScene(model, &scene_, 10000); mjr_makeContext(model, &context_, mjFONTSCALE_150); mjr_setBuffer(mjFB_OFFSCREEN, &context_); + context_.readDepthMap = mjDEPTH_ZEROFAR; initialized_ = true; +#if defined(ENVPOOL_HAS_CGL) + PrimeCglContextForFirstReadback(); +#endif } void OffscreenRenderer::UpdateCamera(const mjModel* model, const mjData* data, @@ -519,8 +691,6 @@ void OffscreenRenderer::UpdateCamera(const mjModel* model, const mjData* data, } if (camera_id == -1 && camera_override != nullptr) { camera_ = *camera_override; - camera_.type = mjCAMERA_FREE; - camera_.fixedcamid = -1; return; } if (camera_id == -1 && camera_policy_ == CameraPolicy::kGymLike) { @@ -554,28 +724,57 @@ void OffscreenRenderer::UpdateCamera(const mjModel* model, const mjData* data, void OffscreenRenderer::Render(const mjModel* model, mjData* data, int width, int height, int camera_id, unsigned char* rgb, - const mjvCamera* camera_override) { + const mjvCamera* camera_override, + const mjvOption* option_override) { if (!initialized_) { Initialize(model); } gl_context_->MakeCurrent(); - if (context_.offWidth != width || context_.offHeight != height) { + if (resize_offscreen_ && + (context_.offWidth != width || context_.offHeight != height)) { mjr_resizeOffscreen(width, height, &context_); } mjr_setBuffer(mjFB_OFFSCREEN, &context_); UpdateCamera(model, data, camera_id, camera_override); mjrRect viewport = {0, 0, width, height}; - mjv_updateScene(model, data, &option_, &perturb_, &camera_, mjCAT_ALL, - &scene_); - mjr_render(viewport, &scene_, &context_); + auto render_scene = [&] { + mjv_updateScene(model, data, + option_override != nullptr ? option_override : &option_, + &perturb_, &camera_, mjCAT_ALL, &scene_); + mjr_render(viewport, &scene_, &context_); + }; std::size_t frame_bytes = static_cast(width) * height * 3 * sizeof(unsigned char); if (scratch_.size() != frame_bytes) { scratch_.resize(frame_bytes); } - mjr_readPixels(scratch_.data(), nullptr, viewport, &context_); + + auto read_pixels = [&] { +#if defined(ENVPOOL_HAS_CGL) + mjr_finish(); +#endif + mjr_readPixels(scratch_.data(), nullptr, viewport, &context_); + }; + + render_scene(); +#if defined(ENVPOOL_HAS_CGL) + // macOS CGL can expose an unsettled first offscreen frame. Settle once per + // renderer so env code does not need task-specific render workarounds. + if (!cgl_first_frame_settled_) { + read_pixels(); + for (int pass = 0; pass < kCglFirstFrameSettlePasses; ++pass) { + render_scene(); + read_pixels(); + } + cgl_first_frame_settled_ = true; + } else { + read_pixels(); + } +#else + read_pixels(); +#endif std::size_t row_bytes = static_cast(width) * 3 * sizeof(unsigned char); diff --git a/envpool/mujoco/offscreen_renderer.h b/envpool/mujoco/offscreen_renderer.h index 7b35d409e..2696139ee 100644 --- a/envpool/mujoco/offscreen_renderer.h +++ b/envpool/mujoco/offscreen_renderer.h @@ -38,17 +38,21 @@ class GlContext { virtual void ClearCurrent() = 0; }; -std::shared_ptr CreateGlContext(); +std::shared_ptr CreateGlContext( + bool share_cgl_context = false, bool prefer_offline_cgl_context = false); class OffscreenRenderer { public: explicit OffscreenRenderer( - CameraPolicy camera_policy = CameraPolicy::kGymLike); + CameraPolicy camera_policy = CameraPolicy::kGymLike, + bool disable_auxiliary_visuals = false, bool share_cgl_context = false, + bool prefer_offline_cgl_context = false, bool resize_offscreen = true); ~OffscreenRenderer(); void Render(const mjModel* model, mjData* data, int width, int height, int camera_id, unsigned char* rgb, - const mjvCamera* camera_override = nullptr); + const mjvCamera* camera_override = nullptr, + const mjvOption* option_override = nullptr); private: void Initialize(const mjModel* model); @@ -63,8 +67,14 @@ class OffscreenRenderer { mjrContext context_; std::vector scratch_; CameraPolicy camera_policy_; + bool share_cgl_context_; + bool prefer_offline_cgl_context_; + bool resize_offscreen_; bool initialized_{false}; bool free_camera_initialized_{false}; +#if defined(__APPLE__) + bool cgl_first_frame_settled_{false}; +#endif }; } // namespace envpool::mujoco diff --git a/envpool/mujoco/pixel_observation_test_utils.py b/envpool/mujoco/pixel_observation_test_utils.py index 15b0a6814..4a14e7db5 100644 --- a/envpool/mujoco/pixel_observation_test_utils.py +++ b/envpool/mujoco/pixel_observation_test_utils.py @@ -13,6 +13,11 @@ # limitations under the License. """Shared helpers for native MuJoCo pixel-observation tests.""" +import os +import platform +import subprocess +import sys +from collections.abc import Sequence from typing import Any import gymnasium @@ -20,6 +25,7 @@ from absl import logging from absl.testing import absltest +from envpool.python import glfw_context as envpool_glfw_context from envpool.python.glfw_context import preload_windows_gl_dlls from envpool.registration import make_gymnasium, make_spec, registry @@ -28,6 +34,11 @@ RENDER_WIDTH = 64 RENDER_HEIGHT = 48 NUM_STEPS = 3 +EGL_TEARDOWN_RENDER_WIDTH = 84 +EGL_TEARDOWN_RENDER_HEIGHT = 84 +EGL_TEARDOWN_FRAME_STACK = 3 +EGL_TEARDOWN_SUBPROCESS_TIMEOUT_SECONDS = 180 +EglTeardownCase = tuple[str, str, str] def task_ids_for_import_path(import_path: str) -> list[str]: @@ -78,6 +89,14 @@ def _assert_nested_equal(lhs: Any, rhs: Any) -> None: np.testing.assert_allclose(np.asarray(lhs), np.asarray(rhs)) +def _subprocess_output_to_text(output: str | bytes | None) -> str: + if output is None: + return "" + if isinstance(output, bytes): + return output.decode(errors="replace") + return output + + def assert_make_spec_exposes_bchw_pixel_specs( test: absltest.TestCase, import_path: str ) -> None: @@ -239,3 +258,88 @@ def assert_tasks_align_with_render_for_three_steps( _assert_pixels_match(obs, render) finally: env.close() + + +def assert_egl_pixel_env_teardown_exits_cleanly( + test: absltest.TestCase, + cases: Sequence[EglTeardownCase], +) -> None: + """Checks EGL pixel envs from multiple families exit without GL noise. + + The subprocess intentionally keeps envs alive until interpreter shutdown: + issue #401 happens in teardown, after the rollout itself has succeeded. + """ + if platform.system() != "Linux": + test.skipTest("EGL teardown regression is Linux-specific.") + env = dict(os.environ) + env["MUJOCO_GL"] = "egl" + env.setdefault("EGL_PLATFORM", "surfaceless") + + package_parent = os.path.dirname( + os.path.dirname(os.path.dirname(envpool_glfw_context.__file__)) + ) + python_paths = [package_parent] + [ + path for path in sys.path if path and path != package_parent + ] + python_path = os.pathsep.join(python_paths) + if env.get("PYTHONPATH"): + python_path = f"{python_path}{os.pathsep}{env['PYTHONPATH']}" + env["PYTHONPATH"] = python_path + + code = f""" +import importlib +import sys + +sys.path.insert(0, {package_parent!r}) +sys.modules.pop("envpool", None) +from envpool.registration import make + +envs = [] +for label, registration_module, task_id in {tuple(cases)!r}: + importlib.import_module(registration_module) + pixels = make( + task_id, + env_type="gymnasium", + num_envs=2, + from_pixels=True, + frame_stack={EGL_TEARDOWN_FRAME_STACK}, + render_width={EGL_TEARDOWN_RENDER_WIDTH}, + render_height={EGL_TEARDOWN_RENDER_HEIGHT}, + ) + obs, info = pixels.reset() + assert obs.shape == ( + 2, + {3 * EGL_TEARDOWN_FRAME_STACK}, + {EGL_TEARDOWN_RENDER_HEIGHT}, + {EGL_TEARDOWN_RENDER_WIDTH}, + ), (label, task_id, obs.shape) + envs.append(pixels) + print("successful", label, task_id) +""" + try: + result = subprocess.run( + [sys.executable, "-c", code], + env=env, + check=False, + capture_output=True, + text=True, + timeout=EGL_TEARDOWN_SUBPROCESS_TIMEOUT_SECONDS, + ) + except subprocess.TimeoutExpired as exc: + stdout = _subprocess_output_to_text(exc.stdout) + stderr = _subprocess_output_to_text(exc.stderr) + test.fail( + "EGL teardown subprocess timed out after " + f"{EGL_TEARDOWN_SUBPROCESS_TIMEOUT_SECONDS} seconds.\n" + f"stdout:\n{stdout}\nstderr:\n{stderr}" + ) + test.assertEqual( + result.returncode, + 0, + msg=f"stdout:\n{result.stdout}\nstderr:\n{result.stderr}", + ) + test.assertNotIn( + "OpenGL error 0x502 in or before mjr_makeContext", + result.stderr, + msg=f"stdout:\n{result.stdout}\nstderr:\n{result.stderr}", + ) diff --git a/envpool/mujoco/robotics/mujoco_env.h b/envpool/mujoco/robotics/mujoco_env.h index d2d46121c..fb4427397 100644 --- a/envpool/mujoco/robotics/mujoco_env.h +++ b/envpool/mujoco/robotics/mujoco_env.h @@ -114,28 +114,38 @@ class MujocoRobotEnv : public RenderableEnv { // teardown happens on the Python thread. Recreating the renderer on // Windows avoids cross-thread WGL resource lifetime issues. envpool::mujoco::OffscreenRenderer renderer( - envpool::mujoco::CameraPolicy::kGymLike); + envpool::mujoco::CameraPolicy::kGymLike, + DisableAuxiliaryRenderVisuals(), ShareRenderContext(), + PreferOfflineRenderContext(), ResizeOffscreenRenderContext()); #else if (renderer_ == nullptr) { renderer_ = std::make_unique( - envpool::mujoco::CameraPolicy::kGymLike); + envpool::mujoco::CameraPolicy::kGymLike, + DisableAuxiliaryRenderVisuals(), ShareRenderContext(), + PreferOfflineRenderContext(), ResizeOffscreenRenderContext()); } #endif mjvCamera camera_override; InitializeRenderCamera(&camera_override); + mjvOption option_override; + mjv_defaultOption(&option_override); + const mjvOption* option = + RenderOption(&option_override) ? &option_override : nullptr; if (RenderCamera(&camera_override)) { #ifdef _WIN32 renderer.Render(model_, data_, width, height, camera_id, rgb, - &camera_override); + &camera_override, option); #else renderer_->Render(model_, data_, width, height, camera_id, rgb, - &camera_override); + &camera_override, option); #endif } else { #ifdef _WIN32 - renderer.Render(model_, data_, width, height, camera_id, rgb); + renderer.Render(model_, data_, width, height, camera_id, rgb, nullptr, + option); #else - renderer_->Render(model_, data_, width, height, camera_id, rgb); + renderer_->Render(model_, data_, width, height, camera_id, rgb, nullptr, + option); #endif } } @@ -202,6 +212,15 @@ class MujocoRobotEnv : public RenderableEnv { } static mjModel* LoadModel(const std::string& xml_path) { + if (xml_path.size() >= 4 && + xml_path.substr(xml_path.size() - 4) == ".mjb") { + mjModel* model = mj_loadModel(xml_path.c_str(), nullptr); + if (model == nullptr) { + throw std::runtime_error("failed to load MuJoCo binary model: " + + xml_path); + } + return model; + } std::array error{}; mjModel* model = mj_loadXML(xml_path.c_str(), nullptr, error.data(), 1000); if (model == nullptr) { @@ -230,6 +249,14 @@ class MujocoRobotEnv : public RenderableEnv { (void)camera; return false; } + virtual bool RenderOption(mjvOption* option) { + (void)option; + return false; + } + virtual bool DisableAuxiliaryRenderVisuals() const { return true; } + virtual bool ShareRenderContext() const { return false; } + virtual bool PreferOfflineRenderContext() const { return false; } + virtual bool ResizeOffscreenRenderContext() const { return true; } void InitializeRenderCamera(mjvCamera* camera) const { mjv_defaultCamera(camera); diff --git a/envpool/pip.bzl b/envpool/pip.bzl index defa30ef0..28b6bcaf5 100644 --- a/envpool/pip.bzl +++ b/envpool/pip.bzl @@ -15,6 +15,7 @@ """EnvPool pip requirements initialization, this is loaded in WORKSPACE.""" load("@python_versions//:pip.bzl", "multi_pip_parse") +load("//third_party/myosuite:oracle_workspace.bzl", "myosuite_oracle_pip_workspace") def workspace(): """Configure pip requirements.""" @@ -40,3 +41,5 @@ def workspace(): quiet = False, # extra_pip_args = ["--extra-index-url", "https://mirrors.aliyun.com/pypi/simple"], ) + + myosuite_oracle_pip_workspace() diff --git a/envpool/procgen/procgen_test.py b/envpool/procgen/procgen_test.py index 4daa62a1c..ef3bc5da8 100644 --- a/envpool/procgen/procgen_test.py +++ b/envpool/procgen/procgen_test.py @@ -14,6 +14,8 @@ """Unit tests for Procgen environments.""" # import cv2 +from typing import Any, cast + import numpy as np from absl import logging from absl.testing import absltest @@ -37,7 +39,11 @@ def procgen_oracle_check( total: int = 200, ) -> None: if ProcgenGym3Env is None: - self.skipTest("upstream procgen is not installed") + self.skipTest( + "optional upstream procgen oracle is not installed; " + "native Procgen checks still run" + ) + procgen_env_cls = cast(Any, ProcgenGym3Env) logging.info(f"procgen oracle check for {task_id}") envpool_env = make_gym( @@ -51,7 +57,7 @@ def procgen_oracle_check( procgen_env = None rng = np.random.default_rng(seed) try: - procgen_env = ProcgenGym3Env( + procgen_env = procgen_env_cls( num=1, env_name=env_name, distribution_mode=distribution[dist_value].lower(), diff --git a/envpool/python/BUILD b/envpool/python/BUILD index 3c779abff..545c5097b 100644 --- a/envpool/python/BUILD +++ b/envpool/python/BUILD @@ -112,6 +112,10 @@ py_library( py_library( name = "dm_envpool", srcs = ["dm_envpool.py"], + data = [ + "lax.py", + "xla_template.py", + ], imports = ["../.."], deps = [ requirement("optree"), @@ -119,7 +123,6 @@ py_library( requirement("numpy"), ":data", ":envpool", - ":lax", ":utils", ], ) @@ -127,6 +130,10 @@ py_library( py_library( name = "gymnasium_envpool", srcs = ["gymnasium_envpool.py"], + data = [ + "lax.py", + "xla_template.py", + ], imports = ["../.."], deps = [ requirement("optree"), @@ -135,7 +142,6 @@ py_library( requirement("numpy"), ":data", ":envpool", - ":lax", ":utils", ], ) diff --git a/envpool/python/envpool.py b/envpool/python/envpool.py index 1ea972d7b..3a76de4d5 100644 --- a/envpool/python/envpool.py +++ b/envpool/python/envpool.py @@ -26,6 +26,14 @@ from .glfw_context import try_ensure_mujoco_glfw_context from .protocol import EnvPool, EnvSpec +_MUJOCO_PYBIND_MODULE_PREFIXES = ( + "envpool.mujoco", + "metaworld_", + "mujoco_", + "myosuite_", + "robotics_", +) + def _normalize_env_id(env_id: Any) -> Any: """Normalize env_id while preserving traced arrays for XLA send paths.""" @@ -60,8 +68,10 @@ def _requires_windows_glfw_context(self) -> bool: return False # Dynamic wrapper classes are created in envpool.python.api, so detect # MuJoCo by scanning the MRO for the underlying pybind base module. + # Some pybind extensions are named by task family rather than + # `mujoco_*`, but they still share the same Windows GL bootstrap. return any( - base.__module__.startswith(("envpool.mujoco", "mujoco_")) + base.__module__.startswith(_MUJOCO_PYBIND_MODULE_PREFIXES) for base in type(self).__mro__ ) diff --git a/envpool/python/glfw_context.py b/envpool/python/glfw_context.py index 8de2948c6..c0206102a 100644 --- a/envpool/python/glfw_context.py +++ b/envpool/python/glfw_context.py @@ -29,6 +29,24 @@ _PRELOADED_DLL_PATHS: set[str] = set() +def _resolve_windows_gl_dll_dir(root: Path, *, strict: bool) -> Path | None: + if (root / "opengl32.dll").is_file(): + return root + candidates = sorted( + path.parent for path in root.rglob("opengl32.dll") if path.is_file() + ) + if candidates: + for candidate in candidates: + if candidate.name.lower() in {"x64", "bin"}: + return candidate + return candidates[0] + if strict: + raise FileNotFoundError( + f"ENVPOOL_DLL_DIR does not contain opengl32.dll: {root}" + ) + return None + + def preload_windows_gl_dlls( *, prepend_path: bool = True, strict: bool = False ) -> None: @@ -45,7 +63,10 @@ def preload_windows_gl_dlls( f"ENVPOOL_DLL_DIR does not exist: {resolved_dir}" ) return - resolved_str = str(resolved_dir) + resolved_dll_dir = _resolve_windows_gl_dll_dir(resolved_dir, strict=strict) + if resolved_dll_dir is None: + return + resolved_str = str(resolved_dll_dir) if prepend_path: path_entries = os.environ.get("PATH", "").split(os.pathsep) if resolved_str not in path_entries: @@ -61,7 +82,7 @@ def preload_windows_gl_dlls( if win_dll is None: return for dll_name in ("libglapi.dll", "libgallium_wgl.dll", "opengl32.dll"): - dll_path = resolved_dir / dll_name + dll_path = resolved_dll_dir / dll_name dll_path_str = str(dll_path) if dll_path.is_file() and dll_path_str not in _PRELOADED_DLL_PATHS: _WINDOWS_DLL_HANDLES.append(win_dll(str(dll_path))) diff --git a/envpool/python/gymnasium_envpool.py b/envpool/python/gymnasium_envpool.py index 6505d59b2..1f6454445 100644 --- a/envpool/python/gymnasium_envpool.py +++ b/envpool/python/gymnasium_envpool.py @@ -13,6 +13,7 @@ # limitations under the License. """EnvPool meta class for gymnasium.Env API.""" +import warnings from abc import ABCMeta from typing import Any, cast @@ -24,11 +25,87 @@ from .envpool import EnvPoolMixin from .utils import check_key_duplication +try: + from gymnasium.vector import VectorEnv as _GymnasiumVectorEnv +except (AttributeError, ImportError): + _VECTOR_ENV_CLS: type | None = None +else: + _VECTOR_ENV_CLS = _GymnasiumVectorEnv + +try: + from gymnasium.vector.vector_env import AutoresetMode as _AutoresetMode +except (AttributeError, ImportError): + _AUTORESET_MODE = None +else: + _AUTORESET_MODE = _AutoresetMode.NEXT_STEP + + +def _gymnasium_base_classes() -> tuple[type, ...]: + if _VECTOR_ENV_CLS is None: + return (gymnasium.Env,) + if issubclass(_VECTOR_ENV_CLS, gymnasium.Env): + return (_VECTOR_ENV_CLS,) + return (_VECTOR_ENV_CLS, gymnasium.Env) + + +def _env_ids_from_reset_options( + options: dict[str, Any] | None, num_envs: int +) -> np.ndarray | None: + if options is None: + return None + allowed_options = {"reset_mask"} + unknown_options = set(options) - allowed_options + if unknown_options: + raise ValueError( + "Unsupported Gymnasium reset options for EnvPool: " + f"{sorted(unknown_options)}" + ) + reset_mask = options.get("reset_mask") + if reset_mask is None: + return None + reset_mask = np.asarray(reset_mask, dtype=np.bool_) + if reset_mask.shape != (num_envs,): + raise ValueError( + f"reset_mask must have shape ({num_envs},), got {reset_mask.shape}" + ) + if not np.any(reset_mask): + raise ValueError("reset_mask must select at least one environment.") + return np.flatnonzero(reset_mask).astype(np.int32) + class GymnasiumEnvPoolMixin: """Special treatment for gymnasim API.""" - metadata = {"render_modes": ["rgb_array", "human"]} + metadata = ( + { + "render_modes": ["rgb_array", "human"], + "autoreset_mode": _AUTORESET_MODE, + } + if _AUTORESET_MODE is not None + else {"render_modes": ["rgb_array", "human"]} + ) + + @property + def num_envs(self: Any) -> int: + """Number of sub-environments in this vectorized EnvPool.""" + return int(self.config["num_envs"]) + + @property + def is_vector_env(self: Any) -> bool: + """Compatibility flag used by older Gymnasium vector-aware wrappers.""" + return True + + @property + def single_observation_space( + self: Any, + ) -> gymnasium.Space | dict[str, Any]: + """Single sub-environment observation space.""" + return self.observation_space + + @property + def single_action_space(self: Any) -> gymnasium.Space | dict[str, Any]: + """Single sub-environment action space.""" + return self.action_space @property def observation_space(self: Any) -> gymnasium.Space | dict[str, Any]: @@ -49,6 +126,37 @@ def render_mode(self: Any) -> str | None: """Render mode configured at construction time.""" return getattr(self, "_render_mode", None) + def reset( + self: Any, + env_id: np.ndarray | None = None, + *, + seed: int | list[int] | None = None, + options: dict[str, Any] | None = None, + ) -> Any: + """Reset with Gymnasium-compatible seed and options keywords.""" + if seed is not None: + warnings.warn( + "EnvPool seeds are fixed when the environment is created. " + "reset(seed=...) is ignored; pass seed to envpool.make " + "instead.", + stacklevel=2, + ) + option_env_id = _env_ids_from_reset_options( + options, self.config["num_envs"] + ) + if env_id is not None and option_env_id is not None: + raise ValueError( + "Pass either env_id or options['reset_mask'], not both." + ) + if option_env_id is not None: + env_id = option_env_id + return cast(Any, super()).reset(env_id) + + def close(self: Any, **kwargs: Any) -> None: + """Accept Gymnasium VectorEnv close kwargs without changing EnvPool.""" + del kwargs + return cast(Any, super()).close() + class GymnasiumEnvPoolMeta( ABCMeta, @@ -67,7 +175,7 @@ def __new__(cls: Any, name: str, parents: tuple, attrs: dict) -> Any: GymnasiumEnvPoolMixin, EnvPoolMixin, XlaMixin, - gymnasium.Env, + *_gymnasium_base_classes(), ) except (ImportError, AttributeError): @@ -77,7 +185,12 @@ def _xla(self: Any) -> None: ) attrs["xla"] = _xla - parents = (base, GymnasiumEnvPoolMixin, EnvPoolMixin, gymnasium.Env) + parents = ( + base, + GymnasiumEnvPoolMixin, + EnvPoolMixin, + *_gymnasium_base_classes(), + ) state_keys = base._state_keys action_keys = base._action_keys diff --git a/envpool/python/protocol.py b/envpool/python/protocol.py index 9c3472737..de56dd602 100644 --- a/envpool/python/protocol.py +++ b/envpool/python/protocol.py @@ -348,14 +348,30 @@ def reset( class GymnasiumEnvPool(EnvPool, Protocol): """gymnasium-compatible EnvPool interface.""" + @property + def num_envs(self) -> int: + """Number of sub-environments.""" + + @property + def is_vector_env(self) -> bool: + """Whether the EnvPool should be treated as vectorized.""" + @property def observation_space(self) -> Any: """Gymnasium observation space.""" + @property + def single_observation_space(self) -> Any: + """Gymnasium single sub-environment observation space.""" + @property def action_space(self) -> Any: """Gymnasium action space.""" + @property + def single_action_space(self) -> Any: + """Gymnasium single sub-environment action space.""" + @overload def recv( self, @@ -394,5 +410,8 @@ def step( def reset( self, env_id: np.ndarray | None = None, + *, + seed: int | list[int] | None = None, + options: dict[str, Any] | None = None, ) -> GymnasiumResetReturn: """Reset the gymnasium-compatible envpool.""" diff --git a/envpool/registration.py b/envpool/registration.py index 63a957d39..b8ff67254 100644 --- a/envpool/registration.py +++ b/envpool/registration.py @@ -14,6 +14,7 @@ """Global env registry.""" import importlib +import importlib.util import os from collections.abc import Sequence from typing import Any, Literal, overload @@ -26,7 +27,38 @@ GymnasiumEnvPool, ) -base_path = os.path.abspath(os.path.dirname(__file__)) + +def _package_dir(package_name: str) -> str | None: + spec = importlib.util.find_spec(package_name) + if spec is None or spec.submodule_search_locations is None: + return None + return os.path.abspath(next(iter(spec.submodule_search_locations))) + + +def _has_local_assets(path: str) -> bool: + return all( + os.path.exists(os.path.join(path, asset_path)) + for asset_path in ( + "atari/roms", + "gfootball/assets", + "mujoco/assets_dmc", + "procgen/assets", + "vizdoom/maps", + ) + ) + + +def _asset_base_path() -> str: + override = os.environ.get("ENVPOOL_ASSETS_PATH") + if override: + return os.path.abspath(override) + if _has_local_assets(package_base_path): + return package_base_path + return _package_dir("envpool_assets") or package_base_path + + +package_base_path = os.path.abspath(os.path.dirname(__file__)) +base_path = _asset_base_path() class EnvRegistry: @@ -119,6 +151,7 @@ def _pixel_variant_supported(import_path: str) -> bool: "envpool.mujoco.dmc", "envpool.mujoco.gym", "envpool.mujoco.metaworld", + "envpool.mujoco.myosuite", "envpool.mujoco.robotics", } diff --git a/envpool/vizdoom/registration.py b/envpool/vizdoom/registration.py index a791da72b..a60ddee3d 100644 --- a/envpool/vizdoom/registration.py +++ b/envpool/vizdoom/registration.py @@ -15,7 +15,7 @@ import os -from envpool.registration import base_path, register +from envpool.registration import base_path, package_base_path, register maps_path = os.path.join(base_path, "vizdoom", "maps") @@ -45,6 +45,7 @@ def _vizdoom_game_list() -> list[str]: dm_cls="VizdoomDMEnvPool", gymnasium_cls="VizdoomGymnasiumEnvPool", cfg_path=cfg_path, + vzd_path=os.path.join(package_base_path, "vizdoom", "bin", "vizdoom"), wad_path=wad_path, max_episode_steps=525, ) diff --git a/envpool/vizdoom/vizdoom_env.h b/envpool/vizdoom/vizdoom_env.h index 40bd324c8..55d924f52 100644 --- a/envpool/vizdoom/vizdoom_env.h +++ b/envpool/vizdoom/vizdoom_env.h @@ -17,6 +17,7 @@ #ifndef ENVPOOL_VIZDOOM_VIZDOOM_ENV_H_ #define ENVPOOL_VIZDOOM_VIZDOOM_ENV_H_ +#include #include #include #include @@ -32,9 +33,25 @@ namespace vizdoom { -std::string MergePath(const std::string& base_path, - const std::string& file_path) { - if (file_path[0] == '/') { +inline bool IsAbsolutePath(const std::string& file_path) { + if (file_path.empty()) { + return false; + } +#if defined(_WIN32) + if (file_path[0] == '/' || file_path[0] == '\\') { + return true; + } + return file_path.size() >= 3 && + std::isalpha(static_cast(file_path[0])) && + file_path[1] == ':' && (file_path[2] == '/' || file_path[2] == '\\'); +#else + return file_path[0] == '/'; +#endif +} + +inline std::string MergePath(const std::string& base_path, + const std::string& file_path) { + if (file_path.empty() || IsAbsolutePath(file_path)) { return file_path; } return base_path + "/" + file_path; diff --git a/envpool/vizdoom/vizdoom_pretrain_test.py b/envpool/vizdoom/vizdoom_pretrain_test.py index fff4c43d3..e9c401a05 100644 --- a/envpool/vizdoom/vizdoom_pretrain_test.py +++ b/envpool/vizdoom/vizdoom_pretrain_test.py @@ -147,12 +147,14 @@ def _eval_c51_subprocess( result_queue: mp.Queue, task: str, resume_path: str, + num_envs: int, cfg_path: str | None, reward_config: dict | None, ) -> None: reward, length = _eval_c51_impl( task, resume_path, + num_envs=num_envs, cfg_path=cfg_path, reward_config=reward_config, ) @@ -188,6 +190,7 @@ def eval_c51_subprocess( self, task: str, resume_path: str, + num_envs: int = 10, cfg_path: str | None = None, reward_config: dict | None = None, ) -> tuple[np.ndarray, np.ndarray]: @@ -195,7 +198,14 @@ def eval_c51_subprocess( result_queue: mp.Queue = ctx.Queue() proc = ctx.Process( target=_eval_c51_subprocess, - args=(result_queue, task, resume_path, cfg_path, reward_config), + args=( + result_queue, + task, + resume_path, + num_envs, + cfg_path, + reward_config, + ), ) proc.start() proc.join(timeout=360) @@ -222,9 +232,11 @@ def test_d3(self) -> None: model_path = self.get_package_path("policy-d3.pth") self.assertTrue(os.path.exists(model_path)) reward_config = {"KILLCOUNT": [1, 0]} + num_envs = 2 baseline_reward, baseline_length = self.eval_c51_subprocess( "D3_battle", model_path, + num_envs=num_envs, reward_config=reward_config, ) # test with customized config @@ -239,6 +251,7 @@ def test_d3(self) -> None: reward, length = self.eval_c51_subprocess( "D3_battle", model_path, + num_envs=num_envs, cfg_path=custom_cfg_path, reward_config=reward_config, ) diff --git a/envpool/workspace0.bzl b/envpool/workspace0.bzl index 11bb74497..3f724fa8b 100644 --- a/envpool/workspace0.bzl +++ b/envpool/workspace0.bzl @@ -17,6 +17,7 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") load("//third_party/cuda:cuda.bzl", "cuda_configure") +load("//third_party/freedoom:defs.bzl", "freedoom_archive") load("//third_party/gfootball:repo.bzl", "gfootball_archive") load("//third_party/vizdoom:repo.bzl", "vizdoom_archive") @@ -146,7 +147,9 @@ def workspace(): build_file = "//third_party/openxla_ffi:ffi_api.BUILD", sha256 = "753df38eab0d430da20e614316401663bcfca433905b976745a6e59998635ce8", strip_prefix = "xla-187a5eb58277a85847d1516bd1e20b7faf03d5ef/xla/ffi/api", + type = "tar.gz", urls = [ + "https://codeload.github.com/openxla/xla/tar.gz/187a5eb58277a85847d1516bd1e20b7faf03d5ef", "https://github.com/openxla/xla/archive/187a5eb58277a85847d1516bd1e20b7faf03d5ef.tar.gz", ], ) @@ -212,6 +215,7 @@ def workspace(): strip_prefix = "ThreadPool-9a42ec1329f259a5f4881a291db1dcb8f2ad9040", urls = [ "https://github.com/progschj/ThreadPool/archive/9a42ec1329f259a5f4881a291db1dcb8f2ad9040.zip", + "https://codeload.github.com/progschj/ThreadPool/zip/9a42ec1329f259a5f4881a291db1dcb8f2ad9040", ], build_file = "//third_party/threadpool:threadpool.BUILD", patches = [ @@ -222,10 +226,10 @@ def workspace(): maybe( http_archive, name = "zlib", - sha256 = "bb329a0a2cd0274d05519d61c667c062e06990d72e125ee2dfa8de64f0119d16", + sha256 = "b99a0b86c0ba9360ec7e78c4f1e43b1cbdf1e6936c8fa0f6835c0cd694a495a1", strip_prefix = "zlib-1.3.2", urls = [ - "https://github.com/madler/zlib/releases/download/v1.3.2/zlib-1.3.2.tar.gz", + "https://github.com/madler/zlib/archive/refs/tags/v1.3.2.tar.gz", ], build_file = "//third_party/zlib:zlib.BUILD", ) @@ -339,6 +343,9 @@ perl -Iperllib -I. macros/macros.pl version.mac 'macros/*.mac' 'output/*.mac' "https://www.libsdl.org/release/SDL2-2.32.10.tar.gz", "https://github.com/libsdl-org/SDL/releases/download/release-2.32.10/SDL2-2.32.10.tar.gz", ], + patches = [ + "//third_party/sdl2:windows_xinput_stub.patch", + ], build_file = "//third_party/sdl2:sdl2.BUILD", ) @@ -411,14 +418,27 @@ perl -Iperllib -I. macros/macros.pl version.mac 'macros/*.mac' 'output/*.mac' ) maybe( - http_archive, + freedoom_archive, name = "freedoom", + attempts = 8, + build_file = "//third_party/freedoom:freedoom.BUILD", sha256 = "f42c6810fc89b0282de1466c2c9c7c9818031a8d556256a6db1b69f6a77b5806", strip_prefix = "freedoom-0.12.1/", + type = "zip", urls = [ "https://github.com/freedoom/freedoom/releases/download/v0.12.1/freedoom-0.12.1.zip", ], - build_file = "//third_party/freedoom:freedoom.BUILD", + ) + + maybe( + http_archive, + name = "re2c_4_5_1", + build_file = "//third_party/re2c:re2c.BUILD", + sha256 = "ffea067c11aa668bcb42885be6e6cd000302000b7747d2bb213299ec66b7864e", + strip_prefix = "re2c-4.5.1", + urls = [ + "https://github.com/skvadrik/re2c/releases/download/4.5.1/re2c-4.5.1.tar.xz", + ], ) maybe( @@ -584,6 +604,7 @@ perl -Iperllib -I. macros/macros.pl version.mac 'macros/*.mac' 'output/*.mac' strip_prefix = "Gymnasium-Robotics-1.4.2/gymnasium_robotics/envs", urls = [ "https://github.com/Farama-Foundation/Gymnasium-Robotics/archive/refs/tags/v1.4.2.tar.gz", + "https://codeload.github.com/Farama-Foundation/Gymnasium-Robotics/tar.gz/refs/tags/v1.4.2", ], build_file = "//third_party/gymnasium_robotics_assets:gymnasium_robotics_assets.BUILD", ) @@ -595,10 +616,87 @@ perl -Iperllib -I. macros/macros.pl version.mac 'macros/*.mac' 'output/*.mac' strip_prefix = "Metaworld-3.0.0", urls = [ "https://github.com/Farama-Foundation/Metaworld/archive/refs/tags/v3.0.0.tar.gz", + "https://codeload.github.com/Farama-Foundation/Metaworld/tar.gz/refs/tags/v3.0.0", ], build_file = "//third_party/metaworld_assets:metaworld_assets.BUILD", ) + maybe( + http_archive, + name = "myosuite_source", + sha256 = "f75b77563547fce6d9be46abee2b86e636dd5e57a6f1d470fdbc2104dcd61d34", + strip_prefix = "myosuite-2.11.6", + urls = [ + "https://github.com/MyoHub/myosuite/archive/refs/tags/v2.11.6.tar.gz", + "https://codeload.github.com/MyoHub/myosuite/tar.gz/refs/tags/v2.11.6", + ], + patch_args = ["-p1"], + patches = ["//third_party/myosuite:mujoco36_mjspec_compat.patch"], + build_file = "//third_party/myosuite:myosuite_source.BUILD", + ) + + maybe( + http_archive, + name = "myosuite_myo_sim", + sha256 = "bd8fdf313b46dbefcd25bf42cf8ddcc45066798164bb3551a990690cad514ebd", + strip_prefix = "myo_sim-33f3ded946f55adbdcf963c99999587aadaf975f", + urls = [ + "https://github.com/MyoHub/myo_sim/archive/33f3ded946f55adbdcf963c99999587aadaf975f.tar.gz", + "https://codeload.github.com/MyoHub/myo_sim/tar.gz/33f3ded946f55adbdcf963c99999587aadaf975f", + ], + build_file = "//third_party/myosuite:simhive_source.BUILD", + ) + + maybe( + http_archive, + name = "myosuite_object_sim", + sha256 = "beed226fcf1d27b91f9147221ef450c2ccab8e5bb7b5954dbcb5635024ed4874", + strip_prefix = "object_sim-0.1.0", + urls = [ + # MyoSuite v2.11.6 gitlinks vikashplus/object_sim@87cd8dd, but + # that commit is no longer fetchable from GitHub archives. + "https://github.com/MyoHub/object_sim/archive/refs/tags/v0.1.0.tar.gz", + "https://codeload.github.com/MyoHub/object_sim/tar.gz/refs/tags/v0.1.0", + ], + build_file = "//third_party/myosuite:simhive_source.BUILD", + ) + + maybe( + http_archive, + name = "myosuite_mpl_sim", + sha256 = "591fce117832c789e227499ea45c601a9ca142c7dd636492f8bbcd825d54ea0a", + strip_prefix = "MPL_sim-58dd1abc6058e0dc06e62f13a61c36adb4916815", + urls = [ + "https://github.com/vikashplus/MPL_sim/archive/58dd1abc6058e0dc06e62f13a61c36adb4916815.tar.gz", + "https://codeload.github.com/vikashplus/MPL_sim/tar.gz/58dd1abc6058e0dc06e62f13a61c36adb4916815", + ], + build_file = "//third_party/myosuite:simhive_source.BUILD", + ) + + maybe( + http_archive, + name = "myosuite_ycb_sim", + sha256 = "81caf29e5b5c01b4af56991731b3f731a95d486addccafaaaedc7600a9f2437e", + strip_prefix = "YCB_sim-46edd9c361061c5d81a82f2511d4fbf76fead569", + urls = [ + "https://github.com/vikashplus/YCB_sim/archive/46edd9c361061c5d81a82f2511d4fbf76fead569.tar.gz", + "https://codeload.github.com/vikashplus/YCB_sim/tar.gz/46edd9c361061c5d81a82f2511d4fbf76fead569", + ], + build_file = "//third_party/myosuite:simhive_source.BUILD", + ) + + maybe( + http_archive, + name = "myosuite_furniture_sim", + sha256 = "5fb42ed8c932f7c820a72fbb86ea736957476020bdf008e17277380c3693ce9e", + strip_prefix = "furniture_sim-c97995afb81c9e2d7325b0069f9abc9a2c74a2f0", + urls = [ + "https://github.com/vikashplus/furniture_sim/archive/c97995afb81c9e2d7325b0069f9abc9a2c74a2f0.tar.gz", + "https://codeload.github.com/vikashplus/furniture_sim/tar.gz/c97995afb81c9e2d7325b0069f9abc9a2c74a2f0", + ], + build_file = "//third_party/myosuite:simhive_source.BUILD", + ) + maybe( http_archive, name = "box2d", diff --git a/scripts/clang_tidy_targets.py b/scripts/clang_tidy_targets.py index b7071260b..a47b6a3ef 100755 --- a/scripts/clang_tidy_targets.py +++ b/scripts/clang_tidy_targets.py @@ -25,7 +25,12 @@ BINDING_ONLY_FILES = frozenset(("envpool/minigrid/minigrid_bindings.cc",)) BINDING_ONLY_PREFIXES = ("envpool/minigrid_bindings/",) CC_RULE_KIND = "cc_(library|test)" -SKIP_TARGETS = frozenset() +# clang-tidy-18 does not finish the MyoSuite pybind translation unit: it gets +# stuck analyzing pybind11/PyEnvPool templates after the native runtime has +# already been parsed. Keep runtime coverage through +# //envpool/mujoco:myosuite_clang_tidy and leave only the binding module to +# normal compile/test coverage. +SKIP_TARGETS = frozenset(("//envpool/mujoco:myosuite_envpool_module",)) def _filter_targets(targets: list[str]) -> list[str]: diff --git a/scripts/release_installed_wheel_smoke.py b/scripts/release_installed_wheel_smoke.py new file mode 100644 index 000000000..e296d8f2d --- /dev/null +++ b/scripts/release_installed_wheel_smoke.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +"""Smoke-check that release tests import installed wheels, including assets.""" + +from __future__ import annotations + +import argparse +from importlib.metadata import version +from pathlib import Path + +from packaging.version import Version + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "--source-root", + type=Path, + required=True, + help="EnvPool source checkout that must not be imported.", + ) + return parser.parse_args() + + +def _is_relative_to(path: Path, parent: Path) -> bool: + try: + path.relative_to(parent) + except ValueError: + return False + return True + + +def main() -> None: + """Check installed EnvPool and envpool-assets package wiring.""" + args = _parse_args() + source_root = args.source_root.resolve() + + import envpool_assets + + import envpool + from envpool.registration import base_path, package_base_path + + envpool_package = Path(envpool.__file__).resolve().parent + assets_package = Path(envpool_assets.asset_path()).resolve() + + if _is_relative_to(envpool_package, source_root): + raise RuntimeError( + f"imported EnvPool from source tree: {envpool_package}" + ) + + asset_version = Version(version("envpool-assets")) + if not Version("0.2.0") <= asset_version < Version("0.3.0"): + raise RuntimeError( + f"unexpected envpool-assets version: {asset_version}" + ) + + if Path(base_path).resolve() != assets_package: + raise RuntimeError( + f"EnvPool asset base_path is {base_path}, expected {assets_package}" + ) + if Path(package_base_path).resolve() != envpool_package: + raise RuntimeError( + "EnvPool package_base_path is " + f"{package_base_path}, expected {envpool_package}" + ) + + required_assets = [ + assets_package / "atari/roms/pong.bin", + assets_package / "gfootball/assets/data", + assets_package / "mujoco/assets_dmc/cartpole.xml", + assets_package + / "mujoco/myosuite/assets/myosuite/envs/myo/assets/hand/myohand_pose.xml", + assets_package / "procgen/assets/platformer/playerBlue_dead.png", + assets_package / "vizdoom/bin/freedoom2.wad", + assets_package / "vizdoom/maps/basic.wad", + ] + missing = [path for path in required_assets if not path.exists()] + if missing: + raise RuntimeError(f"envpool-assets missing required files: {missing}") + + required_package_files = [ + envpool_package / "vizdoom/bin/vizdoom", + envpool_package / "vizdoom/bin/vizdoom.pk3", + ] + missing = [path for path in required_package_files if not path.exists()] + if missing: + raise RuntimeError(f"envpool wheel missing required files: {missing}") + + print(f"envpool installed at {envpool_package}") + print(f"envpool-assets {asset_version} installed at {assets_package}") + + +if __name__ == "__main__": + main() diff --git a/setup.cfg b/setup.cfg index f2fa3acdf..4c4d5dff9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = envpool -version = 1.2.0 +version = 1.2.2 author = "EnvPool Contributors" author_email = "sail@sea.com" description = "C++-based high-performance parallel environment execution engine (vectorized env) for general RL environments." @@ -26,6 +26,7 @@ packages = find: python_requires = >=3.11 install_requires = dm-env>=1.4 + envpool-assets>=0.2.0,<0.3.0 glfw>=2.10.0 gymnasium>=0.26 numpy>=1.19 @@ -43,16 +44,9 @@ envpool = **/*_envpool.so **/*_envpool.pyd **/*_envpool.pyd.dll - atari/roms/*.bin mujoco/*.so.* - mujoco/assets*/**/*.xml - mujoco/assets_dmc/dog_assets/* - mujoco/metaworld/assets/**/* - mujoco/robotics/assets/**/* - gfootball/assets/**/* - procgen/assets/**/*.png - vizdoom/bin/* - vizdoom/maps/* + vizdoom/bin/vizdoom + vizdoom/bin/vizdoom.pk3 [mypy] allow_redefinition = True diff --git a/third_party/freedoom/defs.bzl b/third_party/freedoom/defs.bzl new file mode 100644 index 000000000..2e9c5c164 --- /dev/null +++ b/third_party/freedoom/defs.bzl @@ -0,0 +1,53 @@ +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Repository rule for the Freedoom release asset.""" + +def _freedoom_archive_impl(ctx): + last_error = "download was not attempted" + for attempt in range(ctx.attr.attempts): + ctx.report_progress( + "Fetching Freedoom asset, attempt %d/%d" % + (attempt + 1, ctx.attr.attempts), + ) + result = ctx.download_and_extract( + url = ctx.attr.urls, + sha256 = ctx.attr.sha256, + type = ctx.attr.type, + strip_prefix = ctx.attr.strip_prefix, + canonical_id = ctx.attr.canonical_id, + allow_fail = True, + ) + if result.success: + ctx.file("BUILD.bazel", ctx.read(ctx.attr.build_file)) + return + last_error = result.error + + fail("failed to fetch Freedoom asset after %d attempts: %s" % ( + ctx.attr.attempts, + last_error, + )) + +freedoom_archive = repository_rule( + implementation = _freedoom_archive_impl, + attrs = { + "attempts": attr.int(default = 8), + "build_file": attr.label(allow_single_file = True, mandatory = True), + "canonical_id": attr.string(default = "freedoom-0.12.1.zip"), + "sha256": attr.string(mandatory = True), + "strip_prefix": attr.string(default = ""), + "type": attr.string(default = ""), + "urls": attr.string_list(mandatory = True), + }, +) diff --git a/third_party/mujoco/mujoco.BUILD b/third_party/mujoco/mujoco.BUILD index 733fd37fd..8d96ccc05 100644 --- a/third_party/mujoco/mujoco.BUILD +++ b/third_party/mujoco/mujoco.BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_cc//cc:defs.bzl", "cc_library") +load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library") package(default_visibility = ["//visibility:public"]) @@ -33,6 +33,149 @@ cc_library( "src/render/classic/*.h", ]) + glob([ "src/render/classic/glad/*", + ]) + glob([ + "src/ui/*.c", + "src/ui/*.h", + ]) + glob([ + "src/user/*.c", + "src/user/*.cc", + "src/user/*.h", + ]) + glob([ + "src/xml/*.c", + "src/xml/*.cc", + "src/xml/*.h", + ]) + ), + hdrs = glob([ + "include/mujoco/*.h", + "include/mujoco/experimental/**/*.h", + "src/render/classic/**/*.h", + "src/render/classic/**/*.inc", + ]), + copts = [ + "-DCCD_STATIC_DEFINE", + ] + select({ + "@envpool//:windows": [ + # CI fastbuild otherwise compiles MuJoCo at /Od while the pinned + # official oracle wheel is release-built. Upstream MuJoCo's CMake + # build also enables AVX platform SIMD on MSVC when available. + "/O2", + "/arch:AVX", + ], + "//conditions:default": [ + "-D_GNU_SOURCE", + "-Wno-int-in-bool-context", + "-Wno-maybe-uninitialized", + "-Wno-sign-compare", + "-Wno-stringop-overflow", + "-Wno-stringop-truncation", + ], + }) + select({ + # Match upstream MuJoCo's default CMake build on Linux x86_64. The + # pinned official oracle wheel enables AVX platform SIMD there. + "@envpool//:linux_x86_64": [ + # CI runs Bazel tests in fastbuild, while the official Python + # wheel is release-built. Keep MuJoCo's integrator codegen aligned + # with the wheel instead of compensating with looser oracle checks. + "-O3", + "-mavx", + "-mpclmul", + ], + "//conditions:default": [], + }), + cxxopts = select({ + "@envpool//:windows": ["/std:c++20"], + "//conditions:default": ["-std=c++20"], + }), + defines = ["MJ_STATIC"] + select({ + "@envpool//:linux_x86_64": ["mjUSEPLATFORMSIMD"], + "@envpool//:windows": ["mjUSEPLATFORMSIMD"], + "//conditions:default": [], + }), + # Coverage instrumentation perturbs MuJoCo's floating-point integrator on + # Linux enough to invalidate long oracle rollouts. Keep third-party + # physics code out of EnvPool coverage instead of widening oracle drift. + features = ["-coverage"], + includes = [ + "include", + "include/mujoco", + "src", + ], + linkopts = select({ + "@envpool//:linux": [ + "-ldl", + ], + "//conditions:default": [], + }), + linkstatic = 1, + deps = [ + "@ccd", + "@lodepng", + "@marchingcubecpp", + "@qhull", + "@tinyobjloader", + "@tinyxml2", + ], +) + +cc_binary( + name = "libmujoco.so.3.6.0", + linkopts = select({ + "@envpool//:linux": ["-Wl,-soname,libmujoco.so.3.6.0"], + "//conditions:default": [], + }), + linkshared = True, + linkstatic = True, + deps = [":mujoco_shared_export_lib"], +) + +cc_binary( + name = "libmujoco.3.6.0.dylib", + linkopts = select({ + "@platforms//os:osx": ["-Wl,-install_name,@rpath/mujoco.framework/Versions/A/libmujoco.3.6.0.dylib"], + "//conditions:default": [], + }), + linkshared = True, + linkstatic = True, + deps = [":mujoco_shared_export_lib"], +) + +cc_binary( + name = "mujoco.dll", + linkshared = True, + linkstatic = True, + deps = [":mujoco_shared_export_lib"], +) + +filegroup( + name = "mujoco_shared_lib", + srcs = select({ + "@envpool//:windows": [":mujoco.dll"], + "@platforms//os:osx": [":libmujoco.3.6.0.dylib"], + "//conditions:default": [":libmujoco.so.3.6.0"], + }), +) + +cc_library( + name = "mujoco_shared_export_lib", + srcs = ( + glob(["src/cc/*.h"]) + glob([ + "src/engine/*.c", + "src/engine/*.cc", + "src/engine/*.h", + ]) + glob([ + "src/thread/*.c", + "src/thread/*.cc", + "src/thread/*.h", + ]) + glob([ + "src/render/classic/*.c", + "src/render/classic/*.cc", + "src/render/classic/*.h", + ]) + glob([ + "src/render/classic/glad/*", + ]) + glob([ + "src/ui/*.c", + "src/ui/*.h", ]) + glob([ "src/user/*.c", "src/user/*.cc", @@ -52,7 +195,10 @@ cc_library( copts = [ "-DCCD_STATIC_DEFINE", ] + select({ - "@envpool//:windows": [], + "@envpool//:windows": [ + "/O2", + "/arch:AVX", + ], "//conditions:default": [ "-D_GNU_SOURCE", "-Wno-int-in-bool-context", @@ -61,12 +207,24 @@ cc_library( "-Wno-stringop-overflow", "-Wno-stringop-truncation", ], + }) + select({ + "@envpool//:linux_x86_64": [ + "-O3", + "-mavx", + "-mpclmul", + ], + "//conditions:default": [], }), cxxopts = select({ "@envpool//:windows": ["/std:c++20"], "//conditions:default": ["-std=c++20"], }), - defines = ["MJ_STATIC"], + defines = ["MUJOCO_DLL_EXPORTS"] + select({ + "@envpool//:linux_x86_64": ["mjUSEPLATFORMSIMD"], + "@envpool//:windows": ["mjUSEPLATFORMSIMD"], + "//conditions:default": [], + }), + features = ["-coverage"], includes = [ "include", "include/mujoco", @@ -87,6 +245,7 @@ cc_library( "@tinyobjloader", "@tinyxml2", ], + alwayslink = True, ) cc_library( diff --git a/third_party/myosuite/BUILD b/third_party/myosuite/BUILD new file mode 100644 index 000000000..c8f59eba1 --- /dev/null +++ b/third_party/myosuite/BUILD @@ -0,0 +1,91 @@ +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@python_versions//3.12:defs.bzl", py_binary_312 = "py_binary") +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("//third_party/myosuite:defs.bzl", "myosuite_native_assets") +load("//third_party/myosuite:oracle_requirements.bzl", "oracle_requirement") + +exports_files([ + "oracle_requirements.txt", + "myosuite_source.BUILD", + "simhive_source.BUILD", + "generate_render_sample.py", +]) + +py_binary_312( + name = "generate_native_assets", + srcs = [ + "generate_native_assets.py", + "generate_reference_data.py", + "generate_task_metadata.py", + "generate_task_registry.py", + ], + deps = [ + oracle_requirement("click"), + oracle_requirement("dm-control"), + oracle_requirement("flatten-dict"), + oracle_requirement("gitpython"), + oracle_requirement("gymnasium"), + oracle_requirement("h5py"), + oracle_requirement("mujoco"), + oracle_requirement("numpy"), + oracle_requirement("packaging"), + oracle_requirement("pillow"), + oracle_requirement("pink-noise-rl"), + oracle_requirement("sk-video"), + oracle_requirement("termcolor"), + ], +) + +py_binary_312( + name = "generate_runtime_assets", + srcs = ["generate_runtime_assets.py"], + visibility = ["//visibility:public"], + deps = [ + oracle_requirement("mujoco"), + oracle_requirement("typing_extensions"), + ], +) + +myosuite_native_assets( + name = "gen_myosuite_native_assets", + furniture_srcs = ["@myosuite_furniture_sim//:source"], + generator = ":generate_native_assets", + mpl_srcs = ["@myosuite_mpl_sim//:source"], + myo_srcs = ["@myosuite_myo_sim//:source"], + myosuite_srcs = ["@myosuite_source//:source"], + object_srcs = ["@myosuite_object_sim//:source"], + ycb_srcs = ["@myosuite_ycb_sim//:source"], +) + +filegroup( + name = "myosuite_native_headers", + srcs = [":gen_myosuite_native_assets"], + output_group = "headers", + visibility = ["//visibility:public"], +) + +filegroup( + name = "myosuite_generated_json", + srcs = [":gen_myosuite_native_assets"], + output_group = "json", + visibility = ["//visibility:public"], +) + +cc_library( + name = "myosuite_tasks", + hdrs = [":myosuite_native_headers"], + visibility = ["//visibility:public"], +) diff --git a/third_party/myosuite/defs.bzl b/third_party/myosuite/defs.bzl new file mode 100644 index 000000000..559d9ac46 --- /dev/null +++ b/third_party/myosuite/defs.bzl @@ -0,0 +1,239 @@ +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runtime asset assembler for pinned MyoSuite sources.""" + +_MYODM_OBJECTS = [ + "airplane", + "alarmclock", + "apple", + "banana", + "binoculars", + "bowl", + "camera", + "coffeemug", + "cubelarge", + "cubemedium", + "cubesmall", + "cup", + "cylinderlarge", + "cylindermedium", + "cylindersmall", + "duck", + "elephant", + "eyeglasses", + "flashlight", + "flute", + "gamecontroller", + "hammer", + "hand", + "headphones", + "knife", + "lightbulb", + "mouse", + "mug", + "phone", + "piggybank", + "pyramidlarge", + "pyramidmedium", + "pyramidsmall", + "scissors", + "spherelarge", + "spheremedium", + "spheresmall", + "stamp", + "stanfordbunny", + "stapler", + "teapot", + "toothbrush", + "toothpaste", + "toruslarge", + "torusmedium", + "torussmall", + "train", + "watch", + "waterbottle", + "wineglass", +] + +def _myosuite_runtime_assets_impl(ctx): + out = ctx.actions.declare_directory(ctx.attr.out) + manifest = ctx.actions.declare_file(ctx.label.name + "_srcs.txt") + srcs = sorted([src.path for src in ctx.files.srcs]) + ctx.actions.write(output = manifest, content = "\n".join(srcs) + "\n") + + metadata_manifest = ctx.actions.declare_file(ctx.label.name + "_metadata.txt") + metadata_srcs = sorted([src.path for src in ctx.files.metadata_srcs]) + ctx.actions.write( + output = metadata_manifest, + content = "\n".join(metadata_srcs) + ("\n" if metadata_srcs else ""), + ) + + object_manifest = ctx.actions.declare_file(ctx.label.name + "_objects.txt") + ctx.actions.write(output = object_manifest, content = "\n".join(_MYODM_OBJECTS) + "\n") + + args = ctx.actions.args() + args.add(out.path) + args.add(manifest.path) + args.add(object_manifest.path) + args.add(metadata_manifest.path) + + ctx.actions.run( + executable = ctx.executable.generator, + inputs = depset( + ctx.files.srcs + + ctx.files.metadata_srcs + + [manifest, metadata_manifest, object_manifest], + ), + outputs = [out], + arguments = [args], + env = { + "PATH": "/usr/sbin:/usr/bin:/bin:/opt/homebrew/bin:/usr/local/bin", + }, + mnemonic = "GenerateMyoSuiteRuntimeAssets", + progress_message = "Generating native MyoSuite runtime assets", + use_default_shell_env = True, + ) + + return [ + DefaultInfo( + files = depset([out]), + runfiles = ctx.runfiles(files = [out]), + ), + ] + +myosuite_runtime_assets = rule( + implementation = _myosuite_runtime_assets_impl, + attrs = { + "srcs": attr.label_list( + allow_files = True, + mandatory = True, + ), + "metadata_srcs": attr.label_list(allow_files = True), + "out": attr.string(mandatory = True), + "generator": attr.label( + default = "//third_party/myosuite:generate_runtime_assets", + executable = True, + cfg = "exec", + ), + }, +) + +def _write_manifest(ctx, name, files): + manifest = ctx.actions.declare_file(ctx.label.name + "_" + name + ".txt") + ctx.actions.write( + output = manifest, + content = "\n".join(sorted([src.path for src in files])) + "\n", + ) + return manifest + +def _myosuite_native_assets_impl(ctx): + tasks_json = ctx.actions.declare_file("myosuite_tasks.json") + tasks_header = ctx.actions.declare_file("myosuite_tasks.h") + metadata_json = ctx.actions.declare_file("myosuite_task_metadata.json") + metadata_header = ctx.actions.declare_file("myosuite_task_metadata.h") + oracle_json = ctx.actions.declare_file("myosuite_oracle_metadata.json") + reference_header = ctx.actions.declare_file("myosuite_reference_data.h") + + myosuite_manifest = _write_manifest(ctx, "myosuite", ctx.files.myosuite_srcs) + mpl_manifest = _write_manifest(ctx, "mpl", ctx.files.mpl_srcs) + ycb_manifest = _write_manifest(ctx, "ycb", ctx.files.ycb_srcs) + furniture_manifest = _write_manifest(ctx, "furniture", ctx.files.furniture_srcs) + myo_manifest = _write_manifest(ctx, "myo", ctx.files.myo_srcs) + object_manifest = _write_manifest(ctx, "object", ctx.files.object_srcs) + + args = ctx.actions.args() + args.add("--myosuite-manifest", myosuite_manifest) + args.add("--mpl-manifest", mpl_manifest) + args.add("--ycb-manifest", ycb_manifest) + args.add("--furniture-manifest", furniture_manifest) + args.add("--myo-manifest", myo_manifest) + args.add("--object-manifest", object_manifest) + args.add("--out-tasks-json", tasks_json) + args.add("--out-tasks-header", tasks_header) + args.add("--out-metadata-json", metadata_json) + args.add("--out-metadata-header", metadata_header) + args.add("--out-oracle-json", oracle_json) + args.add("--out-reference-header", reference_header) + + manifests = [ + myosuite_manifest, + mpl_manifest, + ycb_manifest, + furniture_manifest, + myo_manifest, + object_manifest, + ] + ctx.actions.run( + executable = ctx.executable.generator, + inputs = depset( + ctx.files.myosuite_srcs + + ctx.files.mpl_srcs + + ctx.files.ycb_srcs + + ctx.files.furniture_srcs + + ctx.files.myo_srcs + + ctx.files.object_srcs + + manifests, + ), + outputs = [ + tasks_json, + tasks_header, + metadata_json, + metadata_header, + oracle_json, + reference_header, + ], + arguments = [args], + env = { + "PATH": "/usr/sbin:/usr/bin:/bin:/opt/homebrew/bin:/usr/local/bin", + }, + mnemonic = "GenerateMyoSuiteNativeAssets", + progress_message = "Generating native MyoSuite metadata from pinned upstream source", + use_default_shell_env = True, + ) + + headers = depset([tasks_header, metadata_header, reference_header]) + json_files = depset([tasks_json, metadata_json, oracle_json]) + all_files = depset([ + tasks_json, + tasks_header, + metadata_json, + metadata_header, + oracle_json, + reference_header, + ]) + return [ + DefaultInfo(files = all_files), + OutputGroupInfo( + headers = headers, + json = json_files, + ), + ] + +myosuite_native_assets = rule( + implementation = _myosuite_native_assets_impl, + attrs = { + "myosuite_srcs": attr.label_list(allow_files = True, mandatory = True), + "mpl_srcs": attr.label_list(allow_files = True, mandatory = True), + "ycb_srcs": attr.label_list(allow_files = True, mandatory = True), + "furniture_srcs": attr.label_list(allow_files = True, mandatory = True), + "myo_srcs": attr.label_list(allow_files = True, mandatory = True), + "object_srcs": attr.label_list(allow_files = True, mandatory = True), + "generator": attr.label( + executable = True, + cfg = "exec", + mandatory = True, + ), + }, +) diff --git a/third_party/myosuite/generate_native_assets.py b/third_party/myosuite/generate_native_assets.py new file mode 100644 index 000000000..54782bdcb --- /dev/null +++ b/third_party/myosuite/generate_native_assets.py @@ -0,0 +1,553 @@ +#!/usr/bin/env python3 +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generate native MyoSuite assets from pinned upstream archives. + +This tool runs at build/codegen time only. It imports the pinned upstream +MyoSuite source as an oracle to derive compact registry and task metadata, then +emits native C++/Python data files consumed by EnvPool. The generated files are +not checked in, and the native runtime never imports the official Python +package. +""" + +from __future__ import annotations + +import argparse +import contextlib +import importlib +import importlib.util +import json +import os +import posixpath +import shutil +import sys +import tempfile +import warnings +from pathlib import Path +from typing import Any, cast + +import numpy as np + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +import generate_reference_data +import generate_task_metadata +import generate_task_registry + +_ORACLE_VERSION = generate_task_registry.ORACLE_VERSION +_ORACLE_COMMIT = generate_task_registry.ORACLE_COMMIT +_BROKEN_IDS = set(generate_task_registry.BROKEN_IDS) +_SIMHIVE_DIRS = { + "mpl": "MPL_sim", + "ycb": "YCB_sim", + "furniture": "furniture_sim", + "myo": "myo_sim", + "object": "object_sim", +} +_SIMHIVE_REPOS = { + "mpl": "myosuite_mpl_sim", + "ycb": "myosuite_ycb_sim", + "furniture": "myosuite_furniture_sim", + "myo": "myosuite_myo_sim", + "object": "myosuite_object_sim", +} +_WINDOWS_SHORT_IMPORT_PACKAGES = ("mujoco", "h5py") +_DLL_DIRECTORY_HANDLES: list[Any] = [] + + +def _manifest_paths(path: Path) -> list[Path]: + return [ + Path(line) for line in path.read_text().splitlines() if line.strip() + ] + + +def _common_root(paths: list[Path]) -> Path: + if not paths: + raise ValueError("empty source manifest") + return Path(os.path.commonpath([str(path) for path in paths])) + + +def _repo_root(paths: list[Path], repo: str) -> Path: + for path in paths: + parts = path.parts + if repo in parts: + idx = parts.index(repo) + return Path(*parts[: idx + 1]) + return _common_root(paths) + + +def _myosuite_source_root(paths: list[Path]) -> Path: + repo_root = _repo_root(paths, "myosuite_source") + if (repo_root / "myosuite").is_dir(): + return repo_root + root = _common_root(paths) + if root.name == "myosuite": + return root.parent + package = root / "myosuite" + if package.is_dir(): + return root + for path in paths: + parts = path.parts + if "myosuite" in parts: + idx = parts.index("myosuite") + return Path(*parts[:idx]) + raise ValueError(f"could not infer MyoSuite source root from {root}") + + +def _patch_codegen_only_imports(package: Path) -> None: + import_utils = package / "utils" / "import_utils.py" + text = import_utils.read_text() + eager_git = "from os.path import expanduser\nimport git\n\n\n" + fetch_def = ( + "def fetch_git(repo_url, commit_hash, clone_directory, " + "clone_path=None):\n" + ) + lazy_fetch = fetch_def + " import git\n" + if eager_git not in text: + if lazy_fetch in text: + return + raise ValueError("unexpected MyoSuite import_utils.py layout") + text = text.replace(eager_git, "from os.path import expanduser\n\n\n", 1) + if fetch_def not in text: + raise ValueError("unexpected MyoSuite fetch_git layout") + text = text.replace(fetch_def, lazy_fetch, 1) + if import_utils.is_symlink(): + import_utils.unlink() + import_utils.write_text(text) + + +@contextlib.contextmanager +def _assembled_source( + source_manifest: Path, sim_manifests: dict[str, Path] +) -> Any: + source_root = _myosuite_source_root(_manifest_paths(source_manifest)) + sim_roots = { + key: _repo_root(_manifest_paths(manifest), _SIMHIVE_REPOS[key]) + for key, manifest in sim_manifests.items() + } + with tempfile.TemporaryDirectory( + prefix="envpool-myosuite-src-", + ignore_cleanup_errors=os.name == "nt", + ) as tmp: + root = Path(tmp) + package = root / "myosuite" + shutil.copytree(source_root / "myosuite", package, symlinks=True) + _patch_codegen_only_imports(package) + simhive = package / "simhive" + simhive.mkdir(exist_ok=True) + for key, dirname in _SIMHIVE_DIRS.items(): + dst = simhive / dirname + if dst.exists() or dst.is_symlink(): + if dst.is_dir() and not dst.is_symlink(): + shutil.rmtree(dst) + else: + dst.unlink() + shutil.copytree(sim_roots[key], dst, symlinks=True) + old_path = list(sys.path) + for name in tuple(sys.modules): + if name == "myosuite" or name.startswith("myosuite."): + del sys.modules[name] + sys.path.insert(0, str(root)) + try: + yield root + finally: + sys.path[:] = old_path + for name in tuple(sys.modules): + if name == "myosuite" or name.startswith("myosuite."): + del sys.modules[name] + + +def _import_official() -> tuple[Any, Any, Any]: + official_myosuite = importlib.import_module("myosuite") + from myosuite import gym_registry_specs + from myosuite.utils import gym + + return official_myosuite, gym_registry_specs, gym + + +def _copy_short_import_package( + source_root: Path, package_name: str +) -> list[Path]: + spec = importlib.util.find_spec(package_name) + if spec is None or spec.submodule_search_locations is None: + return [] + source = Path(next(iter(spec.submodule_search_locations))) + if not source.is_dir(): + return [] + destination_root = source_root / "_oracle_site" + destination = destination_root / package_name + if not destination.exists(): + shutil.copytree( + source, + destination, + symlinks=False, + ignore=shutil.ignore_patterns("__pycache__"), + ) + copied = [destination] + sibling = source.parent / f"{package_name}.libs" + if sibling.is_dir(): + sibling_destination = destination_root / sibling.name + if not sibling_destination.exists(): + shutil.copytree(sibling, sibling_destination, symlinks=False) + copied.append(sibling_destination) + return copied + + +def _shorten_windows_binary_imports(source_root: Path) -> None: + if os.name != "nt": + return + destination_root = source_root / "_oracle_site" + copied: list[Path] = [] + for package_name in _WINDOWS_SHORT_IMPORT_PACKAGES: + copied.extend(_copy_short_import_package(source_root, package_name)) + for name in tuple(sys.modules): + if name == package_name or name.startswith(f"{package_name}."): + del sys.modules[name] + if copied: + sys.path.insert(0, str(destination_root)) + if hasattr(os, "add_dll_directory"): + for path in copied: + _DLL_DIRECTORY_HANDLES.append(os.add_dll_directory(str(path))) + + +def _jsonable(value: Any) -> Any: + if isinstance(value, dict): + return {str(key): _jsonable(item) for key, item in value.items()} + if isinstance(value, list | tuple): + return [_jsonable(item) for item in value] + if isinstance(value, np.generic): + return value.item() + array = np.asarray(value) + if array.ndim == 0: + return array.item() + if array.dtype == object: + return [str(item) for item in array.ravel()] + return array.tolist() + + +def _names_from_ids(model: Any, obj_type: Any, ids: list[int]) -> list[str]: + import mujoco + + raw_model = model.ptr if hasattr(model, "ptr") else model + return [ + mujoco.mj_id2name(raw_model, int(obj_type), int(obj_id)) + for obj_id in ids + ] + + +def _state_report(env: Any) -> dict[str, Any]: + model = env.sim.model + data = env.sim.data + return { + "act": _jsonable(data.act) if model.na > 0 else [], + "ctrl": _jsonable(data.ctrl), + "qacc_warmstart": _jsonable(data.qacc_warmstart), + "body_pos": _jsonable(model.body_pos), + "body_quat": _jsonable(model.body_quat), + "mocap_pos": _jsonable(data.mocap_pos), + "mocap_quat": _jsonable(data.mocap_quat), + "qpos": _jsonable(data.qpos), + "qvel": _jsonable(data.qvel), + "site_pos": _jsonable(model.site_pos), + "site_quat": _jsonable(model.site_quat), + "time": float(data.time), + } + + +def _metadata_report(task_ids: list[str], gym: Any) -> dict[str, Any]: + import mujoco + + tasks: dict[str, dict[str, Any]] = {} + for task_id in task_ids: + env = gym.make(task_id) + try: + unwrapped = env.unwrapped + model = unwrapped.sim.model + data = unwrapped.sim.data + task: dict[str, Any] = { + "action_shape": list(env.action_space.shape), + "entry_class": type(unwrapped).__name__, + "frame_skip": int(unwrapped.frame_skip), + "init_qpos": _jsonable(unwrapped.init_qpos), + "init_qvel": _jsonable(unwrapped.init_qvel), + "model_nq": int(model.nq), + "model_nv": int(model.nv), + "model_na": int(model.na), + "model_nu": int(model.nu), + "obs_keys": list(unwrapped.obs_keys), + "observation_shape": list(env.observation_space.shape), + "rwd_keys_wt": dict(unwrapped.rwd_keys_wt), + } + for attr in ( + "far_th", + "goal_th", + "hip_period", + "max_rot", + "min_height", + "normalize_act", + "pose_thd", + "reset_type", + "target_rot", + "target_x_vel", + "target_y_vel", + "terrain", + "variant", + ): + if hasattr(unwrapped, attr): + task[attr] = _jsonable(getattr(unwrapped, attr)) + if hasattr(unwrapped, "tip_sids"): + task["tip_sites"] = _names_from_ids( + model, mujoco.mjtObj.mjOBJ_SITE, unwrapped.tip_sids + ) + if hasattr(unwrapped, "target_sids"): + task["target_sites"] = _names_from_ids( + model, mujoco.mjtObj.mjOBJ_SITE, unwrapped.target_sids + ) + if hasattr(unwrapped, "target_jnt_ids"): + task["target_joints"] = _names_from_ids( + model, mujoco.mjtObj.mjOBJ_JOINT, unwrapped.target_jnt_ids + ) + for attr in ( + "target_jnt_range", + "target_jnt_value", + "target_reach_range", + ): + if hasattr(unwrapped, attr): + task[attr] = _jsonable(getattr(unwrapped, attr)) + task["initial_state"] = { + "qpos": _jsonable(data.qpos), + "qvel": _jsonable(data.qvel), + "act": _jsonable(data.act) if model.na > 0 else [], + "qacc_warmstart": _jsonable(data.qacc_warmstart), + "site_pos": _jsonable(model.site_pos), + "site_quat": _jsonable(model.site_quat), + "body_pos": _jsonable(model.body_pos), + "body_quat": _jsonable(model.body_quat), + } + env.reset(seed=0) + task["reset_state"] = _state_report(unwrapped) + tasks[task_id] = task + finally: + env.close() + return {"tasks": tasks, "version": _ORACLE_VERSION} + + +def _kind(entry_point: str, task_id: str) -> str: + if "myodm_v0:TrackEnv" in entry_point: + return "kMyoDmTrack" + if "torso_v0" in entry_point: + return "kTorsoPose" + if "pose_v0" in entry_point: + return "kPose" + if "walk_v0:ReachEnvV0" in entry_point: + return "kWalkReach" + if "walk_v0:WalkEnvV0" in entry_point: + return "kWalk" + if "walk_v0:TerrainEnvV0" in entry_point: + return "kTerrain" + if "reach_v0" in entry_point: + return "kReach" + if "key_turn_v0" in entry_point: + return "kKeyTurn" + if "obj_hold_v0" in entry_point: + return "kObjHoldRandom" if "Random" in task_id else "kObjHoldFixed" + if "pen_v0" in entry_point: + return "kPenTwirlRandom" if "Random" in task_id else "kPenTwirlFixed" + if "reorient_sar_v0" in entry_point: + return "kReorientSar" + if "baoding" in entry_point: + return "kChallengeBaoding" + if "bimanual" in entry_point: + return "kChallengeBimanual" + if "chasetag" in entry_point: + return "kChallengeChaseTag" + if "relocate" in entry_point: + return "kChallengeRelocate" + if "reorient_v0" in entry_point: + return "kChallengeReorient" + if "run_track" in entry_point: + return "kChallengeRunTrack" + if "soccer" in entry_point: + return "kChallengeSoccer" + if "tabletennis" in entry_point: + return "kChallengeTableTennis" + raise ValueError(f"unknown MyoSuite task kind for {task_id}: {entry_point}") + + +def _muscle(kwargs: dict[str, Any]) -> str: + value = kwargs.get("muscle_condition", "") + return { + "sarcopenia": "kSarcopenia", + "fatigue": "kFatigue", + "reafferentation": "kReafferentation", + }.get(value, "kNormal") + + +def _normalize_path(value: str, source_root: Path) -> str: + if not value: + return "" + path = Path(value) + if not path.is_absolute(): + return value.lstrip("/") + marker = "/myosuite/" + raw = str(path) + if marker in raw: + return "myosuite/" + posixpath.normpath(raw.split(marker, 1)[1]) + try: + rel = path.resolve().relative_to(source_root.resolve() / "myosuite") + return "myosuite/" + rel.as_posix() + except ValueError: + return value + + +def _task_from_spec( + task_id: str, + spec: Any, + source_root: Path, + metadata: dict[str, dict[str, Any]], +) -> dict[str, Any]: + kwargs = dict(getattr(spec, "kwargs", {}) or {}) + entry_point = str(spec.entry_point) + object_name = str(kwargs.get("object_name", "")) + reference_path = "" + model_path = _normalize_path(str(kwargs.get("model_path", "")), source_root) + if "ArmReach" in task_id: + model_path = "myosuite/simhive/myo_sim/arm/myoarm_reach.xml" + if _kind(entry_point, task_id) == "kChallengeTableTennis": + model_path = ( + "myosuite/envs/myo/assets/arm/myoarm_tabletennis_native.xml" + ) + if _kind(entry_point, task_id) == "kMyoDmTrack": + model_path = ( + f"myosuite/envs/myo/assets/hand/myohand_object_{object_name}.xml" + ) + reference = kwargs.get("reference") + if isinstance(reference, str): + reference_path = _normalize_path(reference, source_root) + task = { + "id": task_id, + "entry_point": entry_point, + "kind": _kind(entry_point, task_id), + "model_path": model_path, + "reference_path": reference_path, + "object_name": object_name, + "obs_dim": 0, + "action_dim": 0, + "max_episode_steps": int(spec.max_episode_steps), + "frame_skip": int(kwargs.get("frame_skip", 10)), + "normalize_act": bool( + metadata.get(task_id, {}).get( + "normalize_act", kwargs.get("normalize_act", False) + ) + ), + "muscle": _muscle(kwargs), + "oracle_numpy2_broken": task_id in _BROKEN_IDS, + } + if task_id in metadata: + task["obs_dim"] = int(metadata[task_id]["observation_shape"][0]) + task["action_dim"] = int(metadata[task_id]["action_shape"][0]) + task["frame_skip"] = int(metadata[task_id]["frame_skip"]) + else: + raise ValueError(f"missing metadata for {task_id}") + return task + + +def _write_outputs(args: argparse.Namespace, source_root: Path) -> None: + _shorten_windows_binary_imports(source_root) + official_myosuite, gym_registry_specs, gym = _import_official() + if official_myosuite.__version__ != _ORACLE_VERSION: + raise ValueError( + f"expected MyoSuite {_ORACLE_VERSION}, " + f"got {official_myosuite.__version__}" + ) + task_ids = list(official_myosuite.myosuite_env_suite) + metadata_report = _metadata_report(task_ids, gym) + metadata = cast(dict[str, dict[str, Any]], metadata_report["tasks"]) + specs = gym_registry_specs() + tasks = [ + _task_from_spec(task_id, specs[task_id], source_root, metadata) + for task_id in task_ids + ] + + generate_task_registry._write_json(tasks, args.out_tasks_json) + generate_task_registry._write_header(tasks, args.out_tasks_header) + + metadata_entries = [ + generate_task_metadata._entry(task, metadata.get(task["id"])) + for task in tasks + ] + args.out_metadata_json.write_text( + json.dumps(metadata_entries, indent=2, sort_keys=True) + "\n" + ) + generate_task_metadata._write_header( + metadata_entries, args.out_metadata_header + ) + args.out_oracle_json.write_text( + json.dumps( + { + "commit": _ORACLE_COMMIT, + "numpy2_broken_ids": sorted(_BROKEN_IDS), + "version": _ORACLE_VERSION, + }, + indent=2, + sort_keys=True, + ) + + "\n" + ) + + reference_entries = [ + generate_reference_data._reference_entry(task, source_root) + for task in tasks + if task["kind"] == "kMyoDmTrack" + ] + generate_reference_data._write_header( + reference_entries, args.out_reference_header + ) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--myosuite-manifest", type=Path, required=True) + parser.add_argument("--mpl-manifest", type=Path, required=True) + parser.add_argument("--ycb-manifest", type=Path, required=True) + parser.add_argument("--furniture-manifest", type=Path, required=True) + parser.add_argument("--myo-manifest", type=Path, required=True) + parser.add_argument("--object-manifest", type=Path, required=True) + parser.add_argument("--out-tasks-json", type=Path, required=True) + parser.add_argument("--out-tasks-header", type=Path, required=True) + parser.add_argument("--out-metadata-json", type=Path, required=True) + parser.add_argument("--out-metadata-header", type=Path, required=True) + parser.add_argument("--out-oracle-json", type=Path, required=True) + parser.add_argument("--out-reference-header", type=Path, required=True) + return parser.parse_args() + + +def main() -> None: + """Generate all native MyoSuite registry and metadata assets.""" + os.environ.setdefault("ROBOHIVE_VERBOSITY", "SILENT") + warnings.filterwarnings("ignore") + args = _parse_args() + sim_manifests = { + "mpl": args.mpl_manifest, + "ycb": args.ycb_manifest, + "furniture": args.furniture_manifest, + "myo": args.myo_manifest, + "object": args.object_manifest, + } + with _assembled_source(args.myosuite_manifest, sim_manifests) as root: + _write_outputs(args, root) + + +if __name__ == "__main__": + main() diff --git a/third_party/myosuite/generate_reference_data.py b/third_party/myosuite/generate_reference_data.py new file mode 100644 index 000000000..45c32729c --- /dev/null +++ b/third_party/myosuite/generate_reference_data.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python3 +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generate native MyoDM reference trajectories from pinned MyoSuite assets.""" + +from __future__ import annotations + +import argparse +import json +import re +from pathlib import Path +from typing import Any + +import numpy as np + + +def _flat(values: Any) -> list[float]: + if values is None: + return [] + array = np.asarray(values, dtype=np.float64) + return [float(value) for value in array.ravel()] + + +def _shape(values: Any) -> tuple[int, int]: + if values is None: + return (0, 0) + array = np.asarray(values, dtype=np.float64) + if array.ndim == 1: + return (1, int(array.shape[0])) + if array.ndim == 2: + return (int(array.shape[0]), int(array.shape[1])) + raise ValueError(f"expected 1D or 2D reference array, got {array.shape}") + + +def _fixed_reference(randomized: bool) -> dict[str, Any]: + dof_robot = 29 + if randomized: + return { + "type": "kRandom", + "time": [0.0, 4.0], + "robot": np.zeros((2, dof_robot)), + "robot_vel": np.zeros((2, dof_robot)), + "object": np.asarray([ + [-0.2, -0.2, 0.1, 1.0, 0.0, 0.0, -1.0], + [0.2, 0.2, 0.1, 1.0, 0.0, 0.0, 1.0], + ]), + "robot_init": np.zeros(dof_robot), + "object_init": np.asarray([0.0, 0.0, 0.1, 1.0, 0.0, 0.0, 0.0]), + } + return { + "type": "kFixed", + "time": [0.0, 4.0], + "robot": np.zeros((1, dof_robot)), + "robot_vel": np.zeros((1, dof_robot)), + "object": np.asarray([[0.2, 0.2, 0.1, 1.0, 0.0, 0.0, 0.1]]), + "robot_init": np.zeros(dof_robot), + "object_init": np.asarray([-0.2, -0.2, 0.1, 1.0, 0.0, 0.0, 0.0]), + } + + +def _reference_entry(task: dict[str, Any], source_root: Path) -> dict[str, Any]: + ref_path = task.get("reference_path") or "" + if ref_path: + path = source_root / ref_path + data = dict(np.load(path).items()) + data.setdefault("robot_vel", None) + data["type"] = "kTrack" + else: + data = _fixed_reference(task["id"].endswith("Random-v0")) + + robot_rows, robot_cols = _shape(data.get("robot")) + robot_vel_rows, robot_vel_cols = _shape(data.get("robot_vel")) + object_rows, object_cols = _shape(data.get("object")) + return { + "id": task["id"], + "type": data["type"], + "time": _flat(data.get("time")), + "robot": _flat(data.get("robot")), + "robot_rows": robot_rows, + "robot_cols": robot_cols, + "robot_vel": _flat(data.get("robot_vel")), + "robot_vel_rows": robot_vel_rows, + "robot_vel_cols": robot_vel_cols, + "object": _flat(data.get("object")), + "object_rows": object_rows, + "object_cols": object_cols, + "robot_init": _flat(data.get("robot_init")), + "object_init": _flat(data.get("object_init")), + } + + +def _name(task_id: str, suffix: str) -> str: + parts = re.split(r"[^0-9A-Za-z]+", f"{task_id}_{suffix}") + stem = "".join(part[:1].upper() + part[1:] for part in parts if part) + return f"kMyoSuiteReference{stem}" + + +def _array(name: str, values: list[float]) -> list[str]: + lines = [f"inline constexpr std::array {name} = {{"] + if values: + for start in range(0, len(values), 6): + chunk = values[start : start + 6] + lines.append( + " " + ", ".join(f"{value:.17g}" for value in chunk) + "," + ) + lines.append("};") + return lines + + +def _write_header(entries: list[dict[str, Any]], output: Path) -> None: + lines = [ + "// Copyright 2026 Garena Online Private Limited", + "//", + '// Licensed under the Apache License, Version 2.0 (the "License");', + "// you may not use this file except in compliance with the License.", + "// You may obtain a copy of the License at", + "//", + "// http://www.apache.org/licenses/LICENSE-2.0", + "//", + "// Unless required by applicable law or agreed to in writing, software", + '// distributed under the License is distributed on an "AS IS" BASIS,', + "// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.", + "// See the License for the specific language governing permissions and", + "// limitations under the License.", + "", + "// Generated from pinned MyoSuite MyoDM reference assets; do not edit by hand.", + "#ifndef THIRD_PARTY_MYOSUITE_MYOSUITE_REFERENCE_DATA_H_", + "#define THIRD_PARTY_MYOSUITE_MYOSUITE_REFERENCE_DATA_H_", + "", + "#include ", + "#include ", + "#include ", + "#include ", + "#include ", + "", + "namespace third_party::myosuite {", + "", + "enum class MyoSuiteReferenceType : std::uint8_t {", + " kNone,", + " kFixed,", + " kRandom,", + " kTrack,", + "};", + "", + "struct MyoSuiteReferenceData {", + " const char* id;", + " MyoSuiteReferenceType type;", + " const double* time;", + " int time_size;", + " const double* robot;", + " int robot_rows;", + " int robot_cols;", + " const double* robot_vel;", + " int robot_vel_rows;", + " int robot_vel_cols;", + " const double* object;", + " int object_rows;", + " int object_cols;", + " const double* robot_init;", + " int robot_init_size;", + " const double* object_init;", + " int object_init_size;", + "};", + "", + "inline constexpr std::array kMyoSuiteEmptyReference = {};", + "", + "// clang-format off", + ] + for entry in entries: + for key in ( + "time", + "robot", + "robot_vel", + "object", + "robot_init", + "object_init", + ): + lines.extend(_array(_name(entry["id"], key), entry[key])) + lines.append("") + + lines.append( + f"inline constexpr std::array " + "kMyoSuiteReferenceData = {{" + ) + for entry in entries: + ident = entry["id"] + lines.extend([ + " MyoSuiteReferenceData{", + f' "{ident}",', + f" MyoSuiteReferenceType::{entry['type']},", + f" {_name(ident, 'time')}.data(),", + f" {len(entry['time'])},", + f" {_name(ident, 'robot')}.data(),", + f" {entry['robot_rows']},", + f" {entry['robot_cols']},", + f" {_name(ident, 'robot_vel')}.data(),", + f" {entry['robot_vel_rows']},", + f" {entry['robot_vel_cols']},", + f" {_name(ident, 'object')}.data(),", + f" {entry['object_rows']},", + f" {entry['object_cols']},", + f" {_name(ident, 'robot_init')}.data(),", + f" {len(entry['robot_init'])},", + f" {_name(ident, 'object_init')}.data(),", + f" {len(entry['object_init'])},", + " },", + ]) + lines.extend([ + "}};", + "// clang-format on", + "", + "inline constexpr MyoSuiteReferenceData kEmptyMyoSuiteReferenceData = {", + ' "",', + " MyoSuiteReferenceType::kNone,", + " kMyoSuiteEmptyReference.data(),", + " 0,", + " kMyoSuiteEmptyReference.data(),", + " 0,", + " 0,", + " kMyoSuiteEmptyReference.data(),", + " 0,", + " 0,", + " kMyoSuiteEmptyReference.data(),", + " 0,", + " 0,", + " kMyoSuiteEmptyReference.data(),", + " 0,", + " kMyoSuiteEmptyReference.data(),", + " 0,", + "};", + "", + "inline const MyoSuiteReferenceData& GetMyoSuiteReferenceData(", + " std::string_view task_id) {", + " for (const auto& reference : kMyoSuiteReferenceData) {", + " if (reference.id == task_id) {", + " return reference;", + " }", + " }", + " return kEmptyMyoSuiteReferenceData;", + "}", + "", + "} // namespace third_party::myosuite", + "", + "#endif // THIRD_PARTY_MYOSUITE_MYOSUITE_REFERENCE_DATA_H_", + "", + ]) + output.write_text("\n".join(lines)) + + +def main() -> None: + """Generate the C++ MyoDM reference data header.""" + parser = argparse.ArgumentParser() + parser.add_argument("--tasks", type=Path, required=True) + parser.add_argument("--myosuite-source", type=Path, required=True) + parser.add_argument("--out-header", type=Path, required=True) + args = parser.parse_args() + + tasks = json.loads(args.tasks.read_text()) + entries = [ + _reference_entry(task, args.myosuite_source) + for task in tasks + if task["kind"] == "kMyoDmTrack" + ] + _write_header(entries, args.out_header) + + +if __name__ == "__main__": + main() diff --git a/third_party/myosuite/generate_render_sample.py b/third_party/myosuite/generate_render_sample.py new file mode 100644 index 000000000..e72e1a70f --- /dev/null +++ b/third_party/myosuite/generate_render_sample.py @@ -0,0 +1,760 @@ +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Generate MyoSuite EnvPool-vs-official render samples for docs.""" + +from __future__ import annotations + +import argparse +import importlib +import json +import os +import subprocess +import sys +import tempfile +import types +from collections.abc import Mapping, Sequence +from pathlib import Path +from typing import Any + +import numpy as np +from PIL import Image, ImageDraw + +_SYNC_STATE_KEYS = ( + "qpos0", + "qvel0", + "act0", + "qacc0", + "qacc_warmstart0", + "ctrl", + "site_pos", + "site_quat", + "site_size", + "site_rgba", + "body_pos", + "body_quat", + "body_mass", + "geom_pos", + "geom_quat", + "geom_size", + "geom_rgba", + "geom_friction", + "geom_aabb", + "geom_rbound", + "geom_contype", + "geom_conaffinity", + "geom_type", + "geom_condim", + "hfield_data", + "mocap_pos", + "mocap_quat", + "fatigue_ma", + "fatigue_mr", + "fatigue_mf", + "fatigue_tl", +) +_SYNC_STATE_SIZES = { + "qpos0": "nq", + "qvel0": "nv", + "act0": "na", + "qacc0": "nv", + "qacc_warmstart0": "nv", + "ctrl": "nu", + "site_pos": "nsite3", + "site_quat": "nsite4", + "site_size": "nsite3", + "site_rgba": "nsite4", + "body_pos": "nbody3", + "body_quat": "nbody4", + "body_mass": "nbody", + "geom_pos": "ngeom3", + "geom_quat": "ngeom4", + "geom_size": "ngeom3", + "geom_rgba": "ngeom4", + "geom_friction": "ngeom3", + "geom_aabb": "ngeom6", + "geom_rbound": "ngeom", + "geom_contype": "ngeom", + "geom_conaffinity": "ngeom", + "geom_type": "ngeom", + "geom_condim": "ngeom", + "hfield_data": "nhfielddata", + "mocap_pos": "nmocap3", + "mocap_quat": "nmocap4", + "fatigue_ma": "nu", + "fatigue_mr": "nu", + "fatigue_mf": "nu", + "fatigue_tl": "nu", +} + + +def _bootstrap_envpool_namespace() -> None: + """Load envpool modules without importing envpool.entry.""" + if "envpool" in sys.modules: + return + roots = [] + for path in sys.path: + envpool_root = Path(path) / "envpool" + if ( + envpool_root / "registration.py" + ).is_file() or envpool_root.is_dir(): + roots.append(str(envpool_root)) + if not roots: + raise RuntimeError("could not locate envpool package on PYTHONPATH") + module = types.ModuleType("envpool") + module.__file__ = str(Path(roots[0]) / "__init__.py") + module.__path__ = roots # type: ignore[attr-defined] + sys.modules["envpool"] = module + + +def _make_gymnasium(task_id: str, **kwargs: object): + _bootstrap_envpool_namespace() + if not getattr(_make_gymnasium, "_registered", False): + importlib.import_module("envpool.mujoco.myosuite.registration") + _make_gymnasium._registered = True # type: ignore[attr-defined] + from envpool.registration import make_gymnasium + + return make_gymnasium(task_id, **kwargs) + + +def _label(draw: ImageDraw.ImageDraw, xy: tuple[int, int], text: str) -> None: + draw.rectangle( + (xy[0] - 3, xy[1] - 2, xy[0] + 156, xy[1] + 14), + fill=(255, 255, 255), + ) + draw.text(xy, text, fill=(0, 0, 0)) + + +def _registered_tasks() -> tuple[dict[str, object], ...]: + _bootstrap_envpool_namespace() + from envpool.mujoco.myosuite.tasks import MYOSUITE_TASKS + + return tuple(MYOSUITE_TASKS) + + +def _group_name(task_id: str) -> str: + if task_id.startswith("MyoHand"): + return "myodm" + if "Challenge" in task_id: + return "myochallenge" + return "myobase" + + +def _envpool_frames_from_actions( + task_id: str, + width: int, + height: int, + seed: int, + actions: Sequence[Sequence[float]], +) -> tuple[list[np.ndarray], list[Mapping[str, object]]]: + env = _make_gymnasium( + task_id, + num_envs=1, + seed=seed, + render_mode="rgb_array", + render_width=width, + render_height=height, + ) + try: + _, info = env.reset() + frames = [env.render()[0]] + infos: list[Mapping[str, object]] = [info] + for action in actions: + *_, info = env.step(np.asarray(action, dtype=np.float32)[None, :]) + frames.append(env.render()[0]) + infos.append(info) + return frames, infos + finally: + env.close() + + +def _midpoint_action(env: object) -> np.ndarray: + action_space = env.action_space # type: ignore[attr-defined] + low = np.asarray(action_space.low, dtype=np.float32) + high = np.asarray(action_space.high, dtype=np.float32) + return ((low + high) * 0.5).astype(np.float32) + + +def _envpool_trace_record( + task_id: str, + width: int, + height: int, + seed: int, + steps: int = 3, +) -> tuple[list[np.ndarray], list[Mapping[str, object]], dict[str, object]]: + env = _make_gymnasium( + task_id, + num_envs=1, + seed=seed, + render_mode="rgb_array", + render_width=width, + render_height=height, + ) + try: + _, info = env.reset() + frames = [env.render()[0]] + infos: list[Mapping[str, object]] = [info] + actions: list[list[float]] = [] + reset_before_step: list[bool] = [] + action = _midpoint_action(env) + for _ in range(steps): + actions.append(action.tolist()) + *_, info = env.step(action[None, :]) + frames.append(env.render()[0]) + infos.append(info) + elapsed_step = int(np.asarray(info["elapsed_step"]).ravel()[0]) + reset_before_step.append(elapsed_step == 0) + plan = { + "actions": actions, + "reset_before_step": reset_before_step, + "sync_states": [_sync_state_from_info(item) for item in infos], + } + return frames, infos, plan + finally: + env.close() + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--oracle_trace", type=Path) + parser.add_argument("--out", type=Path) + parser.add_argument("--out_dir", type=Path) + parser.add_argument("--all_tasks", action="store_true") + parser.add_argument("--quiet", action="store_true") + parser.add_argument("--debug_json", type=Path) + parser.add_argument("--task_id", default="myoFingerReachFixed-v0") + parser.add_argument("--seed", default=3, type=int) + parser.add_argument("--width", default=160, type=int) + parser.add_argument("--height", default=120, type=int) + parser.add_argument("--camera_id", default=-1, type=int) + return parser.parse_args() + + +def _runfiles_root() -> Path: + path = Path(__file__).absolute() + for parent in (path, *path.parents): + if parent.name.endswith(".runfiles"): + return parent + runfiles_dir = os.environ.get("RUNFILES_DIR") + if runfiles_dir: + return Path(runfiles_dir) + if "TEST_SRCDIR" in os.environ: + return Path(os.environ["TEST_SRCDIR"]) + return path.parents[3] + + +def _oracle_probe_path() -> Path: + runfiles = _runfiles_root() + workspace = os.environ.get("TEST_WORKSPACE", "envpool") + launcher_names = ( + ("myosuite_oracle_probe.exe", "myosuite_oracle_probe") + if sys.platform == "win32" + else ("myosuite_oracle_probe", "myosuite_oracle_probe.exe") + ) + candidates = [ + runfiles / workspace / "envpool/mujoco" / launcher + for launcher in launcher_names + ] + for candidate in candidates: + if candidate.is_file(): + return candidate + for launcher in launcher_names: + matches = list(runfiles.rglob(launcher)) + if matches: + return matches[0] + raise RuntimeError( + f"could not locate myosuite_oracle_probe under {runfiles}" + ) + + +def _sync_state_from_info(info: Mapping[str, object]) -> dict[str, Any]: + dims = { + "nq": int(np.asarray(info["model_nq"]).ravel()[0]), + "nv": int(np.asarray(info["model_nv"]).ravel()[0]), + "na": int(np.asarray(info["model_na"]).ravel()[0]), + "nu": int(np.asarray(info["model_nu"]).ravel()[0]), + "nsite": int(np.asarray(info["model_nsite"]).ravel()[0]), + "nbody": int(np.asarray(info["model_nbody"]).ravel()[0]), + "ngeom": int(np.asarray(info["model_ngeom"]).ravel()[0]), + "nhfielddata": int(np.asarray(info["model_nhfielddata"]).ravel()[0]), + "nmocap": int(np.asarray(info["model_nmocap"]).ravel()[0]), + } + dims.update({ + "nsite3": dims["nsite"] * 3, + "nsite4": dims["nsite"] * 4, + "nbody3": dims["nbody"] * 3, + "nbody4": dims["nbody"] * 4, + "ngeom3": dims["ngeom"] * 3, + "ngeom4": dims["ngeom"] * 4, + "ngeom6": dims["ngeom"] * 6, + "nmocap3": dims["nmocap"] * 3, + "nmocap4": dims["nmocap"] * 4, + }) + sync_state = {} + for key in _SYNC_STATE_KEYS: + if key not in info: + continue + size = dims[_SYNC_STATE_SIZES[key]] + sync_state[key] = ( + np.asarray(info[key][0], dtype=np.float64).ravel()[:size].tolist() + ) + return sync_state + + +def _oracle_trace( + task_ids: Sequence[str], + trace_plan: Mapping[str, Mapping[str, object]] | None, + width: int, + height: int, + seed: int, + camera_id: int, +) -> Mapping[str, Mapping[str, object]]: + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as out: + out_path = Path(out.name) + plan_path: Path | None = None + cmd = [ + str(_oracle_probe_path()), + "--mode", + "trace", + "--render", + "--render_width", + str(width), + "--render_height", + str(height), + "--camera_id", + str(camera_id), + "--action_mode", + "midpoint", + "--steps", + "3", + "--seed", + str(seed), + "--out", + str(out_path), + ] + if trace_plan is not None: + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as plan: + plan_path = Path(plan.name) + plan_path.write_text(json.dumps(trace_plan, sort_keys=True)) + cmd.extend(["--trace_plan", str(plan_path)]) + for task_id in task_ids: + cmd.extend(["--task_id", task_id]) + env = os.environ.copy() + env["ROBOHIVE_VERBOSITY"] = "SILENT" + try: + result = subprocess.run( + cmd, + check=False, + capture_output=True, + env=env, + text=True, + ) + if result.returncode != 0: + raise RuntimeError( + "MyoSuite oracle probe failed\n" + f"cmd: {' '.join(cmd)}\n" + f"stdout:\n{result.stdout}\n" + f"stderr:\n{result.stderr}" + ) + return json.loads(out_path.read_text())["tasks"] + finally: + out_path.unlink(missing_ok=True) + if plan_path is not None: + plan_path.unlink(missing_ok=True) + + +def _task_frames_from_trace( + task_id: str, + task_trace: Mapping[str, object], + width: int, + height: int, + seed: int, + envpool_records: Mapping[ + str, tuple[list[np.ndarray], list[Mapping[str, object]]] + ] + | None = None, +) -> tuple[list[np.ndarray], list[np.ndarray], list[Mapping[str, object]]]: + oracle_frames = [ + np.asarray(frame, dtype=np.uint8) + for frame in task_trace["frames"] # type: ignore[index] + ] + if envpool_records is not None and task_id in envpool_records: + envpool_frames, envpool_infos = envpool_records[task_id] + else: + envpool_frames, envpool_infos = _envpool_frames_from_actions( + task_id, + width, + height, + seed, + task_trace["actions"], # type: ignore[arg-type,index] + ) + if len(envpool_frames) != 4 or len(oracle_frames) != 4: + raise ValueError("expected reset plus three step frames") + return envpool_frames, oracle_frames, envpool_infos + + +def _render_stats( + task_id: str, + envpool_frames: Sequence[np.ndarray], + oracle_frames: Sequence[np.ndarray], + task_trace: Mapping[str, object], + envpool_infos: Sequence[Mapping[str, object]], +) -> dict[str, object]: + diffs = [ + np.abs(a.astype(np.int16) - b.astype(np.int16)) + for a, b in zip(envpool_frames, oracle_frames, strict=True) + ] + frame_diffs = [] + for idx, diff in enumerate(diffs): + mismatched_pixels = int(np.count_nonzero(np.any(diff != 0, axis=-1))) + frame_diffs.append({ + "max_abs_diff": int(diff.max()), + "mismatched_pixels": mismatched_pixels, + "mean_abs_diff": float(np.mean(diff)), + "first_diff": ( + { + "index": [int(item) for item in np.argwhere(diff)[0]], + "envpool": [ + int(item) + for item in envpool_frames[idx][ + tuple(np.argwhere(diff)[0][:2]) + ] + ], + "official": [ + int(item) + for item in oracle_frames[idx][ + tuple(np.argwhere(diff)[0][:2]) + ] + ], + } + if mismatched_pixels + else None + ), + }) + stats: dict[str, object] = { + "task_id": task_id, + "frames": len(diffs), + "frame_diffs": frame_diffs, + "max_abs_diff": max(int(diff.max()) for diff in diffs), + "mismatched_pixels": sum( + int(np.count_nonzero(np.any(diff != 0, axis=-1))) for diff in diffs + ), + "mean_abs_diff": max(float(np.mean(diff)) for diff in diffs), + "envpool_elapsed_steps": [ + int(np.asarray(info["elapsed_step"]).ravel()[0]) + for info in envpool_infos + if "elapsed_step" in info + ], + "envpool_times": [ + float(np.asarray(info["time"]).ravel()[0]) + for info in envpool_infos + if "time" in info + ], + } + states = [task_trace.get("reset_state", {})] + list( + task_trace.get("states", []) # type: ignore[arg-type] + ) + state_diffs: dict[str, list[float]] = {} + max_state_diff: dict[str, tuple[float, int]] = {} + for step_id, (info, state) in enumerate( + zip(envpool_infos, states, strict=False) + ): + if not isinstance(state, Mapping) or "qpos" not in state: + continue + for key in ( + "qpos", + "qvel", + "act", + "actuator_force", + "actuator_length", + "actuator_velocity", + "ctrl", + "geom_rgba", + "geom_xpos", + "geom_xmat", + "qacc_warmstart", + "site_pos", + "site_quat", + "site_size", + "site_xpos", + "site_rgba", + "body_pos", + "body_quat", + "light_xpos", + "light_xdir", + "mocap_pos", + "mocap_quat", + "fatigue_ma", + "fatigue_mr", + "fatigue_mf", + "fatigue_tl", + "fatigue_tauact", + "fatigue_taudeact", + ): + if key not in info or key not in state: + continue + value = np.asarray(info[key], dtype=np.float64).ravel() + oracle_value = np.asarray(state[key], dtype=np.float64).ravel() + if oracle_value.size == 0: + continue + max_diff = float( + np.max(np.abs(value[: oracle_value.size] - oracle_value)) + ) + state_diffs.setdefault(key, []).append(max_diff) + if key not in max_state_diff or max_diff > max_state_diff[key][0]: + max_state_diff[key] = (max_diff, step_id) + for key, diffs_for_key in state_diffs.items(): + stats[f"max_{key}_abs_diff"] = max(diffs_for_key) + stats[f"max_{key}_abs_diff_step"] = max_state_diff[key][1] + return stats + + +def _render_single( + task_id: str, + task_trace: Mapping[str, object], + out: Path, + width: int, + height: int, + seed: int, + envpool_records: Mapping[ + str, tuple[list[np.ndarray], list[Mapping[str, object]]] + ] + | None = None, +) -> dict[str, object]: + envpool_frames, oracle_frames, envpool_infos = _task_frames_from_trace( + task_id, task_trace, width, height, seed, envpool_records + ) + margin = 18 + label_h = 20 + gutter = 12 + row_gap = 10 + cell_w = width + cell_h = height + label_h + canvas_w = margin * 2 + cell_w * 2 + gutter + canvas_h = margin * 2 + cell_h * 4 + row_gap * 3 + canvas = Image.new("RGB", (canvas_w, canvas_h), "white") + draw = ImageDraw.Draw(canvas) + + labels = ["reset", "step 1", "step 2", "step 3"] + for idx, label in enumerate(labels): + y = margin + idx * (cell_h + row_gap) + left_x = margin + right_x = margin + cell_w + gutter + _label(draw, (left_x, y), f"EnvPool {label}") + _label(draw, (right_x, y), f"Official {label}") + canvas.paste( + Image.fromarray(envpool_frames[idx]), (left_x, y + label_h) + ) + canvas.paste( + Image.fromarray(oracle_frames[idx]), (right_x, y + label_h) + ) + + out.parent.mkdir(parents=True, exist_ok=True) + canvas.save(out) + return _render_stats( + task_id, envpool_frames, oracle_frames, task_trace, envpool_infos + ) + + +def _thumbnail(frame: np.ndarray, width: int, height: int) -> Image.Image: + image = Image.fromarray(frame) + if image.size == (width, height): + return image + return image.resize((width, height), Image.Resampling.LANCZOS) + + +def _render_group( + group: str, + task_ids: Sequence[str], + traces: Mapping[str, Mapping[str, object]], + envpool_records: Mapping[ + str, tuple[list[np.ndarray], list[Mapping[str, object]]] + ], + out_dir: Path, + width: int, + height: int, + seed: int, +) -> list[dict[str, object]]: + labels = ("reset", "step 1", "step 2", "step 3") + margin = 12 + label_w = 238 + label_h = 18 + gutter = 4 + row_gap = 6 + header_h = 36 + thumb_w = width + thumb_h = height + row_h = thumb_h + label_h + row_gap + canvas_w = margin * 2 + label_w + (thumb_w + gutter) * 8 - gutter + canvas_h = margin * 2 + header_h + row_h * len(task_ids) + canvas = Image.new("RGB", (canvas_w, canvas_h), "white") + draw = ImageDraw.Draw(canvas) + draw.text( + (margin, margin), + f"{group}: EnvPool left, Official right; reset + first 3 steps", + fill=(0, 0, 0), + ) + for idx, label in enumerate(labels): + x = margin + label_w + idx * 2 * (thumb_w + gutter) + draw.text((x, margin + label_h), f"{label} Env", fill=(0, 0, 0)) + draw.text( + (x + thumb_w + gutter, margin + label_h), + f"{label} Official", + fill=(0, 0, 0), + ) + + stats: list[dict[str, object]] = [] + for row, task_id in enumerate(task_ids): + y = margin + header_h + row * row_h + draw.text((margin, y + label_h), task_id[:36], fill=(0, 0, 0)) + envpool_frames, oracle_frames, envpool_infos = _task_frames_from_trace( + task_id, traces[task_id], width, height, seed, envpool_records + ) + stats.append( + _render_stats( + task_id, + envpool_frames, + oracle_frames, + traces[task_id], + envpool_infos, + ) + ) + for idx in range(4): + x = margin + label_w + idx * 2 * (thumb_w + gutter) + canvas.paste( + _thumbnail(envpool_frames[idx], thumb_w, thumb_h), (x, y) + ) + canvas.paste( + _thumbnail(oracle_frames[idx], thumb_w, thumb_h), + (x + thumb_w + gutter, y), + ) + + out_dir.mkdir(parents=True, exist_ok=True) + canvas.save(out_dir / f"myosuite_{group}_official_compare.png") + return stats + + +def main() -> None: + """Generate side-by-side MyoSuite render sample images.""" + args = _parse_args() + traces = None + envpool_records: dict[ + str, tuple[list[np.ndarray], list[Mapping[str, object]]] + ] = {} + trace_plan: dict[str, Mapping[str, object]] = {} + if args.oracle_trace is not None: + traces = json.loads(args.oracle_trace.read_text())["tasks"] + if args.all_tasks: + if args.out_dir is None: + raise ValueError("--all_tasks requires --out_dir") + task_order = [] + for task in _registered_tasks(): + task_id = str(task["id"]) + if traces is not None and task_id not in traces: + continue + if traces is None and bool(task["oracle_numpy2_broken"]): + continue + task_order.append(task_id) + if traces is None: + for task_id in task_order: + frames, infos, plan = _envpool_trace_record( + task_id, args.width, args.height, args.seed + ) + envpool_records[task_id] = (frames, infos) + trace_plan[task_id] = plan + traces = _oracle_trace( + task_order, + trace_plan, + args.width, + args.height, + args.seed, + args.camera_id, + ) + groups: dict[str, list[str]] = { + "myobase": [], + "myochallenge": [], + "myodm": [], + } + for task_id in task_order: + groups[_group_name(task_id)].append(task_id) + stats = [] + for group, group_task_ids in groups.items(): + stats.extend( + _render_group( + group, + group_task_ids, + traces, + envpool_records, + args.out_dir, + args.width, + args.height, + args.seed, + ) + ) + else: + if args.out is None: + raise ValueError("--out is required unless --all_tasks is set") + if traces is None: + frames, infos, plan = _envpool_trace_record( + args.task_id, args.width, args.height, args.seed + ) + envpool_records[args.task_id] = (frames, infos) + trace_plan[args.task_id] = plan + traces = _oracle_trace( + (args.task_id,), + trace_plan, + args.width, + args.height, + args.seed, + args.camera_id, + ) + stats = [ + _render_single( + args.task_id, + traces[args.task_id], + args.out, + args.width, + args.height, + args.seed, + envpool_records, + ) + ] + if args.debug_json is not None: + args.debug_json.write_text( + json.dumps( + { + "stats": stats, + "trace_plan": trace_plan, + "traces": traces, + }, + sort_keys=True, + ) + ) + if args.quiet: + print( + json.dumps( + { + "max_abs_diff": max( + int(stat["max_abs_diff"]) for stat in stats + ), + "tasks": len(stats), + }, + sort_keys=True, + ) + ) + else: + print(json.dumps({"tasks": stats}, sort_keys=True)) + + +if __name__ == "__main__": + main() diff --git a/third_party/myosuite/generate_runtime_assets.py b/third_party/myosuite/generate_runtime_assets.py new file mode 100644 index 000000000..b8a4153b5 --- /dev/null +++ b/third_party/myosuite/generate_runtime_assets.py @@ -0,0 +1,621 @@ +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Assemble the minimal MyoSuite runtime asset tree.""" + +from __future__ import annotations + +import argparse +import importlib.util +import os +import shutil +import sys +import tempfile +import xml.etree.ElementTree as ET +from collections.abc import Callable +from pathlib import Path +from typing import Any + +_MYODM_OBJECTS = ( + "airplane", + "alarmclock", + "apple", + "banana", + "binoculars", + "bowl", + "camera", + "coffeemug", + "cup", + "cylinderlarge", + "cylindermedium", + "cylindersmall", + "duck", + "elephant", + "eyeglasses", + "flashlight", + "flute", + "fryingpan", + "gamecontroller", + "gelatinbox", + "hammer", + "hand", + "headphones", + "knife", + "lightbulb", + "mouse", + "mug", + "phone", + "piggybank", + "pyramidlarge", + "pyramidmedium", + "pyramidsmall", + "rubberduck", + "scissors", + "spherelarge", + "spheremedium", + "spheresmall", + "stanfordbunny", + "stapler", + "table", + "teapot", + "toothbrush", + "toothpaste", + "toruslarge", + "torusmedium", + "torussmall", + "train", + "watch", + "waterbottle", + "wineglass", +) +_WINDOWS_SHORT_IMPORT_PACKAGES = ("mujoco",) +_DLL_DIRECTORY_HANDLES: list[Any] = [] +mujoco: Any = None + +_UNUSED_RUNTIME_ASSETS = { + "myosuite": { + "envs/myo/assets/leg_soccer/soccer_assets/soccer_scene/SoccerPitch_goal.obj", + "envs/myo/assets/pingpong.obj", + }, + "myosuite_furniture_sim": { + "simpleTable/simpleTable.png", + }, + "myosuite_mpl_sim": { + "meshes/mplL/palm_link-cvx.stl", + "meshes/mplL/wrist_dev_link-cvx.stl", + "meshes/mplL/wrist_fe_link-cvx.stl", + "meshes/mplL/wrist_rot_link-cvx.stl", + }, + "myosuite_myo_sim": { + "meshes/hand_2distph.stl", + "meshes/hand_2midph.stl", + "meshes/hand_2proxph.stl", + "meshes/hat_spine.stl", + "meshes/human_highpoly.stl", + "meshes/human_lowpoly.stl", + "meshes/osl_generic_socket_enlarged.stl", + "meshes/ribcage_s.stl", + "scene/myosuite_scene.msh", + "scene/myosuite_scene_noFloor.png", + }, + "myosuite_myo_sim_patterns": ( + "meshes/fingers", + "meshes/movaxesfin", + ), + "myosuite_myo_sim_suffixes": ( + "_lvs.stl", + "_rvs.stl", + ), + "myosuite_object_sim": { + "hammer/hammer.stl", + "knife/knife.stl", + "lightbulb/lightbulb.stl", + "scissors/scissors.stl", + }, +} + + +def _copy_short_import_package( + destination_root: Path, package_name: str +) -> list[Path]: + spec = importlib.util.find_spec(package_name) + if spec is None or spec.submodule_search_locations is None: + return [] + source = Path(next(iter(spec.submodule_search_locations))) + if not source.is_dir(): + return [] + destination = destination_root / package_name + if not destination.exists(): + shutil.copytree( + source, + destination, + symlinks=False, + ignore=shutil.ignore_patterns("__pycache__"), + ) + copied = [destination] + sibling = source.parent / f"{package_name}.libs" + if sibling.is_dir(): + sibling_destination = destination_root / sibling.name + if not sibling_destination.exists(): + shutil.copytree(sibling, sibling_destination, symlinks=False) + copied.append(sibling_destination) + return copied + + +def _shorten_windows_binary_imports() -> None: + if os.name != "nt": + return + # Bazel runfiles paths can exceed the Windows DLL loader path limit. + destination_root = ( + Path(tempfile.gettempdir()) / "envpool_mujoco_runtime_site" + ) + destination_root.mkdir(parents=True, exist_ok=True) + copied: list[Path] = [] + for package_name in _WINDOWS_SHORT_IMPORT_PACKAGES: + copied.extend( + _copy_short_import_package(destination_root, package_name) + ) + for name in tuple(sys.modules): + if name == package_name or name.startswith(f"{package_name}."): + del sys.modules[name] + if copied: + sys.path.insert(0, str(destination_root)) + if hasattr(os, "add_dll_directory"): + for path in copied: + _DLL_DIRECTORY_HANDLES.append(os.add_dll_directory(str(path))) + + +def _load_mujoco() -> Any: + global mujoco + if mujoco is None: + _shorten_windows_binary_imports() + import mujoco as mujoco_module + + mujoco = mujoco_module + return mujoco + + +def _read_manifest(path: Path) -> list[Path]: + return [Path(line) for line in path.read_text().splitlines() if line] + + +def _after_marker(path: Path, marker: str) -> str | None: + text = path.as_posix() + marker = marker.rstrip("/") + "/" + index = text.find(marker) + if index < 0: + return None + return text[index + len(marker) :] + + +def _skip_common(rel: str) -> bool: + return ( + rel + in { + ".gitignore", + "BUILD.bazel", + "README.md", + "REPO.bazel", + "WORKSPACE", + "__init__.py", + "objects.png", + "preview.py", + "pyproject.toml", + "test_sims.py", + } + or rel.startswith((".github/", ".idea/", "tests/")) + or rel.endswith("/object.xml") + or rel.startswith("scene/") + and rel.endswith((".mtl", ".obj")) + ) + + +def _object_needed(rel: str, objects: set[str]) -> bool: + if rel in {"LICENSE", "common.xml"}: + return True + object_dir = rel.split("/", 1)[0] + return object_dir in objects + + +def _destination(src: Path, out: Path, objects: set[str]) -> Path | None: + rel = _after_marker(src, "myosuite_source/myosuite") + if rel is not None: + if rel in _UNUSED_RUNTIME_ASSETS["myosuite"] or _skip_common(rel): + return None + return out / "myosuite" / rel + + rel = _after_marker(src, "myosuite_mpl_sim") + if rel is not None: + keep = rel in { + "LICENSE", + "assets/handL_assets.xml", + "assets/handL_chain.xml", + "assets/left_arm_assets.xml", + "assets/left_arm_chain_myochallenge.xml", + } or rel.startswith("meshes/mplL/") + if ( + not keep + or rel in _UNUSED_RUNTIME_ASSETS["myosuite_mpl_sim"] + or _skip_common(rel) + ): + return None + return out / "myosuite/simhive/MPL_sim" / rel + + rel = _after_marker(src, "myosuite_ycb_sim") + if rel is not None: + keep = rel in { + "LICENSE", + "includes/assets_009_gelatin_box.xml", + "includes/body_009_gelatin_box.xml", + "includes/defaults_ycb.xml", + "meshes/009_gelatin_box.msh", + "textures/009_gelatin_box.png", + } + if not keep or _skip_common(rel): + return None + return out / "myosuite/simhive/YCB_sim" / rel + + rel = _after_marker(src, "myosuite_furniture_sim") + if rel is not None: + keep = rel in { + "LICENSE", + "common/textures/stone0.png", + "common/textures/stone1.png", + "common/textures/wood1.png", + "simpleTable.xml", + } or rel.startswith("simpleTable/") + if ( + not keep + or rel in _UNUSED_RUNTIME_ASSETS["myosuite_furniture_sim"] + or _skip_common(rel) + ): + return None + return out / "myosuite/simhive/furniture_sim" / rel + + rel = _after_marker(src, "myosuite_myo_sim") + if rel is not None: + if ( + rel in _UNUSED_RUNTIME_ASSETS["myosuite_myo_sim"] + or rel.startswith( + _UNUSED_RUNTIME_ASSETS["myosuite_myo_sim_patterns"] + ) + or rel.endswith(_UNUSED_RUNTIME_ASSETS["myosuite_myo_sim_suffixes"]) + or _skip_common(rel) + ): + return None + return out / "myosuite/simhive/myo_sim" / rel + + rel = _after_marker(src, "myosuite_object_sim") + if rel is not None: + if ( + not _object_needed(rel, objects) + or rel in _UNUSED_RUNTIME_ASSETS["myosuite_object_sim"] + or _skip_common(rel) + ): + return None + return out / "myosuite/simhive/object_sim" / rel + + return None + + +def _copy_runtime_sources(out: Path, manifest: Path, objects: set[str]) -> None: + for src in _read_manifest(manifest): + dst = _destination(src, out, objects) + if dst is None or not src.is_file(): + continue + dst.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(src, dst) + + +def _copy_runtime_metadata(out: Path, manifest: Path) -> None: + metadata_dir = out / "metadata" + for src in _read_manifest(manifest): + if not src.is_file() or src.suffix != ".json": + continue + metadata_dir.mkdir(parents=True, exist_ok=True) + shutil.copy2(src, metadata_dir / src.name) + + +def _generate_myodm_object_xml(out: Path, objects: set[str]) -> None: + template = out / "myosuite/envs/myo/assets/hand/myohand_object.xml" + text = template.read_text() + for object_name in sorted(objects): + (template.parent / f"myohand_object_{object_name}.xml").write_text( + text.replace("OBJECT_NAME", object_name) + ) + + +def _mesh_geoms(body: mujoco.MjsBody) -> list[str]: + return [ + geom.name + for geom in body.geoms + if geom.type == mujoco.mjtGeom.mjGEOM_MESH + ] + + +def _apply_arm_reach_edit(spec: mujoco.MjSpec) -> None: + roots = ("firstmc", "secondmc", "thirdmc", "fourthmc", "fifthmc") + body_positions: dict[str, list[tuple[str, list[float], list[str]]]] = {} + for root in roots: + body_positions[root] = [] + child = spec.body(root).first_body() + while child is not None: + body_positions[root].append(( + child.name, + child.pos.copy(), + _mesh_geoms(child), + )) + child = child.first_body() + + site = spec.site("IFtip") + site_size = site.size.copy() + site_pos = site.pos.copy() + site_rgba = site.rgba.copy() + + for root in roots: + child = spec.body(root).first_body() + if child is not None: + spec.delete(child) + + for root in roots: + parent = spec.body(root) + for body_name, pos, mesh_names in body_positions[root]: + parent.add_body(name=body_name, pos=pos) + body = spec.body(body_name) + for mesh_name in mesh_names: + body.add_geom( + meshname=mesh_name, + name=body_name, + type=mujoco.mjtGeom.mjGEOM_MESH, + ) + if body_name == "distph2": + body.add_site( + name="IFtip", + size=site_size * 2, + pos=site_pos, + rgba=site_rgba, + ) + parent = body + + spec.body("world").add_site( + name="IFtip_target", + type=mujoco.mjtGeom.mjGEOM_SPHERE, + size=[0.02, 0.02, 0.02], + pos=[-0.2, -0.2, 1.2], + rgba=[0.0, 0.0, 1.0, 0.3], + ) + + +def _generate_arm_reach_xml(out: Path) -> None: + _load_mujoco() + myo_sim = (out / "myosuite/simhive/myo_sim").resolve() + source_arm = myo_sim / "arm" + with tempfile.TemporaryDirectory(prefix="myosuite-arm-reach-") as temp: + temp_root = Path(temp) + spec_root = temp_root / "spec" + spec_arm = spec_root / "arm" + shutil.copytree(source_arm, spec_arm, symlinks=True) + (spec_root / "scene").symlink_to(myo_sim / "scene") + (spec_root / "myo_sim").symlink_to(myo_sim) + (spec_arm / "myo_sim").symlink_to(myo_sim) + (temp_root / "myo_sim").symlink_to(myo_sim) + + spec = mujoco.MjSpec.from_file(str(spec_arm / "myoarm.xml")) + _apply_arm_reach_edit(spec) + spec.compile() + (source_arm / "myoarm_reach.xml").write_text(spec.to_xml()) + + +def _recursive_immobilize( + spec: mujoco.MjSpec, + temp_model: mujoco.MjModel, + parent: mujoco.MjsBody, + remove_eqs: bool = False, + remove_actuators: bool = False, +) -> list[int]: + removed_joint_ids: list[int] = [] + for site in list(parent.sites): + spec.delete(site) + for joint in list(parent.joints): + removed_joint_ids.extend(temp_model.joint(joint.name).qposadr) + if remove_eqs: + for equality in list(spec.equalities): + if equality.type == mujoco.mjtEq.mjEQ_JOINT and ( + equality.name1 == joint.name or equality.name2 == joint.name + ): + spec.delete(equality) + if remove_actuators: + for actuator in list(spec.actuators): + if ( + actuator.trntype == mujoco.mjtTrn.mjTRN_JOINT + and actuator.target == joint.name + ): + spec.delete(actuator) + spec.delete(joint) + for child in list(parent.bodies): + removed_joint_ids.extend( + _recursive_immobilize( + spec, temp_model, child, remove_eqs, remove_actuators + ) + ) + return removed_joint_ids + + +def _recursive_remove_contacts( + parent: mujoco.MjsBody, + return_condition: Callable[[mujoco.MjsBody], bool] | None = None, +) -> None: + if return_condition is not None and return_condition(parent): + return + for geom in parent.geoms: + geom.contype = 0 + geom.conaffinity = 0 + for child in parent.bodies: + _recursive_remove_contacts(child, return_condition) + + +def _recursive_mirror( + meshes_to_mirror: set[str], + spec_copy: mujoco.MjSpec, + parent: mujoco.MjsBody, +) -> None: + parent.pos[1] *= -1 + parent.quat[[1, 3]] *= -1 + parent.name += "_mirrored" + for geom in list(parent.geoms): + if geom.type != mujoco.mjtGeom.mjGEOM_MESH: + spec_copy.delete(geom) + continue + geom.pos[1] *= -1 + geom.quat[[1, 3]] *= -1 + geom.name += "_mirrored" + geom.group = 1 + meshes_to_mirror.add(geom.meshname) + geom.meshname += "_mirrored" + for child in list(parent.bodies): + if "ping_pong" in child.name: + spec_copy.delete(child) + continue + _recursive_mirror(meshes_to_mirror, spec_copy, child) + + +def _preprocess_tabletennis_spec(spec: mujoco.MjSpec) -> mujoco.MjSpec: + for sensor in list(spec.sensors): + if "pingpong" not in sensor.name and "paddle" not in sensor.name: + spec.delete(sensor) + temp_model = spec.compile() + + removed_ids = _recursive_immobilize( + spec, temp_model, spec.body("femur_l"), remove_eqs=True + ) + removed_ids.extend( + _recursive_immobilize( + spec, temp_model, spec.body("femur_r"), remove_eqs=True + ) + ) + removed = set(removed_ids) + for key in spec.keys: + key.qpos = [ + value for idx, value in enumerate(key.qpos) if idx not in removed + ] + + _recursive_remove_contacts( + spec.body("full_body"), + return_condition=lambda body: "radius" in body.name, + ) + + torso = spec.body("torso") + spec_copy = spec.copy() + attachment_frame = torso.add_frame( + quat=[0.5, 0.5, -0.5, 0.5], + pos=[0.05, 0.373, -0.04], + ) + for collection in ( + spec_copy.keys, + spec_copy.textures, + spec_copy.materials, + spec_copy.tendons, + spec_copy.actuators, + spec_copy.equalities, + spec_copy.sensors, + spec_copy.cameras, + ): + for item in list(collection): + spec_copy.delete(item) + _recursive_immobilize(spec_copy, temp_model, spec_copy.worldbody) + _recursive_remove_contacts(spec_copy.worldbody) + + meshes_to_mirror: set[str] = set() + _recursive_mirror(meshes_to_mirror, spec_copy, spec_copy.body("clavicle")) + for mesh in list(spec_copy.meshes): + if mesh.name in meshes_to_mirror: + mesh.name += "_mirrored" + mesh.scale[1] *= -1 + else: + spec_copy.delete(mesh) + + attachment_frame.attach_body(spec_copy.body("clavicle_mirrored")) + spec.body("ulna_mirrored").quat = [0.546, 0, 0, -0.838] + spec.body("humerus_mirrored").quat = [0.924, 0.383, 0, 0] + return spec + + +def _generate_tabletennis_xml(out: Path) -> None: + _load_mujoco() + asset_arm = out / "myosuite/envs/myo/assets/arm" + source = asset_arm / "myoarm_tabletennis.xml" + spec = mujoco.MjSpec.from_file(str(source)) + _preprocess_tabletennis_spec(spec) + xml_text = _normalize_tabletennis_xml(spec.to_xml()) + (asset_arm / "myoarm_tabletennis_native.xml").write_text(xml_text) + + +def _normalize_tabletennis_xml(xml_text: str) -> str: + root = ET.fromstring(xml_text) + seen_default_classes: set[str] = set() + + def visit(parent: ET.Element) -> None: + index = 0 + while index < len(parent): + child = parent[index] + if ( + parent.tag == "default" + and child.tag == "default" + and not child.attrib + ): + grandchildren = list(child) + parent.remove(child) + for offset, grandchild in enumerate(grandchildren): + parent.insert(index + offset, grandchild) + continue + elif child.tag == "default" and "class" in child.attrib: + class_name = child.attrib["class"] + if class_name in seen_default_classes: + parent.remove(child) + else: + seen_default_classes.add(class_name) + visit(child) + index += 1 + else: + visit(child) + index += 1 + + visit(root) + ET.indent(root, space=" ") + return ET.tostring(root, encoding="unicode") + "\n" + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("out", type=Path) + parser.add_argument("manifest", type=Path) + parser.add_argument("objects", type=Path) + parser.add_argument("metadata", type=Path) + return parser.parse_args() + + +def main() -> None: + """Generate the minimal runtime asset tree used by native MyoSuite.""" + args = _parse_args() + objects = set(args.objects.read_text().splitlines()) or set(_MYODM_OBJECTS) + args.out.mkdir(parents=True, exist_ok=True) + (args.out / "myosuite/simhive").mkdir(parents=True, exist_ok=True) + _copy_runtime_sources(args.out, args.manifest, objects) + _copy_runtime_metadata(args.out, args.metadata) + _generate_myodm_object_xml(args.out, objects) + _generate_arm_reach_xml(args.out) + _generate_tabletennis_xml(args.out) + + +if __name__ == "__main__": + main() diff --git a/third_party/myosuite/generate_task_metadata.py b/third_party/myosuite/generate_task_metadata.py new file mode 100644 index 000000000..99f062a3a --- /dev/null +++ b/third_party/myosuite/generate_task_metadata.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python3 +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generate compact native MyoSuite task metadata. + +The input metadata is produced by: + + bazel run //envpool/mujoco:myosuite_oracle_probe -- \ + --mode metadata --out /tmp/myosuite_all_meta.json --task_id ... + +Only the compact fields consumed by the native C++ runtime are emitted here. +The full official package remains a test/doc oracle, not a runtime dependency. +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Any + +SCALAR_DEFAULTS = { + "far_th": 0.0, + "goal_th": 0.0, + "hip_period": 0, + "max_rot": 0.0, + "min_height": 0.0, + "pose_thd": 0.0, + "target_x_vel": 0.0, + "target_y_vel": 0.0, +} + +ORACLE_BROKEN_SOURCE_METADATA = { + "myosuite.envs.myo.myochallenge.bimanual_v0:BimanualEnvV1": { + "obs_keys": [ + "time", + "myohand_qpos", + "myohand_qvel", + "pros_hand_qpos", + "pros_hand_qvel", + "object_qpos", + "object_qvel", + "touching_body", + ], + "rwd_keys_wt": { + "act": 0.0, + "fin_dis": -0.5, + "pass_err": -1.0, + "reach_dist": -0.1, + }, + }, + "myosuite.envs.myo.myochallenge.soccer_v0:SoccerEnvV0": { + "obs_keys": [ + "internal_qpos", + "internal_qvel", + "grf", + "torso_angle", + "ball_pos", + "model_root_pos", + "model_root_vel", + "muscle_length", + "muscle_velocity", + "muscle_force", + ], + "rwd_keys_wt": { + "act_reg": -100.0, + "goal_scored": 1000.0, + "pain": -10.0, + "time_cost": -0.01, + }, + }, +} + + +def _csv(items: list[Any] | None) -> str: + if not items: + return "" + return ",".join(str(item) for item in items) + + +def _flat(values: Any) -> list[Any]: + if values is None: + return [] + if isinstance(values, list): + out: list[Any] = [] + for value in values: + if isinstance(value, list): + out.extend(_flat(value)) + else: + out.append(value) + return out + return [values] + + +def _float_csv(values: Any) -> str: + return ",".join(f"{float(value):.17g}" for value in _flat(values)) + + +def _rwd_csv(weights: dict[str, Any] | None) -> str: + if not weights: + return "" + return ",".join( + f"{key}:{float(weights[key]):.17g}" for key in sorted(weights) + ) + + +def _reach_range_csv( + metadata: dict[str, Any], target_index: int +) -> tuple[str, str]: + reach_range = metadata.get("target_reach_range") + tip_sites = metadata.get("tip_sites") or [] + if not reach_range or target_index >= len(tip_sites): + return "", "" + span = reach_range.get(tip_sites[target_index]) + if span is None: + return "", "" + return _float_csv(span[0]), _float_csv(span[1]) + + +def _escape(value: str) -> str: + return value.replace("\\", "\\\\").replace('"', '\\"') + + +def _entry( + task: dict[str, Any], metadata: dict[str, Any] | None +) -> dict[str, Any]: + if metadata is None: + metadata = ORACLE_BROKEN_SOURCE_METADATA.get(task["entry_point"], {}) + low_ranges: list[str] = [] + high_ranges: list[str] = [] + for i, _target_site in enumerate(metadata.get("target_sites") or []): + low, high = _reach_range_csv(metadata, i) + low_ranges.append(low) + high_ranges.append(high) + entry = { + "id": task["id"], + "obs_keys": _csv(metadata.get("obs_keys")), + "rwd_keys_wt": _rwd_csv(metadata.get("rwd_keys_wt")), + "init_qpos": _float_csv(metadata.get("init_qpos")), + "init_qvel": _float_csv(metadata.get("init_qvel")), + "reset_qacc_warmstart": _float_csv( + (metadata.get("reset_state") or {}).get("qacc_warmstart") + ), + "target_jnt_value": _float_csv(metadata.get("target_jnt_value")), + "tip_sites": _csv(metadata.get("tip_sites")), + "target_sites": _csv(metadata.get("target_sites")), + "target_reach_low": ";".join(low_ranges), + "target_reach_high": ";".join(high_ranges), + "reset_type": str(metadata.get("reset_type", "")), + } + for key, default in SCALAR_DEFAULTS.items(): + value = metadata.get(key, default) + entry[key] = default if value is None else value + return entry + + +def _write_header(entries: list[dict[str, Any]], output: Path) -> None: + lines = [ + "// Copyright 2026 Garena Online Private Limited", + "//", + '// Licensed under the Apache License, Version 2.0 (the "License");', + "// you may not use this file except in compliance with the License.", + "// You may obtain a copy of the License at", + "//", + "// http://www.apache.org/licenses/LICENSE-2.0", + "//", + "// Unless required by applicable law or agreed to in writing, software", + '// distributed under the License is distributed on an "AS IS" BASIS,', + "// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.", + "// See the License for the specific language governing permissions and", + "// limitations under the License.", + "", + "// Generated from pinned MyoSuite oracle metadata; do not edit by hand.", + "#ifndef THIRD_PARTY_MYOSUITE_MYOSUITE_TASK_METADATA_H_", + "#define THIRD_PARTY_MYOSUITE_MYOSUITE_TASK_METADATA_H_", + "", + "#include ", + "#include ", + "#include ", + "", + "namespace third_party::myosuite {", + "", + "struct MyoSuiteTaskMetadata {", + " const char* id;", + " const char* obs_keys;", + " const char* rwd_keys_wt;", + " const char* init_qpos;", + " const char* init_qvel;", + " const char* reset_qacc_warmstart;", + " const char* target_jnt_value;", + " const char* tip_sites;", + " const char* target_sites;", + " const char* target_reach_low;", + " const char* target_reach_high;", + " const char* reset_type;", + " double far_th;", + " double goal_th;", + " int hip_period;", + " double max_rot;", + " double min_height;", + " double pose_thd;", + " double target_x_vel;", + " double target_y_vel;", + "};", + "", + "// clang-format off", + ( + f"inline constexpr std::array " + "kMyoSuiteTaskMetadata = {{" + ), + ] + for entry in entries: + lines.extend([ + " MyoSuiteTaskMetadata{", + f' "{_escape(entry["id"])}",', + f' "{_escape(entry["obs_keys"])}",', + f' "{_escape(entry["rwd_keys_wt"])}",', + f' "{_escape(entry["init_qpos"])}",', + f' "{_escape(entry["init_qvel"])}",', + f' "{_escape(entry["reset_qacc_warmstart"])}",', + f' "{_escape(entry["target_jnt_value"])}",', + f' "{_escape(entry["tip_sites"])}",', + f' "{_escape(entry["target_sites"])}",', + f' "{_escape(entry["target_reach_low"])}",', + f' "{_escape(entry["target_reach_high"])}",', + f' "{_escape(entry["reset_type"])}",', + f" {float(entry['far_th']):.17g},", + f" {float(entry['goal_th']):.17g},", + f" {int(entry['hip_period'])},", + f" {float(entry['max_rot']):.17g},", + f" {float(entry['min_height']):.17g},", + f" {float(entry['pose_thd']):.17g},", + f" {float(entry['target_x_vel']):.17g},", + f" {float(entry['target_y_vel']):.17g},", + " },", + ]) + lines.extend([ + "}};", + "// clang-format on", + "", + "inline const MyoSuiteTaskMetadata& GetMyoSuiteTaskMetadata(", + " std::string_view task_id) {", + " for (const auto& metadata : kMyoSuiteTaskMetadata) {", + " if (metadata.id == task_id) {", + " return metadata;", + " }", + " }", + ' throw std::runtime_error("Unknown MyoSuite task metadata.");', + "}", + "", + "} // namespace third_party::myosuite", + "", + "#endif // THIRD_PARTY_MYOSUITE_MYOSUITE_TASK_METADATA_H_", + "", + ]) + output.write_text("\n".join(lines)) + + +def main() -> None: + """Generate compact C++ and JSON task metadata.""" + parser = argparse.ArgumentParser() + parser.add_argument("--tasks", type=Path, required=True) + parser.add_argument("--oracle-metadata", type=Path, required=True) + parser.add_argument("--out-json", type=Path, required=True) + parser.add_argument("--out-header", type=Path, required=True) + args = parser.parse_args() + + tasks = json.loads(args.tasks.read_text()) + oracle = json.loads(args.oracle_metadata.read_text())["tasks"] + entries = [_entry(task, oracle.get(task["id"])) for task in tasks] + args.out_json.write_text( + json.dumps(entries, indent=2, sort_keys=True) + "\n" + ) + _write_header(entries, args.out_header) + + +if __name__ == "__main__": + main() diff --git a/third_party/myosuite/generate_task_registry.py b/third_party/myosuite/generate_task_registry.py new file mode 100644 index 000000000..02c3d8196 --- /dev/null +++ b/third_party/myosuite/generate_task_registry.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python3 +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Regenerate native MyoSuite task registry files. + +The source registry is the pinned MyoSuite task list checked in as JSON. The +optional oracle metadata input is produced by `myosuite_oracle_probe --mode +metadata` and refreshes space fields that can change with the pinned MuJoCo +runtime and upstream MjSpec patching. +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Any + +ORACLE_VERSION = "2.11.6" +ORACLE_COMMIT = "05cb84678373f91271004f99602ebbf01e57d1a1" + +BROKEN_IDS: tuple[str, ...] = () + + +def _escape(value: str) -> str: + return value.replace("\\", "\\\\").replace('"', '\\"') + + +def _bool(value: bool) -> str: + return "true" if value else "false" + + +def _refresh_from_metadata( + tasks: list[dict[str, Any]], metadata_path: Path | None +) -> None: + if metadata_path is None: + return + oracle = json.loads(metadata_path.read_text()) + if oracle["version"] != ORACLE_VERSION: + raise ValueError( + f"expected MyoSuite {ORACLE_VERSION}, got {oracle['version']}" + ) + by_id = oracle["tasks"] + for task in tasks: + metadata = by_id.get(task["id"]) + if metadata is None: + continue + task["obs_dim"] = int(metadata["observation_shape"][0]) + task["action_dim"] = int(metadata["action_shape"][0]) + task["frame_skip"] = int(metadata["frame_skip"]) + task["oracle_numpy2_broken"] = False + for task in tasks: + if task["id"] in BROKEN_IDS: + task["oracle_numpy2_broken"] = True + + +def _write_json(tasks: list[dict[str, Any]], output: Path) -> None: + output.write_text(json.dumps(tasks, indent=2, sort_keys=True) + "\n") + + +def _write_header(tasks: list[dict[str, Any]], output: Path) -> None: + lines = [ + "// Copyright 2026 Garena Online Private Limited", + "//", + '// Licensed under the Apache License, Version 2.0 (the "License");', + "// you may not use this file except in compliance with the License.", + "// You may obtain a copy of the License at", + "//", + "// http://www.apache.org/licenses/LICENSE-2.0", + "//", + "// Unless required by applicable law or agreed to in writing, software", + '// distributed under the License is distributed on an "AS IS" BASIS,', + "// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.", + "// See the License for the specific language governing permissions and", + "// limitations under the License.", + "", + f"// Generated from MyoSuite v{ORACLE_VERSION} registry; do not edit by hand.", + "#ifndef THIRD_PARTY_MYOSUITE_MYOSUITE_TASKS_H_", + "#define THIRD_PARTY_MYOSUITE_MYOSUITE_TASKS_H_", + "", + "#include ", + "#include ", + "#include ", + "#include ", + "#include ", + "", + "namespace third_party::myosuite {", + "", + "enum class MyoSuiteTaskKind : std::uint8_t {", + " kPose,", + " kReach,", + " kWalkReach,", + " kWalk,", + " kTerrain,", + " kKeyTurn,", + " kObjHoldFixed,", + " kObjHoldRandom,", + " kPenTwirlFixed,", + " kPenTwirlRandom,", + " kTorsoPose,", + " kReorientSar,", + " kChallengeBaoding,", + " kChallengeBimanual,", + " kChallengeChaseTag,", + " kChallengeRelocate,", + " kChallengeReorient,", + " kChallengeRunTrack,", + " kChallengeSoccer,", + " kChallengeTableTennis,", + " kMyoDmTrack,", + "};", + "", + "enum class MyoSuiteMuscleCondition : std::uint8_t {", + " kNormal,", + " kSarcopenia,", + " kFatigue,", + " kReafferentation,", + "};", + "", + "struct MyoSuiteTaskDef {", + " const char* id;", + " const char* envpool_id;", + " const char* entry_point;", + " MyoSuiteTaskKind kind;", + " const char* model_path;", + " const char* reference_path;", + " const char* object_name;", + " int obs_dim;", + " int action_dim;", + " int max_episode_steps;", + " int frame_skip;", + " bool normalize_act;", + " MyoSuiteMuscleCondition muscle_condition;", + " bool oracle_numpy2_broken;", + "};", + "", + "// clang-format off", + ( + f"inline constexpr std::array " + "kMyoSuiteTasks = {{" + ), + ] + for task in tasks: + lines.extend([ + " MyoSuiteTaskDef{", + f' "{_escape(task["id"])}",', + f' "MyoSuite/{_escape(task["id"])}",', + f' "{_escape(task["entry_point"])}",', + f" MyoSuiteTaskKind::{task['kind']},", + f' "{_escape(task["model_path"])}",', + f' "{_escape(task["reference_path"])}",', + f' "{_escape(task["object_name"])}",', + f" {int(task['obs_dim'])},", + f" {int(task['action_dim'])},", + f" {int(task['max_episode_steps'])},", + f" {int(task['frame_skip'])},", + f" {_bool(bool(task['normalize_act']))},", + f" MyoSuiteMuscleCondition::{task['muscle']},", + f" {_bool(bool(task['oracle_numpy2_broken']))},", + " },", + ]) + lines.extend([ + "}};", + "// clang-format on", + "", + "inline const MyoSuiteTaskDef& GetMyoSuiteTask(std::string_view id) {", + " for (const auto& task : kMyoSuiteTasks) {", + " if (task.id == id || task.envpool_id == id) {", + " return task;", + " }", + " }", + ' throw std::runtime_error("Unknown MyoSuite task: " + std::string(id));', + "}", + "", + "} // namespace third_party::myosuite", + "", + "#endif // THIRD_PARTY_MYOSUITE_MYOSUITE_TASKS_H_", + "", + ]) + output.write_text("\n".join(lines)) + + +def main() -> None: + """Generate native task registry files.""" + parser = argparse.ArgumentParser() + parser.add_argument("--tasks-json", type=Path, required=True) + parser.add_argument("--oracle-metadata", type=Path) + parser.add_argument("--out-json", type=Path, required=True) + parser.add_argument("--out-header", type=Path, required=True) + args = parser.parse_args() + + tasks = json.loads(args.tasks_json.read_text()) + _refresh_from_metadata(tasks, args.oracle_metadata) + _write_json(tasks, args.out_json) + _write_header(tasks, args.out_header) + + +if __name__ == "__main__": + main() diff --git a/third_party/myosuite/mujoco36_mjspec_compat.patch b/third_party/myosuite/mujoco36_mjspec_compat.patch new file mode 100644 index 000000000..21535eda8 --- /dev/null +++ b/third_party/myosuite/mujoco36_mjspec_compat.patch @@ -0,0 +1,98 @@ +diff --git a/myosuite/envs/myo/myoedits/__init__.py b/myosuite/envs/myo/myoedits/__init__.py +index 3f8af30..d215f8f 100644 +--- a/myosuite/envs/myo/myoedits/__init__.py ++++ b/myosuite/envs/myo/myoedits/__init__.py +@@ -47 +47 @@ def edit_fn_arm_reaching(spec: mujoco.MjSpec) -> None: +- spec.detach_body(child_body) ++ spec.delete(child_body) +diff --git a/myosuite/envs/myo/myochallenge/tabletennis_v0.py b/myosuite/envs/myo/myochallenge/tabletennis_v0.py +index 4dc75f3..af9894e 100644 +--- a/myosuite/envs/myo/myochallenge/tabletennis_v0.py ++++ b/myosuite/envs/myo/myochallenge/tabletennis_v0.py +@@ -401,7 +401,7 @@ class TableTennisEnvV0(BaseV0): + warnings.warn("A paddle was found that was not a free body. Confirm this is intended.") + for s in spec.sensors: + if "pingpong" not in s.name and "paddle" not in s.name: +- s.delete() ++ spec.delete(s) + temp_model = spec.compile() + + removed_ids = recursive_immobilize(spec, temp_model, spec.body("femur_l"), remove_eqs=True) +@@ -419,15 +419,15 @@ class TableTennisEnvV0(BaseV0): + + spec_copy: mujoco.MjSpec = spec.copy() + attachment_frame = torso.add_frame(quat=[0.5, 0.5, -0.5, 0.5], + pos=[0.05, 0.373, -0.04]) +- [k.delete() for k in spec_copy.keys] +- [t.delete() for t in spec_copy.textures] +- [m.delete() for m in spec_copy.materials] +- [t.delete() for t in spec_copy.tendons] +- [a.delete() for a in spec_copy.actuators] +- [e.delete() for e in spec_copy.equalities] +- [s.delete() for s in spec_copy.sensors] +- [c.delete() for c in spec_copy.cameras] ++ [spec_copy.delete(k) for k in spec_copy.keys] ++ [spec_copy.delete(t) for t in spec_copy.textures] ++ [spec_copy.delete(m) for m in spec_copy.materials] ++ [spec_copy.delete(t) for t in spec_copy.tendons] ++ [spec_copy.delete(a) for a in spec_copy.actuators] ++ [spec_copy.delete(e) for e in spec_copy.equalities] ++ [spec_copy.delete(s) for s in spec_copy.sensors] ++ [spec_copy.delete(c) for c in spec_copy.cameras] +- recursive_immobilize(spec, temp_model, spec_copy.worldbody) ++ recursive_immobilize(spec_copy, temp_model, spec_copy.worldbody) + recursive_remove_contacts(spec_copy.worldbody, return_condition=None) + +@@ -438,7 +438,7 @@ class TableTennisEnvV0(BaseV0): + if mesh.name in meshes_to_mirror: + mesh.name += "_mirrored" + mesh.scale[1] *= -1 + else: +- mesh.delete() ++ spec_copy.delete(mesh) + + attachment_frame.attach_body(spec_copy.body("clavicle_mirrored")) +diff --git a/myosuite/utils/spec_processing.py b/myosuite/utils/spec_processing.py +index 6c64ee4..d4a0a4b 100644 +--- a/myosuite/utils/spec_processing.py ++++ b/myosuite/utils/spec_processing.py +@@ -3,18 +3,18 @@ import mujoco + def recursive_immobilize(spec, temp_model, parent, remove_eqs=False, remove_actuators=False): + removed_joint_ids = [] + for s in parent.sites: +- s.delete() ++ spec.delete(s) + for j in parent.joints: + removed_joint_ids.extend(temp_model.joint(j.name).qposadr) + if remove_eqs: + for e in spec.equalities: + if e.type == mujoco.mjtEq.mjEQ_JOINT and (e.name1 == j.name or e.name2 == j.name): +- e.delete() ++ spec.delete(e) + if remove_actuators: + for a in spec.actuators: + if a.trntype == mujoco.mjtTrn.mjTRN_JOINT and a.target == j.name: +- a.delete() +- j.delete() ++ spec.delete(a) ++ spec.delete(j) + for child in parent.bodies: + removed_joint_ids.extend( + recursive_immobilize(spec, temp_model, child, remove_eqs, remove_actuators) +@@ -38,7 +38,7 @@ def recursive_mirror(meshes_to_mirror, spec_copy, parent): + parent.name += "_mirrored" + for g in parent.geoms: + if g.type != mujoco.mjtGeom.mjGEOM_MESH: +- g.delete() ++ spec_copy.delete(g) + continue + g.pos[1] *= -1 + g.quat[[1, 3]] *= -1 +@@ -48,6 +48,6 @@ def recursive_mirror(meshes_to_mirror, spec_copy, parent): + g.meshname += "_mirrored" + for child in parent.bodies: + if "ping_pong" in child.name: +- spec_copy.detach_body(child) ++ spec_copy.delete(child) + continue + recursive_mirror(meshes_to_mirror, spec_copy, child) diff --git a/third_party/myosuite/myosuite_source.BUILD b/third_party/myosuite/myosuite_source.BUILD new file mode 100644 index 000000000..208590d80 --- /dev/null +++ b/third_party/myosuite/myosuite_source.BUILD @@ -0,0 +1,28 @@ +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +filegroup( + name = "source", + srcs = glob(["myosuite/**"]), + visibility = ["//visibility:public"], +) + +filegroup( + name = "runtime_assets", + srcs = glob([ + "myosuite/envs/myo/assets/**", + "myosuite/envs/myo/myodm/data/**", + ]), + visibility = ["//visibility:public"], +) diff --git a/third_party/myosuite/oracle_requirements.bzl b/third_party/myosuite/oracle_requirements.bzl new file mode 100644 index 000000000..f41d86a63 --- /dev/null +++ b/third_party/myosuite/oracle_requirements.bzl @@ -0,0 +1,20 @@ +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Thin wrapper around the pinned MyoSuite oracle requirement labels.""" + +load("@myosuite_oracle_requirements//:requirements.bzl", _requirement = "requirement") + +def oracle_requirement(name): + return _requirement(name) diff --git a/third_party/myosuite/oracle_requirements.txt b/third_party/myosuite/oracle_requirements.txt new file mode 100644 index 000000000..2ad27e434 --- /dev/null +++ b/third_party/myosuite/oracle_requirements.txt @@ -0,0 +1,41 @@ +absl-py==2.4.0 +attrs==26.1.0 +certifi==2026.4.22 +charset-normalizer==3.4.7 +click==8.3.3 +cloudpickle==3.1.2 +colorama==0.4.6 +dm-control==1.0.38 +dm-env==1.6 +dm-tree==0.1.10 +etils==1.14.0 +Farama-Notifications==0.0.6 +flatten-dict==0.5.0 +fsspec==2026.4.0 +gitdb==4.0.12 +gitpython==3.1.49 +glfw==2.10.0 +gymnasium==0.29.1 +h5py==3.16.0 +idna==3.13 +labmaze==1.0.6 +lxml==6.1.0 +mujoco==3.6.0 +numpy==1.26.4 +packaging==26.2 +pillow==12.2.0 +pink-noise-rl==2.0.1 +protobuf==7.34.1 +PyOpenGL==3.1.10 +pyparsing==3.3.2 +requests==2.33.1 +scipy==1.17.1 +setuptools==82.0.1 +sk-video==1.1.10 +smmap==5.0.3 +termcolor==3.3.0 +tqdm==4.67.3 +typing_extensions==4.15.0 +urllib3==2.6.3 +wrapt==2.1.2 +zipp==3.23.1 diff --git a/third_party/myosuite/oracle_workspace.bzl b/third_party/myosuite/oracle_workspace.bzl new file mode 100644 index 000000000..3bf798dba --- /dev/null +++ b/third_party/myosuite/oracle_workspace.bzl @@ -0,0 +1,44 @@ +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Workspace setup for the pinned MyoSuite oracle dependencies.""" + +load("@python_versions//:pip.bzl", "multi_pip_parse") + +def myosuite_oracle_pip_workspace(): + """Configure official MyoSuite dependencies used only by tests/codegen.""" + if "myosuite_oracle_requirements" in native.existing_rules().keys(): + return + + # The official MyoSuite package is only a test/codegen oracle. Keep its + # dependency hub pinned to Python 3.12 for every EnvPool toolchain key so + # generic targets such as //:setup_py314 do not try to build oracle-only + # wheels for unsupported Python/platform pairs. + multi_pip_parse( + name = "myosuite_oracle_requirements", + default_version = "3.12", + python_interpreter_target = { + "3.11": "@python_versions_3_12_host//:python", + "3.12": "@python_versions_3_12_host//:python", + "3.13": "@python_versions_3_12_host//:python", + "3.14": "@python_versions_3_12_host//:python", + }, + requirements_lock = { + "3.11": "@envpool//third_party/myosuite:oracle_requirements.txt", + "3.12": "@envpool//third_party/myosuite:oracle_requirements.txt", + "3.13": "@envpool//third_party/myosuite:oracle_requirements.txt", + "3.14": "@envpool//third_party/myosuite:oracle_requirements.txt", + }, + quiet = False, + ) diff --git a/third_party/myosuite/simhive_source.BUILD b/third_party/myosuite/simhive_source.BUILD new file mode 100644 index 000000000..b572a40d4 --- /dev/null +++ b/third_party/myosuite/simhive_source.BUILD @@ -0,0 +1,19 @@ +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +filegroup( + name = "source", + srcs = glob(["**"]), + visibility = ["//visibility:public"], +) diff --git a/third_party/re2c/BUILD b/third_party/re2c/BUILD new file mode 100644 index 000000000..f21140d2a --- /dev/null +++ b/third_party/re2c/BUILD @@ -0,0 +1,22 @@ +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:defs.bzl", "cc_library") + +cc_library( + name = "config", + hdrs = ["config.h"], + strip_include_prefix = ".", + visibility = ["//visibility:public"], +) diff --git a/third_party/re2c/config.h b/third_party/re2c/config.h new file mode 100644 index 000000000..ec7820416 --- /dev/null +++ b/third_party/re2c/config.h @@ -0,0 +1,37 @@ +// Copyright 2026 Garena Online Private Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_RE2C_CONFIG_H_ +#define THIRD_PARTY_RE2C_CONFIG_H_ + +#define PACKAGE_VERSION "4.5.1" + +#ifndef RE2C_STDLIB_DIR +#define RE2C_STDLIB_DIR "" +#endif + +#define HAVE_STDINT_H 1 +#define HAVE_STDLIB_H 1 +#define HAVE_STRING_H 1 + +#ifdef _WIN32 +#define HAVE_IO_H 1 +#else +#define HAVE_FCNTL_H 1 +#define HAVE_SYS_STAT_H 1 +#define HAVE_SYS_TYPES_H 1 +#define HAVE_UNISTD_H 1 +#endif + +#endif // THIRD_PARTY_RE2C_CONFIG_H_ diff --git a/third_party/re2c/re2c.BUILD b/third_party/re2c/re2c.BUILD new file mode 100644 index 000000000..31140f443 --- /dev/null +++ b/third_party/re2c/re2c.BUILD @@ -0,0 +1,48 @@ +# Copyright 2026 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:defs.bzl", "cc_binary") + +cc_binary( + name = "re2c", + srcs = glob( + ["src/**/*.cc"], + exclude = ["src/test/**"], + ) + [ + "bootstrap/src/msg/help_re2c.cc", + "bootstrap/src/options/parse_opts.cc", + "bootstrap/src/parse/conf_lexer.cc", + "bootstrap/src/parse/conf_parser.cc", + "bootstrap/src/parse/lexer.cc", + "bootstrap/src/parse/parser.cc", + ] + glob([ + "bootstrap/src/**/*.h", + "src/**/*.h", + ]), + copts = [ + "-DRE2C_STDLIB_DIR=\\\"\\\"", + ] + select({ + "@envpool//:windows": [ + "/D_CRT_SECURE_NO_WARNINGS", + "/DNOMINMAX", + ], + "//conditions:default": [], + }), + includes = [ + ".", + "bootstrap", + ], + visibility = ["//visibility:public"], + deps = ["@envpool//third_party/re2c:config"], +) diff --git a/third_party/sdl2/BUILD b/third_party/sdl2/BUILD index 12106201f..bbbb44029 100644 --- a/third_party/sdl2/BUILD +++ b/third_party/sdl2/BUILD @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +exports_files(["windows_xinput_stub.patch"]) diff --git a/third_party/sdl2/windows_xinput_stub.patch b/third_party/sdl2/windows_xinput_stub.patch new file mode 100644 index 000000000..8da8a0d7f --- /dev/null +++ b/third_party/sdl2/windows_xinput_stub.patch @@ -0,0 +1,22 @@ +--- src/joystick/windows/SDL_xinputjoystick.c ++++ src/joystick/windows/SDL_xinputjoystick.c +@@ -482,13 +482,19 @@ + int SDL_XINPUT_JoystickInit(void) + { + return 0; + } + + void SDL_XINPUT_JoystickDetect(JoyStick_DeviceData **pContext) + { + } + ++int SDL_XINPUT_GetSteamVirtualGamepadSlot(Uint8 userid) ++{ ++ (void)userid; ++ return -1; ++} ++ + int SDL_XINPUT_JoystickOpen(SDL_Joystick *joystick, JoyStick_DeviceData *joystickdevice) + { + return SDL_Unsupported(); + } diff --git a/third_party/vizdoom/vizdoom.BUILD b/third_party/vizdoom/vizdoom.BUILD index 5c8a42f48..4d7bae2df 100644 --- a/third_party/vizdoom/vizdoom.BUILD +++ b/third_party/vizdoom/vizdoom.BUILD @@ -38,6 +38,7 @@ genrule( srcs = [], outs = ["arith.h"], cmd = "$(execpath :arithchk) > $@", + cmd_bat = "$(execpath :arithchk) > $@", tools = [":arithchk"], ) @@ -54,6 +55,7 @@ genrule( srcs = [], outs = ["gd_qnan.h"], cmd = "$(execpath :qnan) > $@", + cmd_bat = "$(execpath :qnan) > $@", tools = [":qnan"], ) @@ -176,8 +178,14 @@ genrule( name = "sc_man_scanner", srcs = ["src/sc_man_scanner.re"], outs = ["src/sc_man_scanner.h"], - cmd = "$(execpath :re2c) --no-generation-date -s -o $@ $<", - tools = [":re2c"], + cmd = "$(execpath :re2c) --no-generation-date -s -o $@ " + + "$(location src/sc_man_scanner.re)", + cmd_bat = "$(execpath @re2c_4_5_1//:re2c) --no-generation-date " + + "-s -o $@ $(location src/sc_man_scanner.re)", + tools = select({ + "@envpool//:windows": ["@re2c_4_5_1//:re2c"], + "//conditions:default": [":re2c"], + }), ) genrule( @@ -185,6 +193,7 @@ genrule( srcs = ["tools/lemon/lempar.c"], outs = ["lempar.c"], cmd = "cp $(SRCS) $(RULEDIR)", + cmd_bat = "copy /Y $(SRCS) $(RULEDIR)", ) genrule( @@ -192,6 +201,7 @@ genrule( srcs = ["src/xlat/xlat_parser.y"], outs = ["xlat_parser.y"], cmd = "cp $< $@", + cmd_bat = "copy /Y $< $@", ) genrule( @@ -202,6 +212,7 @@ genrule( "xlat_parser.h", ], cmd = "$(execpath :lemon) $<", + cmd_bat = "$(execpath :lemon) $<", tools = [ ":lemon", ":lemon_deps", @@ -466,6 +477,7 @@ genrule( srcs = [":wadsrc"], outs = ["vizdoom.pk3"], cmd = "$(execpath zipdir) -udf $@ $<", + cmd_bat = "$(execpath zipdir) -udf $@ $<", tools = [":zipdir"], visibility = ["//visibility:public"], )