Skip to content

Commit 26b2e4e

Browse files
authored
[vizdoom] expose combined dm action as discrete (#350)
## Summary - add explicit discrete metadata to C++ specs and plumb it into Python array specs - use that metadata only for Vizdoom combined actions so dm_env exposes a discrete action spec even though the transport dtype is float - add a regression test for `make_spec("D1Basic-v1", use_combined_action=True).action_spec()` ## Testing - python3 -m py_compile envpool/python/data.py envpool/python/protocol.py envpool/vizdoom/vizdoom_test.py - devbox validation in progress: `make lint`, targeted Bazel tests, then `make bazel-test`
1 parent 996e97d commit 26b2e4e

6 files changed

Lines changed: 66 additions & 34 deletions

File tree

envpool/core/py_envpool.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ template <typename Spec>
104104
struct SpecTupleHelper {
105105
static decltype(auto) Make(const Spec& spec) {
106106
return std::make_tuple(py::dtype::of<typename Spec::dtype>(), spec.shape,
107-
spec.bounds, spec.elementwise_bounds);
107+
spec.bounds, spec.elementwise_bounds,
108+
spec.is_discrete);
108109
}
109110
};
110111

@@ -121,7 +122,8 @@ struct SpecTupleHelper<Spec<Container<dtype>>> {
121122
return std::make_tuple(py::dtype::of<dtype>(),
122123
std::make_tuple(spec.shape, spec.inner_spec.shape),
123124
spec.inner_spec.bounds,
124-
spec.inner_spec.elementwise_bounds);
125+
spec.inner_spec.elementwise_bounds,
126+
spec.inner_spec.is_discrete);
125127
}
126128
};
127129

envpool/core/spec.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ template <typename D>
5959
class Spec : public ShapeSpec {
6060
public:
6161
using dtype = D; // NOLINT
62+
bool is_discrete{false};
6263
std::tuple<dtype, dtype> bounds = {std::numeric_limits<dtype>::min(),
6364
std::numeric_limits<dtype>::max()};
6465
std::tuple<std::vector<dtype>, std::vector<dtype>> elementwise_bounds;
@@ -109,6 +110,12 @@ class Spec : public ShapeSpec {
109110
}
110111
};
111112

113+
template <typename D>
114+
Spec<D> MarkDiscrete(Spec<D> spec) {
115+
spec.is_discrete = true;
116+
return spec;
117+
}
118+
112119
template <typename dtype>
113120
class TArray;
114121

envpool/python/data.py

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,34 @@
2828
ACTION_THRESHOLD = 2**20
2929

3030

31+
def _maybe_scalar_int(value: Any) -> int | None:
32+
arr = np.asarray(value)
33+
if arr.size != 1:
34+
return None
35+
scalar = arr.item()
36+
integer = int(scalar)
37+
if not np.isclose(scalar, integer):
38+
return None
39+
return integer
40+
41+
42+
def _maybe_discrete_range(
43+
spec: ArraySpec, spec_type: str
44+
) -> tuple[int, int] | None:
45+
if np.prod(np.abs(spec.shape)) != 1:
46+
return None
47+
minimum = _maybe_scalar_int(spec.minimum)
48+
maximum = _maybe_scalar_int(spec.maximum)
49+
if minimum is None or maximum is None or maximum >= ACTION_THRESHOLD:
50+
return None
51+
if spec_type == "act":
52+
if not (spec.is_discrete or np.issubdtype(spec.dtype, np.integer)):
53+
return None
54+
elif not np.issubdtype(spec.dtype, np.integer):
55+
return None
56+
return minimum, maximum - minimum + 1
57+
58+
3159
def to_nested_dict(
3260
flatten_dict: dict[str, Any], generator: type = dict
3361
) -> dict[str, Any]:
@@ -70,16 +98,15 @@ def dm_spec_transform(
7098
name: str, spec: ArraySpec, spec_type: str
7199
) -> dm_env.specs.Array:
72100
"""Transform ArraySpec to dm_env compatible specs."""
73-
if (
74-
np.prod(np.abs(spec.shape)) == 1
75-
and np.isclose(spec.minimum, 0)
76-
and spec.maximum < ACTION_THRESHOLD
77-
):
78-
# special treatment for discrete action space
101+
discrete_range = _maybe_discrete_range(spec, spec_type)
102+
if discrete_range is not None and discrete_range[0] == 0:
103+
# dm_env only supports zero-based discrete arrays.
79104
return dm_env.specs.DiscreteArray(
80105
name=name,
81-
dtype=spec.dtype,
82-
num_values=int(spec.maximum - spec.minimum + 1),
106+
dtype=spec.dtype
107+
if np.issubdtype(spec.dtype, np.integer)
108+
else np.int32,
109+
num_values=discrete_range[1],
83110
)
84111
return dm_env.specs.BoundedArray(
85112
name=name,
@@ -92,19 +119,13 @@ def dm_spec_transform(
92119

93120
def gym_spec_transform(name: str, spec: ArraySpec, spec_type: str) -> gym.Space:
94121
"""Transform ArraySpec to gym.Env compatible spaces."""
95-
if (
96-
np.prod(np.abs(spec.shape)) == 1
97-
and np.isclose(spec.minimum, 0)
98-
and spec.maximum < ACTION_THRESHOLD
99-
):
100-
# special treatment for discrete action space
101-
discrete_range = int(spec.maximum - spec.minimum + 1)
122+
discrete_range = _maybe_discrete_range(spec, spec_type)
123+
if discrete_range is not None:
124+
start, num_values = discrete_range
102125
try:
103-
return gym.spaces.Discrete(
104-
n=discrete_range, start=int(spec.minimum)
105-
)
126+
return gym.spaces.Discrete(n=num_values, start=start)
106127
except TypeError: # old gym version doesn't have `start`
107-
return gym.spaces.Discrete(n=discrete_range)
128+
return gym.spaces.Discrete(n=num_values)
108129
return gym.spaces.Box(
109130
shape=[s for s in spec.shape if s != -1],
110131
dtype=spec.dtype,
@@ -117,16 +138,10 @@ def gymnasium_spec_transform(
117138
name: str, spec: ArraySpec, spec_type: str
118139
) -> gymnasium.Space:
119140
"""Transform ArraySpec to gymnasium.Env compatible spaces."""
120-
if (
121-
np.prod(np.abs(spec.shape)) == 1
122-
and np.isclose(spec.minimum, 0)
123-
and spec.maximum < ACTION_THRESHOLD
124-
):
125-
# special treatment for discrete action space
126-
discrete_range = int(spec.maximum - spec.minimum + 1)
127-
return gymnasium.spaces.Discrete(
128-
n=discrete_range, start=int(spec.minimum)
129-
)
141+
discrete_range = _maybe_discrete_range(spec, spec_type)
142+
if discrete_range is not None:
143+
start, num_values = discrete_range
144+
return gymnasium.spaces.Discrete(n=num_values, start=start)
130145
return gymnasium.spaces.Box(
131146
shape=[s for s in spec.shape if s != -1],
132147
dtype=spec.dtype,

envpool/python/protocol.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,12 @@ def __init__(
9696
shape: list[int],
9797
bounds: tuple[Any, Any],
9898
element_wise_bounds: tuple[Any, Any],
99+
is_discrete: bool = False,
99100
):
100101
"""Constructor of ArraySpec."""
101102
self.dtype = dtype
102103
self.shape = shape
104+
self.is_discrete = is_discrete
103105
if element_wise_bounds[0]:
104106
self.minimum = np.array(element_wise_bounds[0])
105107
else:

envpool/vizdoom/vizdoom_env.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ class VizdoomEnvFns {
123123
}
124124
auto action_set =
125125
BuildActionSet(button_list, conf["force_speed"_], delta_config);
126-
return MakeDict(
127-
"action"_.Bind(Spec<double>({-1}, {0.0, action_set.size() - 1.0})));
126+
return MakeDict("action"_.Bind(
127+
MarkDiscrete(Spec<double>({-1}, {0.0, action_set.size() - 1.0}))));
128128
}
129129
};
130130

envpool/vizdoom/vizdoom_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from absl.testing import absltest
2121

2222
import envpool.vizdoom.registration # noqa: F401
23-
from envpool.registration import make_dm, make_gym
23+
from envpool.registration import make_dm, make_gym, make_spec
2424

2525

2626
class _VizdoomEnvPoolBasicTest(absltest.TestCase):
@@ -155,6 +155,12 @@ def test_obs_space(self) -> None:
155155
== 1 * 4
156156
)
157157

158+
def test_action_spec(self) -> None:
159+
spec = make_spec("D1Basic-v1", use_combined_action=True)
160+
action_spec = spec.action_spec()
161+
assert action_spec.num_values == 6
162+
assert np.issubdtype(action_spec.dtype, np.integer)
163+
158164
def test_explicit_reset_with_episodic_life_gymnasium(self) -> None:
159165
env = make_gym(
160166
"D1Basic-v1",

0 commit comments

Comments
 (0)