forked from MemTensor/MemOS
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
213 lines (177 loc) · 8.49 KB
/
utils.py
File metadata and controls
213 lines (177 loc) · 8.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import inspect
from abc import ABC
from typing import Any
import pytest
from pydantic import BaseModel
from pydantic.aliases import PydanticUndefined
def check_module_base_class(cls: Any) -> None:
"""
General function to test the correctness of an abstract base class.
- It should inherit from ABC.
- It should define at least one method.
- All methods should be marked as @abstractmethod.
- It should not be instantiable.
- All methods should have docstrings.
Args:
cls: The abstract base class to test.
"""
# Check 1: Ensure this is an abstract base class
assert issubclass(cls, ABC), f"{cls.__name__} should inherit from ABC"
# Get all non-excluded methods (excluding dunder methods, except for __init__)
all_class_methods = [name for name, _ in inspect.getmembers(cls, predicate=inspect.isfunction)]
# Check 2: Ensure the class defines methods
assert all_class_methods, f"{cls.__name__} should define at least one method"
# Check 3: Verify abstract methods
for method_name in all_class_methods:
method = getattr(cls, method_name)
# Skip private methods (starting with _) as they are typically helper methods
if method_name.startswith("_") and method_name != "__init__":
continue
assert getattr(method, "__isabstractmethod__", False), (
f"The method '{method_name}' in {cls.__name__} should be marked as @abstractmethod"
)
# Check 4: Test that the class cannot be instantiated directly
with pytest.raises(TypeError) as excinfo:
cls()
assert "abstract" in str(excinfo.value).lower(), (
f"{cls.__name__} should not be instantiable as it's an abstract base class"
)
# Check 5: Ensure all methods have docstrings
for method_name in all_class_methods:
method = getattr(cls, method_name)
assert method.__doc__, f"Method '{method_name}' in {cls.__name__} should have a docstring"
def check_module_factory_class(cls: Any) -> None:
"""
Generic function to test factory classes.
- It should inherit from a base class.
- It should have a backend_to_class attribute.
- It should have a from_config method.
- All registered backends should have valid classes.
- The backend_to_class attribute should be a dictionary.
- The backend_to_class attribute should map strings to classes that are subclasses of the base class.
Args:
cls: The module factory class to test
"""
# Check 1: Test if the module factory class is a subclass of the base class
assert len(cls.__bases__) == 1, "Factory class should have exactly one base class"
base_class = cls.__bases__[0]
# Check 2: Test if the module factory class has a backend_to_class attribute
assert hasattr(cls, "backend_to_class"), "Factory class should have backend_to_class attribute"
assert isinstance(cls.backend_to_class, dict), "backend_to_class should be a dictionary"
backend_to_module_mapping = cls.backend_to_class
# Check 3: Test if the module factory class has a from_config method
assert hasattr(cls, "from_config"), "Factory class should have from_config method"
# Check 4: Test if all registered backends have valid classes
for backend, module_class in backend_to_module_mapping.items():
assert isinstance(backend, str), f"Backend '{backend}' should be a string"
assert issubclass(module_class, base_class), (
f"{module_class} should be a subclass of {base_class}"
)
def check_config_base_class(
cls: BaseModel,
factory_fields: list[str] | None = None,
required_fields: list[str] | None = None,
optional_fields: list[str] | None = None,
reserved_fields: list[str] | None = None,
) -> None:
"""
Check if a configuration class is properly defined.
- It should inherit from Pydantic's BaseModel.
- It should have a model_config attribute.
- It should have a model_fields attribute.
- The factory_fields, required_fields, and optional_fields should be properly defined.
- It should have a ConfigDict as model_config.
Args:
cls: The config class to check
factory_fields: List of field names with default_factory.
required_fields: List of field names that should be required, despite factory fields.
optional_fields: List of field names that should be optional, despite factory fields.
reserved_fields: List of field names that should be ignored in the checks.
Like fields defined in `memos.configs.base.BaseConfig`.
"""
if reserved_fields is None:
reserved_fields = ["model_schema"]
# Check if the class is a subclass of BaseModel
assert inspect.isclass(cls), f"{cls} is not a class"
assert issubclass(cls, BaseModel), f"{cls} is not a Pydantic BaseModel"
# Check model_config
assert cls.model_config == {"extra": "forbid", "strict": True}, (
f"{cls} does not have the correct model_config"
)
# Check model_fields
factory_fields = factory_fields or []
required_fields = required_fields or []
optional_fields = optional_fields or []
actual_factory_fields = []
actual_required_fields = []
actual_optional_fields = []
for field_name, field_info in cls.model_fields.items():
if field_name in reserved_fields:
continue
elif field_info.default_factory is not None:
actual_factory_fields.append(field_name)
elif field_info.default == PydanticUndefined:
actual_required_fields.append(field_name)
else:
actual_optional_fields.append(field_name)
assert set(actual_factory_fields) == set(factory_fields), (
f"{cls} has incorrect factory fields: expected {actual_factory_fields}, got {factory_fields}"
)
assert set(actual_required_fields) == set(required_fields), (
f"{cls} has incorrect required fields: expected {actual_required_fields}, got {required_fields}"
)
assert set(actual_optional_fields) == set(optional_fields), (
f"{cls} has incorrect optional fields: expected {actual_optional_fields}, got {optional_fields}"
)
def check_config_factory_class(cls: BaseModel, expected_backends: list[str] | None = None) -> None:
"""
Check if a configuration factory is properly defined.
- It should inherit from Pydantic's BaseModel.
- It should have a backend_to_class attribute.
- It should have validate_backend and create_config methods.
- Expected backends should be supported.
Args:
cls: The config factory class to check
expected_backends: List of backend names that should be supported
"""
assert inspect.isclass(cls), f"{cls} is not a class"
assert issubclass(cls, BaseModel), f"{cls} is not a Pydantic BaseModel"
# Check required attributes
assert hasattr(cls, "backend_to_class"), f"{cls} has no backend_to_class attribute"
assert isinstance(cls.backend_to_class, dict), f"{cls.backend_to_class} is not a dict"
# Check required fields
assert "backend" in cls.model_fields, f"{cls} is missing 'backend' field"
assert "config" in cls.model_fields, f"{cls} is missing 'config' field"
# Check validators
assert hasattr(cls, "validate_backend"), f"{cls} has no validate_backend method"
assert hasattr(cls, "create_config"), f"{cls} has no create_config method"
# Check supported backends
if expected_backends:
for backend in expected_backends:
assert backend in cls.backend_to_class, f"{cls} does not support {backend} backend"
def check_config_instantiation_valid(cls: BaseModel, valid_config: dict) -> None:
"""
Test that a valid configuration can be instantiated.
Args:
cls: The config class to test
valid_config: Dictionary of valid configuration values
"""
config = cls.model_validate(valid_config)
assert isinstance(config, cls)
def check_config_instantiation_invalid(cls: BaseModel, invalid_config: dict | None = None) -> None:
"""
Test that invalid configurations raise the appropriate exceptions.
Args:
cls: The config class to test
invalid_config: Dictionary of invalid configuration values
"""
invalid_configs = [
{"impossible_field": "invalid_value"},
{"another_impossible_field": 2},
{"abcdef": 0.1, "ghijk": "lmn"},
]
if invalid_config is not None:
invalid_configs.append(invalid_config)
for invalid_config in invalid_configs:
with pytest.raises((ValueError, TypeError, Exception)):
cls.model_validate(invalid_config)