Skip to content

Commit 466aa99

Browse files
authored
[Feature] Typed MetaData (#1428)
1 parent e9d8439 commit 466aa99

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed

tensordict/tensorclass.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4033,7 +4033,61 @@ def _stack_non_tensor(
40334033
return NonTensorStack(*list_of_non_tensor, stack_dim=dim)
40344034

40354035

4036-
class MetaData(NonTensorDataBase):
4036+
class _MetaDataMeta(_TensorClassMeta):
4037+
def __new__(
4038+
mcs,
4039+
name,
4040+
bases,
4041+
namespace,
4042+
datatype=None,
4043+
**kwargs,
4044+
):
4045+
# Create the class using the parent's __new__ method
4046+
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
4047+
if datatype is not None:
4048+
cls._datatype = datatype
4049+
# Initialize cache for typed classes
4050+
if not hasattr(cls, "_typed_class_cache"):
4051+
cls._typed_class_cache = {}
4052+
return cls
4053+
4054+
def __getitem__(cls, item):
4055+
"""Create a typed version of MetaData that validates the data type."""
4056+
if cls.__name__ != "MetaData":
4057+
# Only allow type specification on the base MetaData class
4058+
raise TypeError(f"Cannot specify type for {cls.__name__}")
4059+
4060+
# Check cache first
4061+
if item in cls._typed_class_cache: # type: ignore
4062+
return cls._typed_class_cache[item] # type: ignore
4063+
4064+
# Create a new class that validates the data type
4065+
type_name = getattr(item, "__name__", str(item))
4066+
class_name = f"MetaData[{type_name}]"
4067+
4068+
class TypedMetaData(cls):
4069+
_expected_type = item
4070+
4071+
def __post_init__(self):
4072+
super().__post_init__()
4073+
# Validate the data type
4074+
if not isinstance(self.data, item):
4075+
expected_name = getattr(item, "__name__", str(item))
4076+
actual_name = type(self.data).__name__
4077+
raise TypeError(
4078+
f"Expected data of type {expected_name}, got {actual_name}"
4079+
)
4080+
4081+
TypedMetaData.__name__ = class_name
4082+
TypedMetaData.__qualname__ = class_name
4083+
4084+
# Cache the class
4085+
cls._typed_class_cache[item] = TypedMetaData # type: ignore
4086+
4087+
return TypedMetaData
4088+
4089+
4090+
class MetaData(NonTensorDataBase, metaclass=_MetaDataMeta):
40374091
"""A non-tensor, metadata carrier class for `TensorDict`.
40384092
40394093
This class mainly behaves as :class:`~tensordict.NonTensorData`, except for indexing,
@@ -4049,6 +4103,8 @@ class MetaData(NonTensorDataBase):
40494103
40504104
"""
40514105

4106+
# Remove the __class_getitem__ method since the metaclass handles it
4107+
40524108
_load_memmap = classmethod(_load_memmap)
40534109
_from_dict = classmethod(_from_dict)
40544110
_from_tensordict = classmethod(_from_tensordict)

test/test_tensordict.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12853,6 +12853,18 @@ def test_where(self):
1285312853

1285412854

1285512855
class TestMetaData:
12856+
def test_typed_metadata(self):
12857+
d = MetaData[int](0, batch_size=(3,))
12858+
assert d.data == 0
12859+
assert isinstance(d, MetaData[int])
12860+
with pytest.raises(TypeError, match="Expected data of type int, got str"):
12861+
MetaData[int]("a string")
12862+
cls = MetaData[int]
12863+
assert issubclass(cls, MetaData)
12864+
d = cls(0, batch_size=(3,))
12865+
assert isinstance(d, MetaData)
12866+
# Test caching
12867+
assert isinstance(d, MetaData[int])
1285612868

1285712869
def test_expand(self):
1285812870
d = MetaData(0, batch_size=(3,))

0 commit comments

Comments
 (0)