@@ -4033,7 +4033,61 @@ def _stack_non_tensor(
40334033 return NonTensorStack (* list_of_non_tensor , stack_dim = dim )
40344034
40354035
4036- class MetaData (NonTensorDataBase ):
4036+ class _MetaDataMeta (_TensorClassMeta ):
4037+ def __new__ (
4038+ mcs ,
4039+ name ,
4040+ bases ,
4041+ namespace ,
4042+ datatype = None ,
4043+ ** kwargs ,
4044+ ):
4045+ # Create the class using the parent's __new__ method
4046+ cls = super ().__new__ (mcs , name , bases , namespace , ** kwargs )
4047+ if datatype is not None :
4048+ cls ._datatype = datatype
4049+ # Initialize cache for typed classes
4050+ if not hasattr (cls , "_typed_class_cache" ):
4051+ cls ._typed_class_cache = {}
4052+ return cls
4053+
4054+ def __getitem__ (cls , item ):
4055+ """Create a typed version of MetaData that validates the data type."""
4056+ if cls .__name__ != "MetaData" :
4057+ # Only allow type specification on the base MetaData class
4058+ raise TypeError (f"Cannot specify type for { cls .__name__ } " )
4059+
4060+ # Check cache first
4061+ if item in cls ._typed_class_cache : # type: ignore
4062+ return cls ._typed_class_cache [item ] # type: ignore
4063+
4064+ # Create a new class that validates the data type
4065+ type_name = getattr (item , "__name__" , str (item ))
4066+ class_name = f"MetaData[{ type_name } ]"
4067+
4068+ class TypedMetaData (cls ):
4069+ _expected_type = item
4070+
4071+ def __post_init__ (self ):
4072+ super ().__post_init__ ()
4073+ # Validate the data type
4074+ if not isinstance (self .data , item ):
4075+ expected_name = getattr (item , "__name__" , str (item ))
4076+ actual_name = type (self .data ).__name__
4077+ raise TypeError (
4078+ f"Expected data of type { expected_name } , got { actual_name } "
4079+ )
4080+
4081+ TypedMetaData .__name__ = class_name
4082+ TypedMetaData .__qualname__ = class_name
4083+
4084+ # Cache the class
4085+ cls ._typed_class_cache [item ] = TypedMetaData # type: ignore
4086+
4087+ return TypedMetaData
4088+
4089+
4090+ class MetaData (NonTensorDataBase , metaclass = _MetaDataMeta ):
40374091 """A non-tensor, metadata carrier class for `TensorDict`.
40384092
40394093 This class mainly behaves as :class:`~tensordict.NonTensorData`, except for indexing,
@@ -4049,6 +4103,8 @@ class MetaData(NonTensorDataBase):
40494103
40504104 """
40514105
4106+ # Remove the __class_getitem__ method since the metaclass handles it
4107+
40524108 _load_memmap = classmethod (_load_memmap )
40534109 _from_dict = classmethod (_from_dict )
40544110 _from_tensordict = classmethod (_from_tensordict )
0 commit comments