Skip to content

nakamotoo/dsrl_pi0

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DSRL for π₀: Diffusion Steering via Reinforcement Learning

Overview

This repository provides the official implementation for our paper: Steering Your Diffusion Policy with Latent Space Reinforcement Learning (CoRL 2025).

Specifically, it contains a JAX-based implementation of DSRL (Diffusion Steering via Reinforcement Learning) for steering a pre-trained generalist policy, π₀, across various environments, including:

  • Simulation: Libero, Aloha
  • Real Robot: Franka

If you find this repository useful for your research, please cite:

@article{wagenmaker2025steering,
  author    = {Andrew Wagenmaker and Mitsuhiko Nakamoto and Yunchu Zhang and Seohong Park and Waleed Yagoub and Anusha Nagabandi and Abhishek Gupta and Sergey Levine},
  title     = {Steering Your Diffusion Policy with Latent Space Reinforcement Learning},
  journal   = {Conference on Robot Learning (CoRL)},
  year      = {2025},
}

Installation

  1. Create a conda environment:
conda create -n dsrl_pi0 python=3.11.11
conda activate dsrl_pi0
  1. Clone this repo with all submodules
git clone git@github.com:nakamotoo/dsrl_pi0.git --recurse-submodules
cd dsrl_pi0
  1. Install all packages and dependencies
pip install -e .
pip install -r requirements.txt
pip install "jax[cuda12]==0.5.0"

# install openpi
pip install -e openpi
pip install -e openpi/packages/openpi-client

# install Libero
pip install -e LIBERO
pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/cpu # needed for libero

Training (Simulation)

Libero

bash examples/scripts/run_libero.sh

Aloha

bash examples/scripts/run_aloha.sh

Training Logs

We provide sample W&B runs and logs: https://wandb.ai/mitsuhiko/DSRL_pi0_public

Training (Real)

For real-world experiments, we use the remote hosting feature from pi0 (see here) which enables us to host the pi0 model on a higher-spec remote server, in case the robot's client machine is not powerful enough.

  1. Setup Franka robot and install DROID package [link]

  2. [On the remote server] Host pi0 droid model on your remote server

cd openpi && python scripts/serve_policy.py --env=DROID
  1. [On your robot client machine] Run DSRL
bash examples/scripts/run_real.sh

Credits

This repository is built upon jaxrl2 and PTR repositories. In case of any questions, bugs, suggestions or improvements, please feel free to contact me at nakamoto[at]berkeley[dot]edu

About

Official implementation for pi0 steering via DSRL, Steering Your Diffusion Policy with Latent Space Reinforcement Learning (CoRL 2025)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages