2828ACTION_THRESHOLD = 2 ** 20
2929
3030
31+ def _maybe_scalar_int (value : Any ) -> int | None :
32+ arr = np .asarray (value )
33+ if arr .size != 1 :
34+ return None
35+ scalar = arr .item ()
36+ integer = int (scalar )
37+ if not np .isclose (scalar , integer ):
38+ return None
39+ return integer
40+
41+
42+ def _maybe_discrete_range (
43+ spec : ArraySpec , spec_type : str
44+ ) -> tuple [int , int ] | None :
45+ if np .prod (np .abs (spec .shape )) != 1 :
46+ return None
47+ minimum = _maybe_scalar_int (spec .minimum )
48+ maximum = _maybe_scalar_int (spec .maximum )
49+ if minimum is None or maximum is None or maximum >= ACTION_THRESHOLD :
50+ return None
51+ if spec_type == "act" :
52+ if not (spec .is_discrete or np .issubdtype (spec .dtype , np .integer )):
53+ return None
54+ elif not np .issubdtype (spec .dtype , np .integer ):
55+ return None
56+ return minimum , maximum - minimum + 1
57+
58+
3159def to_nested_dict (
3260 flatten_dict : dict [str , Any ], generator : type = dict
3361) -> dict [str , Any ]:
@@ -70,16 +98,15 @@ def dm_spec_transform(
7098 name : str , spec : ArraySpec , spec_type : str
7199) -> dm_env .specs .Array :
72100 """Transform ArraySpec to dm_env compatible specs."""
73- if (
74- np .prod (np .abs (spec .shape )) == 1
75- and np .isclose (spec .minimum , 0 )
76- and spec .maximum < ACTION_THRESHOLD
77- ):
78- # special treatment for discrete action space
101+ discrete_range = _maybe_discrete_range (spec , spec_type )
102+ if discrete_range is not None and discrete_range [0 ] == 0 :
103+ # dm_env only supports zero-based discrete arrays.
79104 return dm_env .specs .DiscreteArray (
80105 name = name ,
81- dtype = spec .dtype ,
82- num_values = int (spec .maximum - spec .minimum + 1 ),
106+ dtype = spec .dtype
107+ if np .issubdtype (spec .dtype , np .integer )
108+ else np .int32 ,
109+ num_values = discrete_range [1 ],
83110 )
84111 return dm_env .specs .BoundedArray (
85112 name = name ,
@@ -92,19 +119,13 @@ def dm_spec_transform(
92119
93120def gym_spec_transform (name : str , spec : ArraySpec , spec_type : str ) -> gym .Space :
94121 """Transform ArraySpec to gym.Env compatible spaces."""
95- if (
96- np .prod (np .abs (spec .shape )) == 1
97- and np .isclose (spec .minimum , 0 )
98- and spec .maximum < ACTION_THRESHOLD
99- ):
100- # special treatment for discrete action space
101- discrete_range = int (spec .maximum - spec .minimum + 1 )
122+ discrete_range = _maybe_discrete_range (spec , spec_type )
123+ if discrete_range is not None :
124+ start , num_values = discrete_range
102125 try :
103- return gym .spaces .Discrete (
104- n = discrete_range , start = int (spec .minimum )
105- )
126+ return gym .spaces .Discrete (n = num_values , start = start )
106127 except TypeError : # old gym version doesn't have `start`
107- return gym .spaces .Discrete (n = discrete_range )
128+ return gym .spaces .Discrete (n = num_values )
108129 return gym .spaces .Box (
109130 shape = [s for s in spec .shape if s != - 1 ],
110131 dtype = spec .dtype ,
@@ -117,16 +138,10 @@ def gymnasium_spec_transform(
117138 name : str , spec : ArraySpec , spec_type : str
118139) -> gymnasium .Space :
119140 """Transform ArraySpec to gymnasium.Env compatible spaces."""
120- if (
121- np .prod (np .abs (spec .shape )) == 1
122- and np .isclose (spec .minimum , 0 )
123- and spec .maximum < ACTION_THRESHOLD
124- ):
125- # special treatment for discrete action space
126- discrete_range = int (spec .maximum - spec .minimum + 1 )
127- return gymnasium .spaces .Discrete (
128- n = discrete_range , start = int (spec .minimum )
129- )
141+ discrete_range = _maybe_discrete_range (spec , spec_type )
142+ if discrete_range is not None :
143+ start , num_values = discrete_range
144+ return gymnasium .spaces .Discrete (n = num_values , start = start )
130145 return gymnasium .spaces .Box (
131146 shape = [s for s in spec .shape if s != - 1 ],
132147 dtype = spec .dtype ,
0 commit comments