forked from DelinQu/SimplerEnv-OpenVLA
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstart_openvla.py
More file actions
73 lines (57 loc) · 2.46 KB
/
start_openvla.py
File metadata and controls
73 lines (57 loc) · 2.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import simpler_env
from simpler_env.utils.env.observation_utils import get_image_from_maniskill2_obs_dict
from simpler_env.policies.openvla.openvla_model import OpenVLAInference
import mediapy as media
import numpy as np
# Set up environment
# env = simpler_env.make('google_robot_pick_coke_can')
env = simpler_env.make('widowx_carrot_on_plate')
obs, reset_info = env.reset()
instruction = env.unwrapped.get_language_instruction()
print("Reset info", reset_info)
print("Instruction", instruction)
print("Control mode:", env.unwrapped.agent.control_mode)
# Initialize OpenVLA policy
openvla_policy = OpenVLAInference(
saved_model_path="openvla/openvla-7b", # Use Hugging Face model ID instead of local path
# policy_setup="google_robot"
policy_setup="widowx_bridge"
)
openvla_policy.reset(instruction)
# Array to store video frames
frames = []
# Main loop
done, truncated = False, False
while not (done or truncated):
# Get the observation image for OpenVLA
image = get_image_from_maniskill2_obs_dict(env.unwrapped, obs)
# Generate action using OpenVLA policy
openvla_out, action = openvla_policy.step(image=image, task_description=instruction)
# Extract values in the same structure as the sample action
# Position deltas (x, y, z)
position_deltas = action['world_vector'].tolist()
# Rotation (x, y, z) as axis-angle
rotation_deltas = action['rot_axangle'].tolist()
# Gripper state
gripper_state = [action['gripper'][0]]
# Construct a flat list matching the sample action format
transformed_action = position_deltas + rotation_deltas + gripper_state
# Convert the action list to a NumPy array
transformed_action = np.array(transformed_action, dtype=np.float32)
# Append the current frame to the frames list
frames.append(image)
# Perform environment step with the transformed action
obs, reward, done, truncated, info = env.step(transformed_action)
# Check for new instructions
new_instruction = env.unwrapped.get_language_instruction()
if new_instruction != instruction:
instruction = new_instruction
print("New Instruction", instruction)
openvla_policy.reset(instruction)
# Save the video of the episode
video_path = "openvla_policy_video_widowx_carrot_on_plate.mp4"
media.write_video(video_path, np.array(frames), fps=10)
print(f"Video saved at {video_path}")
# Episode stats
episode_stats = info.get('episode_stats', {})
print("Episode stats", episode_stats)