Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
rename te, tr to tm, tc
  • Loading branch information
pimpale committed Jun 13, 2023
commit e320c06d7332aa6413ff39229d3e15ab20174617
40 changes: 20 additions & 20 deletions metadrive/envs/marl_envs/marl_tollgate.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,9 @@ def get_single_observation(self, vehicle_config):
return o

def step(self, actions):
o, r, te, tr, i = super(MultiAgentTollgateEnv, self).step(actions)
o, r, tm, tc, i = super(MultiAgentTollgateEnv, self).step(actions)
self.stay_time_manager.record(self.agent_manager.active_agents, self.episode_step)
return o, r, te, tr, i
return o, r, tm, tc, i

def setup_engine(self):
super(MultiAgentTollgateEnv, self).setup_engine()
Expand Down Expand Up @@ -317,13 +317,13 @@ def _expert():
total_r = 0
ep_s = 0
for i in range(1, 100000):
o, r, te, tr, info = env.step(env.action_space.sample())
o, r, tm, tc, info = env.step(env.action_space.sample())
for r_ in r.values():
total_r += r_
ep_s += 1
te.update({"total_r": total_r, "episode length": ep_s})
tm.update({"total_r": total_r, "episode length": ep_s})
# env.render(text=d)
if te["__all__"]:
if tm["__all__"]:
print(
"Finish! Current step {}. Group Reward: {}. Average reward: {}".format(
i, total_r, total_r / env.agent_manager.next_agent_count
Expand Down Expand Up @@ -361,7 +361,7 @@ def _vis_debug_respawn():
ep_s = 0
for i in range(1, 100000):
action = {k: [.0, 1.0] for k in env.vehicles.keys()}
o, r, te, tr, info = env.step(action)
o, r, tm, tc, info = env.step(action)
for r_ in r.values():
total_r += r_
ep_s += 1
Expand All @@ -374,7 +374,7 @@ def _vis_debug_respawn():
"cam_z": env.main_camera.top_down_camera_height
}
env.render(text=render_text)
if te["__all__"]:
if tm["__all__"]:
print(
"Finish! Current step {}. Group Reward: {}. Average reward: {}".format(
i, total_r, total_r / env.agent_manager.next_agent_count
Expand Down Expand Up @@ -415,7 +415,7 @@ def _vis():
total_r = 0
ep_s = 0
for i in range(1, 100000):
o, r, te, tr, info = env.step({k: [0, 1] for k in env.vehicles.keys()})
o, r, tm, tc, info = env.step({k: [0, 1] for k in env.vehicles.keys()})
for r_ in r.values():
total_r += r_
ep_s += 1
Expand All @@ -436,7 +436,7 @@ def _vis():
render_text["lane"] = env.current_track_vehicle.lane_index
render_text["block"] = env.current_track_vehicle.navigation.current_road.block_ID()
env.render(text=render_text)
if te["__all__"]:
if tm["__all__"]:
print(info)
print(
"Finish! Current step {}. Group Reward: {}. Average reward: {}".format(
Expand All @@ -457,12 +457,12 @@ def _profile():
obs, _ = env.reset()
start = time.time()
for s in range(10000):
o, r, te, tr, i = env.step(env.action_space.sample())
o, r, tm, tc, i = env.step(env.action_space.sample())

# mask_ratio = env.engine.detector_mask.get_mask_ratio()
# print("Mask ratio: ", mask_ratio)

if all(te.values()):
if all(tm.values()):
env.reset()
if (s + 1) % 100 == 0:
print(
Expand All @@ -471,7 +471,7 @@ def _profile():
time.time() - start, (s + 1) / (time.time() - start)
)
)
print(f"(MetaDriveEnv) Total Time Elapse: {time.time() - start}")
print(f"(MetaDriveEnv) Total Time Elapsed: {time.time() - start}")


def _long_run():
Expand Down Expand Up @@ -500,30 +500,30 @@ def _long_run():
assert env.observation_space.contains(obs)
for step in range(10000):
act = env.action_space.sample()
o, r, te, tr, i = env.step(act)
o, r, tm, tc, i = env.step(act)
if step == 0:
assert not any(te.values())
assert not any(tm.values())

if any(te.values()):
print("Current Done: {}\nReward: {}".format(te, r))
for kkk, ddd in te.items():
if any(tm.values()):
print("Current Done: {}\nReward: {}".format(tm, r))
for kkk, ddd in tm.items():
if ddd and kkk != "__all__":
print("Info {}: {}\n".format(kkk, i[kkk]))
print("\n")

for kkk, rrr in r.items():
if rrr == -_out_of_road_penalty:
assert te[kkk]
assert tm[kkk]

if (step + 1) % 200 == 0:
print(
"{}/{} Agents: {} {}\nO: {}\nR: {}\nD: {}\nI: {}\n\n".format(
step + 1, 10000, len(env.vehicles), list(env.vehicles.keys()),
{k: (oo.shape, oo.mean(), oo.min(), oo.max())
for k, oo in o.items()}, r, te, i
for k, oo in o.items()}, r, tm, i
)
)
if te["__all__"]:
if tm["__all__"]:
print('Current step: ', step)
break
finally:
Expand Down
20 changes: 10 additions & 10 deletions metadrive/envs/marl_envs/tinyinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,13 +367,13 @@ def _get_reset_return(self):
return self.agent_manager.filter_RL_agents(org)

def step(self, actions):
o, r, te, tr, i = super(MultiAgentTinyInter, self).step(actions)
o, r, tm, tc, i = super(MultiAgentTinyInter, self).step(actions)

# if self.num_RL_agents == self.num_agents:
# return o, r, te, tr, i
# return o, r, tm, tc, i

original_done_dict = copy.deepcopy(te)
d = self.agent_manager.filter_RL_agents(te, original_done_dict=original_done_dict)
original_done_dict = copy.deepcopy(tm)
d = self.agent_manager.filter_RL_agents(tm, original_done_dict=original_done_dict)
if "__all__" in d:
d.pop("__all__")
# assert len(d) == self.agent_manager.num_RL_agents, d
Expand All @@ -382,7 +382,7 @@ def step(self, actions):
self.agent_manager.filter_RL_agents(o, original_done_dict=original_done_dict),
self.agent_manager.filter_RL_agents(r, original_done_dict=original_done_dict),
d,
self.agent_manager.filter_RL_agents(tr, original_done_dict=original_done_dict),
self.agent_manager.filter_RL_agents(tc, original_done_dict=original_done_dict),
self.agent_manager.filter_RL_agents(i, original_done_dict=original_done_dict),
)

Expand Down Expand Up @@ -445,11 +445,11 @@ def get_single_observation(self, vehicle_config: "Config"):
ep_reward_sum = 0.0
ep_success_reward_sum = 0.0
for i in range(1, 100000):
o, r, te, tr, info = env.step({k: [0.0, 1.0] for k in env.action_space.sample().keys()})
o, r, tm, tc, info = env.step({k: [0.0, 1.0] for k in env.action_space.sample().keys()})
# env.render("top_down", camera_position=(42.5, 0), film_size=(500, 500))
vehicles = env.vehicles

for k, v in te.items():
for k, v in tm.items():
if v and k in info:
ep_success += int(info[k]["arrive_dest"])
ep_reward_sum += int(info[k]["episode_reward"])
Expand All @@ -461,10 +461,10 @@ def get_single_observation(self, vehicle_config: "Config"):
# assert sum(
# [env.engine.get_policy(v.name).__class__.__name__ == "EnvInputPolicy" for k, v in vehicles.items()]
# ) == env.config["num_RL_agents"]
if any(te.values()):
print("Somebody dead.", te, info)
if any(tm.values()):
print("Somebody dead.", tm, info)
# print("Step {}. Policies: {}".format(i, {k: v['policy'] for k, v in info.items()}))
if te["__all__"]:
if tm["__all__"]:
# assert i >= 1000
print("Reset. ", i, info)
# break
Expand Down
4 changes: 2 additions & 2 deletions metadrive/envs/real_data_envs/waymo_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _merge_extra_config(self, config):
env.reset(force_seed=i)
while True:
step_start = time.time()
o, r, te, tr, info = env.step([0, 0])
o, r, tm, tc, info = env.step([0, 0])
assert env.observation_space.contains(o)
# c_lane = env.vehicle.lane
# long, lat, = c_lane.local_coordinates(env.vehicle.position)
Expand All @@ -75,7 +75,7 @@ def _merge_extra_config(self, config):
# mode="topdown"
)

if te or tr:
if tm or tc:
if info["arrive_dest"]:
print("seed:{}, success".format(env.engine.global_random_seed))
break
Loading