diff --git a/docs/index.md b/docs/index.md index ba4db72de..da8408e32 100644 --- a/docs/index.md +++ b/docs/index.md @@ -35,6 +35,7 @@ data_models/dimensions data_models/chunk_grids data_models/data_types data_models/compressors +template_registry ``` ```{toctree} diff --git a/docs/template_registry.md b/docs/template_registry.md new file mode 100644 index 000000000..3b51de811 --- /dev/null +++ b/docs/template_registry.md @@ -0,0 +1,254 @@ +# Template Registry Singleton + +A thread-safe singleton registry for managing dataset templates in MDIO applications. + +## Overview + +The `TemplateRegistry` implements the singleton pattern to ensure there's only one instance managing all dataset templates throughout the application lifecycle. This provides a centralized registry for template management with thread-safe operations. + +## Features + +- **Singleton Pattern**: Ensures only one registry instance exists +- **Thread Safety**: All operations are thread-safe using locks +- **Global Access**: Convenient global functions for common operations +- **Advanced Support**: Reset functionality for environment re-usability. +- **Default Templates**: The registry is instantiated with the default set of templates: + - PostStack2DTime + - PostStack3DTime + - PreStackCdpGathers3DTime + - PreStackShotGathers3DTime + - PostStack2DDepth + - PostStack3DDepth + - PreStackCdpGathers3DDepth + - PreStackShotGathers3DDepth + +## Usage + +### Basic Usage + +```python +from mdio.schemas.v1.templates.template_registry import TemplateRegistry + +# Get the singleton instance +registry = TemplateRegistry() + +# Or use the class method +registry = TemplateRegistry.get_instance() + +# Register a template +template = MyDatasetTemplate() +template_name=registry.register(template) +print(f"Registered template named {template_name}") + +# Retrieve a template using a well-known name +template = registry.get("my_template") +# Retrieve a template using the name returned when the template was registered +template = registry.get(template_name) + +# Check if template exists +if registry.is_registered("my_template"): + print("Template is registered") + +# List all templates +template_names = registry.list_all_templates() +``` + +### Global Functions + +For convenience, you can use global functions that operate on the singleton instance: + +```python +from mdio.schemas.v1.templates.template_registry import ( + register_template, + get_template, + is_template_registered, + list_templates +) + +# Register a template globally +register_template(Seismic3DTemplate()) + +# Get a template +template = get_template("seismic_3d") + +# Check registration +if is_template_registered("seismic_3d"): + print("Template available") + +# List all registered templates +templates = list_templates() +``` + +### Multiple Instantiation + +The singleton pattern ensures all instantiations return the same object: + +```python +registry1 = TemplateRegistry() +registry2 = TemplateRegistry() +registry3 = TemplateRegistry.get_instance() + +# All variables point to the same instance +assert registry1 is registry2 is registry3 +``` + +## API Reference + +### Core Methods + +#### `register(instance: AbstractDatasetTemplate) -> str` + +Registers a template instance and returns its normalized name. + +- **Parameters:** + - `instance`: Template instance implementing `AbstractDatasetTemplate` +- **Returns:** The template name +- **Raises:** `ValueError` if template name is already registered + +#### `get(template_name: str) -> AbstractDatasetTemplate` + +Retrieves a registered template by name. + +- **Parameters:** + - `template_name`: Name of the template (case-insensitive) +- **Returns:** The registered template instance +- **Raises:** `KeyError` if template is not registered + +#### `unregister(template_name: str) -> None` + +Removes a template from the registry. + +- **Parameters:** + - `template_name`: Name of the template to remove +- **Raises:** `KeyError` if template is not registered + +#### `is_registered(template_name: str) -> bool` + +Checks if a template is registered. + +- **Parameters:** + - `template_name`: Name of the template to check +- **Returns:** `True` if template is registered, `False` otherwise + +#### `list_all_templates() -> List[str]` + +Returns a list of all registered template names. + +- **Returns:** List of template names + +#### `clear() -> None` + +Removes all registered templates. Useful for testing. + +### Class Methods + +#### `get_instance() -> TemplateRegistry` + +Alternative way to get the singleton instance. + +- **Returns:** The singleton registry instance + +### Global Functions + +#### `get_template_registry() -> TemplateRegistry` + +Returns the global singleton registry instance. + +#### `register_template(template: AbstractDatasetTemplate) -> str` + +Registers a template in the global registry. + +#### `get_template(name: str) -> AbstractDatasetTemplate` + +Gets a template from the global registry. + +#### `is_template_registered(name: str) -> bool` + +Checks if a template is registered in the global registry. + +#### `list_templates() -> List[str]` + +Lists all templates in the global registry. + +## Thread Safety + +All operations on the registry are thread-safe: + +```python +import threading + +def register_templates(): + registry = TemplateRegistry() + for i in range(10): + template = MyTemplate(f"template_{i}") + registry.register(template) + +# Multiple threads can safely access the registry +threads = [threading.Thread(target=register_templates) for _ in range(5)] +for thread in threads: + thread.start() +for thread in threads: + thread.join() +``` + +## Best Practices + +1. **Use Global Functions**: For simple operations, prefer the global convenience functions +2. **Register Early**: Register all templates during application startup +3. **Thread Safety**: The registry is thread-safe, but individual templates may not be +4. **Testing Isolation**: Always reset the singleton in test setup/teardown + +## Example: Complete Template Management + +```python +from mdio.schemas.v1.templates.template_registry import TemplateRegistry +from mdio.schemas.v1.templates.seismic_3d_poststack import Seismic3DPostStackTemplate +from mdio.schemas.v1.templates.seismic_3d_prestack_time import Seismic3DPostStackTimeTemplate +from mdio.schemas.v1.templates.seismic_3d_prestack import Seismic3DPreStackTemplate + +def setup_templates(): + """Register MDIO templates runtime. + Custom templates can be created in external projects and added without modifying the MDIO library code + """ + # Use strongly-typed template + template_name = TemplateRegistry.register(Seismic3DPostStackTimeTemplate()) + print(f"Registered template named {template_name}") + # Use parametrized template + template_name = TemplateRegistry.register(Seismic3DPostStackTemplate("Depth")) + print(f"Registered template named {template_name}") + template_name = TemplateRegistry.register(Seismic3DPreStackTemplate()) + print(f"Registered template named {template_name}") + + print(f"Registered templates: {list_templates()}") + +# Application startup +setup_standard_templates() + +# Later in the application +template = TemplateRegistry().get_template("PostStack3DDepth") +dataset = template.create_dataset(name="Seismic 3d m/m/ft", + sizes = [256, 512, 384] + coord_units = [ + AllUnits(units_v1=LengthUnitModel(length=LengthUnitEnum.METER)), + AllUnits(units_v1=LengthUnitModel(length=LengthUnitEnum.METER)), + AllUnits(units_v1=LengthUnitModel(length=LengthUnitEnum.FOOT))] +``` + +## Error Handling + +The registry provides clear error messages: + +```python +# Template not registered +try: + template = get_template("nonexistent") +except KeyError as e: + print(f"Error: {e}") # "Template 'nonexistent' is not registered." + +# Duplicate registration +try: + register_template("duplicate", template1) + register_template("duplicate", template2) +except ValueError as e: + print(f"Error: {e}") # "Template 'duplicate' is already registered." +``` diff --git a/src/mdio/schemas/v1/templates/abstract_dataset_template.py b/src/mdio/schemas/v1/templates/abstract_dataset_template.py new file mode 100644 index 000000000..776a65ce0 --- /dev/null +++ b/src/mdio/schemas/v1/templates/abstract_dataset_template.py @@ -0,0 +1,152 @@ +"""Template method pattern implementation for MDIO v1 dataset template.""" + +from abc import ABC +from abc import abstractmethod + +from mdio.schemas import compressors +from mdio.schemas.chunk_grid import RegularChunkGrid +from mdio.schemas.chunk_grid import RegularChunkShape +from mdio.schemas.dtype import ScalarType +from mdio.schemas.metadata import ChunkGridMetadata +from mdio.schemas.metadata import UserAttributes +from mdio.schemas.v1.dataset import Dataset +from mdio.schemas.v1.dataset_builder import MDIODatasetBuilder +from mdio.schemas.v1.units import AllUnits + + +class AbstractDatasetTemplate(ABC): + """Abstract base class that defines the template method for Dataset building factory. + + The template method defines the skeleton of the data processing algorithm, + while allowing subclasses to override specific steps. + """ + + def __init__(self, domain: str = "") -> None: + # Template attributes to be overridden by subclasses + # Domain of the seismic data, e.g. "time" or "depth" + self._trace_domain = domain.lower() + # Names of all coordinate dimensions in the dataset + # e.g. ["cdp"] for 2D post-stack depth + # e.g. ["inline", "crossline"] for 3D post-stack + # e.g. ["inline", "crossline"] for 3D pre-stack CDP gathers + # Note: For pre-stack Shot gathers, the coordinates are defined differently + # and are not directly tied to _coord_dim_names. + self._coord_dim_names = [] + # Names of all dimensions in the dataset + # e.g. ["cdp", "depth"] for 2D post-stack depth + # e.g. ["inline", "crossline", "depth"] for 3D post-stack depth + # e.g. ["inline", "crossline", "offset", "depth"] for 3D pre-stack depth CPD gathers + # e.g. ["shot_point", "cable", "channel", "time"] for 3D pre-stack time Shot gathers + self._dim_names = [] + # Names of all coordinates in the dataset + # e.g. ["cdp-x", "cdp-y"] for 2D post-stack depth + # e.g. ["cdp-x", "cdp-y"] for 3D post-stack depth + # e.g. ["cdp-x", "cdp-y"] for 3D pre-stack CPD depth + # e.g. ["gun", "shot-x", "shot-y", "receiver-x", "receiver-y"] for 3D pre-stack + # time Shot gathers + self._coord_names = [] + # Name of the variable in the dataset + # e.g. "StackedAmplitude" for 2D post-stack depth + # e.g. "StackedAmplitude" for 3D post-stack depth + # e.g. "AmplitudeCDP" for 3D pre-stack CPD depth + # e.g. "AmplitudeShot" for 3D pre-stack time Shot gathers + self._var_name = "" + # Chunk shape for the variable in the dataset + # e.g. [1024, 1024] for 2D post-stack depth + # e.g. [128, 128, 128] for 3D post-stack depth + # e.g. [1, 1, 512, 4096] for 3D pre-stack CPD depth + # e.g. [1, 1, 512, 4096] for 3D pre-stack time Shot gathers + self._var_chunk_shape = [] + + # Variables instantiated when build_dataset() is called + self._builder: MDIODatasetBuilder = None + # Sizes of the dimensions in the dataset, to be set when build_dataset() is called + self._dim_sizes = [] + # Horizontal units for the coordinates (e.g, "m", "ft"), to be set when + # build_dataset() is called + self._coord_units = [] + + def build_dataset(self, name: str, sizes: list[int], coord_units: list[AllUnits]) -> Dataset: + """Template method that builds the dataset. + + Args: + name: The name of the dataset. + sizes: The sizes of the dimensions. + coord_units: The units for the coordinates. + + Returns: + Dataset: The constructed dataset + """ + self._dim_sizes = sizes + self._coord_units = coord_units + + self._builder = MDIODatasetBuilder(name=name, attributes=self._load_dataset_attributes()) + self._add_dimensions() + self._add_coordinates() + self._add_variables() + return self._builder.build() + + def get_name(self) -> str: + """Returns the name of the template.""" + return self._get_name() + + @abstractmethod + def _get_name(self) -> str: + """Abstract method to get the name of the template. + + Must be implemented by subclasses. + + Returns: + str: The name of the template + """ + + @abstractmethod + def _load_dataset_attributes(self) -> UserAttributes: + """Abstract method to load dataset attributes. + + Must be implemented by subclasses. + + Returns: + UserAttributes: The dataset attributes + """ + + def _add_dimensions(self) -> None: + """A virtual method that can be overwritten by subclasses to add custom dimensions. + + Uses the class field 'builder' to add dimensions to the dataset. + """ + for i in range(len(self._dim_names)): + self._builder.add_dimension(self._dim_names[i], self._dim_sizes[i]) + + def _add_coordinates(self) -> None: + """A virtual method that can be overwritten by subclasses to add custom coordinates. + + Uses the class field 'builder' to add coordinates to the dataset. + """ + for i in range(len(self._coord_names)): + self._builder.add_coordinate( + self._coord_names[i], + dimensions=self._coord_dim_names, + data_type=ScalarType.FLOAT64, + metadata_info=[self._coord_units], + ) + + def _add_variables(self) -> None: + """A virtual method that can be overwritten by subclasses to add custom variables. + + Uses the class field 'builder' to add variables to the dataset. + """ + self._builder.add_variable( + name=self._var_name, + dimensions=self._dim_names, + data_type=ScalarType.FLOAT32, + compressor=compressors.Blosc(algorithm=compressors.BloscAlgorithm.ZSTD), + coordinates=self._coord_names, + metadata_info=[ + ChunkGridMetadata( + chunk_grid=RegularChunkGrid( + configuration=RegularChunkShape(chunk_shape=self._var_chunk_shape) + ) + ) + ], + ) diff --git a/src/mdio/schemas/v1/templates/seismic_2d_poststack.py b/src/mdio/schemas/v1/templates/seismic_2d_poststack.py new file mode 100644 index 000000000..4f0d81b68 --- /dev/null +++ b/src/mdio/schemas/v1/templates/seismic_2d_poststack.py @@ -0,0 +1,29 @@ +"""Seismic2DPostStackTemplate MDIO v1 dataset templates.""" + +from mdio.schemas.metadata import UserAttributes +from mdio.schemas.v1.templates.abstract_dataset_template import AbstractDatasetTemplate + + +class Seismic2DPostStackTemplate(AbstractDatasetTemplate): + """Seismic post-stack 2D time or depth Dataset template.""" + + def __init__(self, domain: str): + super().__init__(domain=domain) + + self._coord_dim_names = ["cdp"] + self._dim_names = [*self._coord_dim_names, self._trace_domain] + self._coord_names = ["cdp-x", "cdp-y"] + self._var_name = "StackedAmplitude" + self._var_chunk_shape = [1024, 1024] + + def _get_name(self) -> str: + return f"PostStack2D{self._trace_domain.capitalize()}" + + def _load_dataset_attributes(self) -> UserAttributes: + return UserAttributes( + attributes={ + "surveyDimensionality": "2D", + "ensembleType": "line", + "processingStage": "post-stack", + } + ) diff --git a/src/mdio/schemas/v1/templates/seismic_3d_poststack.py b/src/mdio/schemas/v1/templates/seismic_3d_poststack.py new file mode 100644 index 000000000..3fd3950a3 --- /dev/null +++ b/src/mdio/schemas/v1/templates/seismic_3d_poststack.py @@ -0,0 +1,29 @@ +"""Seismic3DPostStackTemplate MDIO v1 dataset templates.""" + +from mdio.schemas.metadata import UserAttributes +from mdio.schemas.v1.templates.abstract_dataset_template import AbstractDatasetTemplate + + +class Seismic3DPostStackTemplate(AbstractDatasetTemplate): + """Seismic post-stack 3D time or depth Dataset template.""" + + def __init__(self, domain: str): + super().__init__(domain=domain) + # Template attributes to be overridden by subclasses + self._coord_dim_names = ["inline", "crossline"] + self._dim_names = [*self._coord_dim_names, self._trace_domain] + self._coord_names = ["cdp-x", "cdp-y"] + self._var_name = "StackedAmplitude" + self._var_chunk_shape = [128, 128, 128] + + def _get_name(self) -> str: + return f"PostStack3D{self._trace_domain.capitalize()}" + + def _load_dataset_attributes(self) -> UserAttributes: + return UserAttributes( + attributes={ + "surveyDimensionality": "3D", + "ensembleType": "line", + "processingStage": "post-stack", + } + ) diff --git a/src/mdio/schemas/v1/templates/seismic_3d_prestack_cdp.py b/src/mdio/schemas/v1/templates/seismic_3d_prestack_cdp.py new file mode 100644 index 000000000..fa95729b8 --- /dev/null +++ b/src/mdio/schemas/v1/templates/seismic_3d_prestack_cdp.py @@ -0,0 +1,29 @@ +"""Seismic3DPreStackCDPTemplate MDIO v1 dataset templates.""" + +from mdio.schemas.metadata import UserAttributes +from mdio.schemas.v1.templates.abstract_dataset_template import AbstractDatasetTemplate + + +class Seismic3DPreStackCDPTemplate(AbstractDatasetTemplate): + """Seismic CDP pre-stack 3D time or depth Dataset template.""" + + def __init__(self, domain: str): + super().__init__(domain=domain) + + self._coord_dim_names = ["inline", "crossline"] + self._dim_names = [*self._coord_dim_names, "offset", self._trace_domain] + self._coord_names = ["cdp-x", "cdp-y"] + self._var_name = "AmplitudeCDP" + self._var_chunk_shape = [1, 1, 512, 4096] + + def _get_name(self) -> str: + return f"PreStackCdpGathers3D{self._trace_domain.capitalize()}" + + def _load_dataset_attributes(self) -> UserAttributes: + return UserAttributes( + attributes={ + "surveyDimensionality": "3D", + "ensembleType": "cdp", + "processingStage": "pre-stack", + } + ) diff --git a/src/mdio/schemas/v1/templates/seismic_3d_prestack_shot.py b/src/mdio/schemas/v1/templates/seismic_3d_prestack_shot.py new file mode 100644 index 000000000..a39331276 --- /dev/null +++ b/src/mdio/schemas/v1/templates/seismic_3d_prestack_shot.py @@ -0,0 +1,63 @@ +"""Seismic3DPreStackShotTemplate MDIO v1 dataset templates.""" + +from mdio.schemas.dtype import ScalarType +from mdio.schemas.metadata import UserAttributes +from mdio.schemas.v1.templates.abstract_dataset_template import AbstractDatasetTemplate +from mdio.schemas.v1.units import AllUnits + + +class Seismic3DPreStackShotTemplate(AbstractDatasetTemplate): + """Seismic Shot pre-stack 3D time or depth Dataset template.""" + + def __init__(self, domain: str): + super().__init__(domain=domain) + + self._coord_dim_names = [] # Custom coordinate definition for shot gathers + self._dim_names = ["shot_point", "cable", "channel", self._trace_domain] + self._coord_names = ["gun", "shot-x", "shot-y", "receiver-x", "receiver-y"] + self._var_name = "AmplitudeShot" + self._var_chunk_shape = [1, 1, 512, 4096] + + def _get_name(self) -> str: + return f"PreStackShotGathers3D{self._trace_domain.capitalize()}" + + def _load_dataset_attributes(self) -> UserAttributes: + return UserAttributes( + attributes={ + "surveyDimensionality": "3D", + "ensembleType": "shot", + "processingStage": "pre-stack", + } + ) + + def _add_coordinates(self) -> None: + self._builder.add_coordinate( + "gun", + dimensions=["shot_point"], + data_type=ScalarType.UINT8, + metadata_info=[AllUnits(units_v1=None)], + ) + self._builder.add_coordinate( + "shot-x", + dimensions=["shot_point"], + data_type=ScalarType.FLOAT64, + metadata_info=[self._coord_units[0]], + ) + self._builder.add_coordinate( + "shot-y", + dimensions=["shot_point"], + data_type=ScalarType.FLOAT64, + metadata_info=[self._coord_units[1]], + ) + self._builder.add_coordinate( + "receiver-x", + dimensions=["shot_point", "cable", "channel"], + data_type=ScalarType.FLOAT64, + metadata_info=[self._coord_units[0]], + ) + self._builder.add_coordinate( + "receiver-y", + dimensions=["shot_point", "cable", "channel"], + data_type=ScalarType.FLOAT64, + metadata_info=[self._coord_units[1]], + ) diff --git a/src/mdio/schemas/v1/templates/template_registry.py b/src/mdio/schemas/v1/templates/template_registry.py new file mode 100644 index 000000000..c674b0c20 --- /dev/null +++ b/src/mdio/schemas/v1/templates/template_registry.py @@ -0,0 +1,212 @@ +"""Template registry for MDIO v1 dataset templates.""" + +import threading +from typing import Optional + +from mdio.schemas.v1.templates.abstract_dataset_template import AbstractDatasetTemplate +from mdio.schemas.v1.templates.seismic_2d_poststack import Seismic2DPostStackTemplate +from mdio.schemas.v1.templates.seismic_3d_poststack import Seismic3DPostStackTemplate +from mdio.schemas.v1.templates.seismic_3d_prestack_cdp import Seismic3DPreStackCDPTemplate +from mdio.schemas.v1.templates.seismic_3d_prestack_shot import Seismic3DPreStackShotTemplate + + +class TemplateRegistry: + """A thread-safe singleton registry for dataset templates.""" + + _instance: Optional["TemplateRegistry"] = None + _lock = threading.RLock() + _initialized = False + + def __new__(cls) -> "TemplateRegistry": + """Create or return the singleton instance. + + Uses double-checked locking pattern to ensure thread safety. + + Returns: + The singleton instance of TemplateRegistry. + """ + if cls._instance is None: + with cls._lock: + # Double-checked locking pattern + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self) -> None: + if not self._initialized: + with self._lock: + if not self._initialized: + self._templates: dict[str, AbstractDatasetTemplate] = {} + self._registry_lock = threading.RLock() + self._register_default_templates() + TemplateRegistry._initialized = True + + def register(self, instance: AbstractDatasetTemplate) -> str: + """Register a template instance by its name. + + Args: + instance: An instance of template to register. + + Returns: + The name of the registered template. + + Raises: + ValueError: If the template name is already registered. + """ + with self._registry_lock: + name = instance.get_name() + if name in self._templates: + err = f"Template '{name}' is already registered." + raise ValueError(err) + self._templates[name] = instance + return name + + def _register_default_templates(self) -> None: + """Register default templates if needed. + + This method can be overridden by subclasses to register default templates. + """ + self.register(Seismic2DPostStackTemplate("time")) + self.register(Seismic2DPostStackTemplate("depth")) + + self.register(Seismic3DPostStackTemplate("time")) + self.register(Seismic3DPostStackTemplate("depth")) + + self.register(Seismic3DPreStackCDPTemplate("time")) + self.register(Seismic3DPreStackCDPTemplate("depth")) + + self.register(Seismic3DPreStackShotTemplate("time")) + self.register(Seismic3DPreStackShotTemplate("depth")) + + def get(self, template_name: str) -> AbstractDatasetTemplate: + """Get a template from the registry by its name. + + Args: + template_name: The name of the template to retrieve. + + Returns: + The template instance if found. + + Raises: + KeyError: If the template is not registered. + """ + with self._registry_lock: + name = template_name + if name not in self._templates: + err = f"Template '{name}' is not registered." + raise KeyError(err) + return self._templates[name] + + def unregister(self, template_name: str) -> None: + """Unregister a template from the registry. + + Args: + template_name: The name of the template to unregister. + + Raises: + KeyError: If the template is not registered. + """ + with self._registry_lock: + name = template_name + if name not in self._templates: + err_msg = f"Template '{name}' is not registered." + raise KeyError(err_msg) + del self._templates[name] + + def is_registered(self, template_name: str) -> bool: + """Check if a template is registered in the registry. + + Args: + template_name: The name of the template to check. + + Returns: + True if the template is registered, False otherwise. + """ + with self._registry_lock: + name = template_name + return name in self._templates + + def list_all_templates(self) -> list[str]: + """Get all registered template names. + + Returns: + A list of all registered template names. + """ + with self._registry_lock: + return list(self._templates.keys()) + + def clear(self) -> None: + """Clear all registered templates (useful for testing).""" + with self._registry_lock: + self._templates.clear() + + @classmethod + def get_instance(cls) -> "TemplateRegistry": + """Get the singleton instance (alternative to constructor). + + Returns: + The singleton instance of TemplateRegistry. + """ + return cls() + + @classmethod + def _reset_instance(cls) -> None: + """Reset the singleton instance (useful for testing).""" + with cls._lock: + cls._instance = None + cls._initialized = False + + +# Global convenience functions +def get_template_registry() -> TemplateRegistry: + """Get the global template registry instance. + + Returns: + The singleton instance of TemplateRegistry. + """ + return TemplateRegistry.get_instance() + + +def register_template(template: AbstractDatasetTemplate) -> str: + """Register a template in the global registry. + + Args: + template: An instance of AbstractDatasetTemplate to register. + + Returns: + The name of the registered template. + """ + return get_template_registry().register(template) + + +def get_template(name: str) -> AbstractDatasetTemplate: + """Get a template from the global registry. + + Args: + name: The name of the template to retrieve. + + Returns: + The template instance if found. + """ + return get_template_registry().get(name) + + +def is_template_registered(name: str) -> bool: + """Check if a template is registered in the global registry. + + Args: + name: The name of the template to check. + + Returns: + True if the template is registered, False otherwise. + """ + return get_template_registry().is_registered(name) + + +def list_templates() -> list[str]: + """List all registered template names. + + Returns: + A list of all registered template names. + """ + return get_template_registry().list_all_templates() diff --git a/tests/unit/v1/helpers.py b/tests/unit/v1/helpers.py index d0ebe5d79..f9d209690 100644 --- a/tests/unit/v1/helpers.py +++ b/tests/unit/v1/helpers.py @@ -155,10 +155,10 @@ def output_path(file_dir: str, file_name: str, debugging: bool = False) -> str: return file_path -def make_seismic_poststack_3d_acceptance_dataset() -> Dataset: +def make_seismic_poststack_3d_acceptance_dataset(dataset_name: str) -> Dataset: """Create in-memory Seismic PostStack 3D Acceptance dataset.""" ds = MDIODatasetBuilder( - "campos_3d", + dataset_name, attributes=UserAttributes( attributes={ "textHeader": [ diff --git a/tests/unit/v1/templates/test_seismic_2d_poststack.py b/tests/unit/v1/templates/test_seismic_2d_poststack.py new file mode 100644 index 000000000..79f795854 --- /dev/null +++ b/tests/unit/v1/templates/test_seismic_2d_poststack.py @@ -0,0 +1,192 @@ +"""Unit tests for Seismic2DPostStackTemplate.""" + +from tests.unit.v1.helpers import validate_variable + +from mdio.schemas.chunk_grid import RegularChunkGrid +from mdio.schemas.dtype import ScalarType +from mdio.schemas.v1.templates.seismic_2d_poststack import Seismic2DPostStackTemplate +from mdio.schemas.v1.units import AllUnits +from mdio.schemas.v1.units import LengthUnitEnum +from mdio.schemas.v1.units import LengthUnitModel +from mdio.schemas.v1.units import TimeUnitEnum +from mdio.schemas.v1.units import TimeUnitModel + +_UNIT_METER = AllUnits(units_v1=LengthUnitModel(length=LengthUnitEnum.METER)) +_UNIT_SECOND = AllUnits(units_v1=TimeUnitModel(time=TimeUnitEnum.SECOND)) + + +class TestSeismic2DPostStackTemplate: + """Unit tests for Seismic2DPostStackTemplate.""" + + def test_configuration_depth(self) -> None: + """Test configuration of Seismic2DPostStackTemplate with depth domain.""" + t = Seismic2DPostStackTemplate("depth") + + # Template attributes + assert t._trace_domain == "depth" + assert t._coord_dim_names == ["cdp"] + assert t._dim_names == ["cdp", "depth"] + assert t._coord_names == ["cdp-x", "cdp-y"] + assert t._var_name == "StackedAmplitude" + assert t._var_chunk_shape == [1024, 1024] + + # Variables instantiated when build_dataset() is called + assert t._builder is None + assert t._dim_sizes == [] + assert t._coord_units == [] + + # Verify dataset attributes + attrs = t._load_dataset_attributes() + assert attrs.attributes == { + "surveyDimensionality": "2D", + "ensembleType": "line", + "processingStage": "post-stack", + } + + def test_configuration_time(self) -> None: + """Test configuration of Seismic2DPostStackTemplate with time domain.""" + t = Seismic2DPostStackTemplate("time") + + # Template attributes + assert t._trace_domain == "time" + assert t._coord_dim_names == ["cdp"] + assert t._dim_names == ["cdp", "time"] + assert t._coord_names == ["cdp-x", "cdp-y"] + assert t._var_name == "StackedAmplitude" + assert t._var_chunk_shape == [1024, 1024] + + # Variables instantiated when build_dataset() is called + assert t._builder is None + assert t._dim_sizes == [] + assert t._coord_units == [] + + def test_domain_case_handling(self) -> None: + """Test that domain parameter handles different cases correctly.""" + # Test uppercase + t1 = Seismic2DPostStackTemplate("ELEVATION") + assert t1._trace_domain == "elevation" + assert t1.get_name() == "PostStack2DElevation" + + # Test mixed case + t2 = Seismic2DPostStackTemplate("elevatioN") + assert t2._trace_domain == "elevation" + assert t2.get_name() == "PostStack2DElevation" + + def test_build_dataset_depth(self) -> None: + """Test building a complete 2D depth dataset.""" + t = Seismic2DPostStackTemplate("depth") + + dataset = t.build_dataset( + "Seismic 2D Depth Line 001", + sizes=[2048, 4096], + coord_units=[_UNIT_METER, _UNIT_METER], # Both coordinates and depth in meters + ) + + # Verify dataset metadata + assert dataset.metadata.name == "Seismic 2D Depth Line 001" + assert dataset.metadata.attributes["surveyDimensionality"] == "2D" + assert dataset.metadata.attributes["ensembleType"] == "line" + assert dataset.metadata.attributes["processingStage"] == "post-stack" + + # 2 coordinate variables + 1 data variable = 5 variables + assert len(dataset.variables) == 3 + + # Verify coordinate variables + cdp_x = validate_variable( + dataset, + name="cdp-x", + dims=[("cdp", 2048)], + coords=["cdp-x"], + dtype=ScalarType.FLOAT64, + ) + assert cdp_x.metadata.units_v1.length == LengthUnitEnum.METER + + cdp_y = validate_variable( + dataset, + name="cdp-y", + dims=[("cdp", 2048)], + coords=["cdp-y"], + dtype=ScalarType.FLOAT64, + ) + assert cdp_y.metadata.units_v1.length == LengthUnitEnum.METER + + # Verify seismic variable + seismic = validate_variable( + dataset, + name="StackedAmplitude", + dims=[("cdp", 2048), ("depth", 4096)], + coords=["cdp-x", "cdp-y"], + dtype=ScalarType.FLOAT32, + ) + assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid) + assert seismic.metadata.chunk_grid.configuration.chunk_shape == [1024, 1024] + assert seismic.metadata.stats_v1 is None + + def test_build_dataset_time(self) -> None: + """Test building a complete 2D time dataset.""" + t = Seismic2DPostStackTemplate("time") + + dataset = t.build_dataset( + "Seismic 2D Time Line 001", + sizes=[2048, 4096], + coord_units=[_UNIT_METER, _UNIT_METER], # Coordinates in meters, time in seconds + ) + + # Verify dataset metadata + assert dataset.metadata.name == "Seismic 2D Time Line 001" + assert dataset.metadata.attributes["surveyDimensionality"] == "2D" + assert dataset.metadata.attributes["ensembleType"] == "line" + assert dataset.metadata.attributes["processingStage"] == "post-stack" + + # Verify variables count + assert len(dataset.variables) == 3 + + # Verify coordinate variables + v = validate_variable( + dataset, + name="cdp-x", + dims=[("cdp", 2048)], + coords=["cdp-x"], + dtype=ScalarType.FLOAT64, + ) + assert v.metadata.units_v1.length == LengthUnitEnum.METER + + v = validate_variable( + dataset, + name="cdp-y", + dims=[("cdp", 2048)], + coords=["cdp-y"], + dtype=ScalarType.FLOAT64, + ) + assert v.metadata.units_v1.length == LengthUnitEnum.METER + + # Verify seismic variable + v = validate_variable( + dataset, + name="StackedAmplitude", + dims=[("cdp", 2048), ("time", 4096)], + coords=["cdp-x", "cdp-y"], + dtype=ScalarType.FLOAT32, + ) + assert isinstance(v.metadata.chunk_grid, RegularChunkGrid) + assert v.metadata.chunk_grid.configuration.chunk_shape == [1024, 1024] + assert v.metadata.stats_v1 is None + + def test_time_vs_depth_comparison(self) -> None: + """Test differences between time and depth templates.""" + time_template = Seismic2DPostStackTemplate("time") + depth_template = Seismic2DPostStackTemplate("depth") + + # Different trace domains + assert time_template._trace_domain == "time" + assert depth_template._trace_domain == "depth" + + # Different names + assert time_template.get_name() == "PostStack2DTime" + assert depth_template.get_name() == "PostStack2DDepth" + + # Same other attributes + assert time_template._coord_dim_names == depth_template._coord_dim_names + assert time_template._coord_names == depth_template._coord_names + assert time_template._var_name == depth_template._var_name + assert time_template._var_chunk_shape == depth_template._var_chunk_shape diff --git a/tests/unit/v1/templates/test_seismic_3d_poststack.py b/tests/unit/v1/templates/test_seismic_3d_poststack.py new file mode 100644 index 000000000..3e205b0ae --- /dev/null +++ b/tests/unit/v1/templates/test_seismic_3d_poststack.py @@ -0,0 +1,182 @@ +"""Unit tests for Seismic3DPostStackTemplate.""" + +from tests.unit.v1.helpers import validate_variable + +from mdio.schemas.chunk_grid import RegularChunkGrid +from mdio.schemas.compressors import Blosc +from mdio.schemas.dtype import ScalarType +from mdio.schemas.v1.templates.seismic_3d_poststack import Seismic3DPostStackTemplate +from mdio.schemas.v1.units import AllUnits +from mdio.schemas.v1.units import LengthUnitEnum +from mdio.schemas.v1.units import LengthUnitModel +from mdio.schemas.v1.units import TimeUnitEnum +from mdio.schemas.v1.units import TimeUnitModel + +_UNIT_METER = AllUnits(units_v1=LengthUnitModel(length=LengthUnitEnum.METER)) +_UNIT_SECOND = AllUnits(units_v1=TimeUnitModel(time=TimeUnitEnum.SECOND)) + + +class TestSeismic3DPostStackTemplate: + """Unit tests for Seismic3DPostStackTemplate.""" + + def test_configuration_depth(self) -> None: + """Unit tests for Seismic3DPostStackTemplate with depth domain.""" + t = Seismic3DPostStackTemplate(domain="depth") + + # Template attributes to be overridden by subclasses + assert t._trace_domain == "depth" # Domain should be lowercased + assert t._coord_dim_names == ["inline", "crossline"] + assert t._dim_names == ["inline", "crossline", "depth"] + assert t._coord_names == ["cdp-x", "cdp-y"] + assert t._var_chunk_shape == [128, 128, 128] + + # Variables instantiated when build_dataset() is called + assert t._builder is None + assert t._dim_sizes == [] + assert t._coord_units == [] + + assert t._load_dataset_attributes().attributes == { + "surveyDimensionality": "3D", + "ensembleType": "line", + "processingStage": "post-stack", + } + + assert t.get_name() == "PostStack3DDepth" + + def test_configuration_time(self) -> None: + """Unit tests for Seismic3DPostStackTemplate with time domain.""" + t = Seismic3DPostStackTemplate(domain="time") + + # Template attributes to be overridden by subclasses + assert t._trace_domain == "time" # Domain should be lowercased + assert t._coord_dim_names == ["inline", "crossline"] + assert t._dim_names == ["inline", "crossline", "time"] + assert t._coord_names == ["cdp-x", "cdp-y"] + assert t._var_chunk_shape == [128, 128, 128] + + # Variables instantiated when build_dataset() is called + assert t._builder is None + assert t._dim_sizes == [] + assert t._coord_units == [] + + assert t._load_dataset_attributes().attributes == { + "surveyDimensionality": "3D", + "ensembleType": "line", + "processingStage": "post-stack", + } + + assert t.get_name() == "PostStack3DTime" + + def test_domain_case_handling(self) -> None: + """Test that domain parameter handles different cases correctly.""" + # Test uppercase + t1 = Seismic3DPostStackTemplate("ELEVATION") + assert t1._trace_domain == "elevation" + assert t1.get_name() == "PostStack3DElevation" + + # Test mixed case + t2 = Seismic3DPostStackTemplate("elevatioN") + assert t2._trace_domain == "elevation" + assert t2.get_name() == "PostStack3DElevation" + + def test_build_dataset_depth(self) -> None: + """Unit tests for Seismic3DPostStackTemplate build with depth domain.""" + t = Seismic3DPostStackTemplate(domain="depth") + + assert t.get_name() == "PostStack3DDepth" + dataset = t.build_dataset( + "Seismic 3D", sizes=[256, 512, 1024], coord_units=[_UNIT_METER, _UNIT_METER] + ) + + assert dataset.metadata.name == "Seismic 3D" + assert dataset.metadata.attributes["surveyDimensionality"] == "3D" + assert dataset.metadata.attributes["ensembleType"] == "line" + assert dataset.metadata.attributes["processingStage"] == "post-stack" + + # Verify variables + # 2 coordinate variables + 1 data variables = 3 variables + assert len(dataset.variables) == 3 + + # Verify coordinate variables + cdp_x = validate_variable( + dataset, + name="cdp-x", + dims=[("inline", 256), ("crossline", 512)], + coords=["cdp-x"], + dtype=ScalarType.FLOAT64, + ) + assert cdp_x.metadata.units_v1.length == LengthUnitEnum.METER + + cdp_y = validate_variable( + dataset, + name="cdp-y", + dims=[("inline", 256), ("crossline", 512)], + coords=["cdp-y"], + dtype=ScalarType.FLOAT64, + ) + assert cdp_y.metadata.units_v1.length == LengthUnitEnum.METER + + # Verify seismic variable + seismic = validate_variable( + dataset, + name="StackedAmplitude", + dims=[("inline", 256), ("crossline", 512), ("depth", 1024)], + coords=["cdp-x", "cdp-y"], + dtype=ScalarType.FLOAT32, + ) + assert isinstance(seismic.compressor, Blosc) + assert seismic.compressor.algorithm == "zstd" + assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid) + assert seismic.metadata.chunk_grid.configuration.chunk_shape == [128, 128, 128] + assert seismic.metadata.stats_v1 is None + + def test_build_dataset_time(self) -> None: + """Unit tests for Seismic3DPostStackTimeTemplate build with time domain.""" + t = Seismic3DPostStackTemplate(domain="time") + + assert t.get_name() == "PostStack3DTime" + dataset = t.build_dataset( + "Seismic 3D", sizes=[256, 512, 1024], coord_units=[_UNIT_METER, _UNIT_METER] + ) + + assert dataset.metadata.name == "Seismic 3D" + assert dataset.metadata.attributes["surveyDimensionality"] == "3D" + assert dataset.metadata.attributes["ensembleType"] == "line" + assert dataset.metadata.attributes["processingStage"] == "post-stack" + + # Verify variables + # 2 coordinate variables + 1 data variables = 3 variables + assert len(dataset.variables) == 3 + + # Verify coordinate variables + cdp_x = validate_variable( + dataset, + name="cdp-x", + dims=[("inline", 256), ("crossline", 512)], + coords=["cdp-x"], + dtype=ScalarType.FLOAT64, + ) + assert cdp_x.metadata.units_v1.length == LengthUnitEnum.METER + + cdp_y = validate_variable( + dataset, + name="cdp-y", + dims=[("inline", 256), ("crossline", 512)], + coords=["cdp-y"], + dtype=ScalarType.FLOAT64, + ) + assert cdp_y.metadata.units_v1.length == LengthUnitEnum.METER + + # Verify seismic variable + seismic = validate_variable( + dataset, + name="StackedAmplitude", + dims=[("inline", 256), ("crossline", 512), ("time", 1024)], + coords=["cdp-x", "cdp-y"], + dtype=ScalarType.FLOAT32, + ) + assert isinstance(seismic.compressor, Blosc) + assert seismic.compressor.algorithm == "zstd" + assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid) + assert seismic.metadata.chunk_grid.configuration.chunk_shape == [128, 128, 128] + assert seismic.metadata.stats_v1 is None diff --git a/tests/unit/v1/templates/test_seismic_3d_prestack_cdp.py b/tests/unit/v1/templates/test_seismic_3d_prestack_cdp.py new file mode 100644 index 000000000..4f936940e --- /dev/null +++ b/tests/unit/v1/templates/test_seismic_3d_prestack_cdp.py @@ -0,0 +1,192 @@ +"""Unit tests for Seismic3DPreStackCDPTemplate.""" + +from tests.unit.v1.helpers import validate_variable + +from mdio.schemas.chunk_grid import RegularChunkGrid +from mdio.schemas.compressors import Blosc +from mdio.schemas.dtype import ScalarType +from mdio.schemas.v1.templates.seismic_3d_prestack_cdp import Seismic3DPreStackCDPTemplate +from mdio.schemas.v1.units import AllUnits +from mdio.schemas.v1.units import LengthUnitEnum +from mdio.schemas.v1.units import LengthUnitModel +from mdio.schemas.v1.units import TimeUnitEnum +from mdio.schemas.v1.units import TimeUnitModel + +_UNIT_METER = AllUnits(units_v1=LengthUnitModel(length=LengthUnitEnum.METER)) +_UNIT_SECOND = AllUnits(units_v1=TimeUnitModel(time=TimeUnitEnum.SECOND)) + + +class TestSeismic3DPreStackCDPTemplate: + """Unit tests for Seismic3DPreStackCDPTemplate.""" + + def test_configuration_depth(self) -> None: + """Unit tests for Seismic3DPreStackCDPTemplate.""" + t = Seismic3DPreStackCDPTemplate(domain="DEPTH") + + # Template attributes for prestack CDP + assert t._trace_domain == "depth" + assert t._coord_dim_names == ["inline", "crossline"] + assert t._dim_names == ["inline", "crossline", "offset", "depth"] + assert t._coord_names == ["cdp-x", "cdp-y"] + assert t._var_name == "AmplitudeCDP" + assert t._var_chunk_shape == [1, 1, 512, 4096] + + # Variables instantiated when build_dataset() is called + assert t._builder is None + assert t._dim_sizes == [] + assert t._coord_units == [] + + # Verify prestack CDP attributes + attrs = t._load_dataset_attributes() + assert attrs.attributes == { + "surveyDimensionality": "3D", + "ensembleType": "cdp", + "processingStage": "pre-stack", + } + + assert t.get_name() == "PreStackCdpGathers3DDepth" + + def test_configuration_time(self) -> None: + """Unit tests for Seismic3DPreStackCDPTemplate.""" + t = Seismic3DPreStackCDPTemplate(domain="TIME") + + # Template attributes for prestack CDP + assert t._trace_domain == "time" + assert t._coord_dim_names == ["inline", "crossline"] + assert t._dim_names == ["inline", "crossline", "offset", "time"] + assert t._coord_names == ["cdp-x", "cdp-y"] + assert t._var_name == "AmplitudeCDP" + assert t._var_chunk_shape == [1, 1, 512, 4096] + + # Variables instantiated when build_dataset() is called + assert t._builder is None + assert t._dim_sizes == [] + assert t._coord_units == [] + + # Verify prestack CDP attributes + attrs = t._load_dataset_attributes() + assert attrs.attributes == { + "surveyDimensionality": "3D", + "ensembleType": "cdp", + "processingStage": "pre-stack", + } + + assert t.get_name() == "PreStackCdpGathers3DTime" + + def test_domain_case_handling(self) -> None: + """Test that domain parameter handles different cases correctly.""" + # Test uppercase + t1 = Seismic3DPreStackCDPTemplate("ELEVATION") + assert t1._trace_domain == "elevation" + assert t1.get_name() == "PreStackCdpGathers3DElevation" + + # Test mixed case + t2 = Seismic3DPreStackCDPTemplate("elevatioN") + assert t2._trace_domain == "elevation" + assert t2.get_name() == "PreStackCdpGathers3DElevation" + + def test_build_dataset_depth(self) -> None: + """Unit tests for Seismic3DPreStackCDPDepthTemplate build with depth domain.""" + t = Seismic3DPreStackCDPTemplate(domain="depth") + + assert t.get_name() == "PreStackCdpGathers3DDepth" + dataset = t.build_dataset( + "North Sea 3D Prestack Depth", + sizes=[512, 768, 36, 1536], + coord_units=[_UNIT_METER, _UNIT_METER], + ) + + assert dataset.metadata.name == "North Sea 3D Prestack Depth" + assert dataset.metadata.attributes["surveyDimensionality"] == "3D" + assert dataset.metadata.attributes["ensembleType"] == "cdp" + assert dataset.metadata.attributes["processingStage"] == "pre-stack" + + # Verify variables + # 2 coordinate variables + 1 data variable = 3 variables + assert len(dataset.variables) == 3 + + # Verify coordinate variables + cdp_x = validate_variable( + dataset, + name="cdp-x", + dims=[("inline", 512), ("crossline", 768)], + coords=["cdp-x"], + dtype=ScalarType.FLOAT64, + ) + assert cdp_x.metadata.units_v1.length == LengthUnitEnum.METER + + cdp_y = validate_variable( + dataset, + name="cdp-y", + dims=[("inline", 512), ("crossline", 768)], + coords=["cdp-y"], + dtype=ScalarType.FLOAT64, + ) + assert cdp_y.metadata.units_v1.length == LengthUnitEnum.METER + + # Verify seismic variable (prestack depth data) + seismic = validate_variable( + dataset, + name="AmplitudeCDP", + dims=[("inline", 512), ("crossline", 768), ("offset", 36), ("depth", 1536)], + coords=["cdp-x", "cdp-y"], + dtype=ScalarType.FLOAT32, + ) + assert isinstance(seismic.compressor, Blosc) + assert seismic.compressor.algorithm == "zstd" + assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid) + assert seismic.metadata.chunk_grid.configuration.chunk_shape == [1, 1, 512, 4096] + assert seismic.metadata.stats_v1 is None + + def test_build_dataset_time(self) -> None: + """Unit tests for Seismic3DPreStackCDPTimeTemplate build with time domain.""" + t = Seismic3DPreStackCDPTemplate(domain="time") + + assert t.get_name() == "PreStackCdpGathers3DTime" + dataset = t.build_dataset( + "Santos Basin 3D Prestack", + sizes=[512, 768, 36, 1536], + coord_units=[_UNIT_METER, _UNIT_METER], + ) + + assert dataset.metadata.name == "Santos Basin 3D Prestack" + assert dataset.metadata.attributes["surveyDimensionality"] == "3D" + assert dataset.metadata.attributes["ensembleType"] == "cdp" + assert dataset.metadata.attributes["processingStage"] == "pre-stack" + + # Verify variables + # 2 coordinate variables + 1 data variable = 3 variables + assert len(dataset.variables) == 3 + + # Verify coordinate variables + cdp_x = validate_variable( + dataset, + name="cdp-x", + dims=[("inline", 512), ("crossline", 768)], + coords=["cdp-x"], + dtype=ScalarType.FLOAT64, + ) + assert cdp_x.metadata.units_v1.length == LengthUnitEnum.METER + + cdp_y = validate_variable( + dataset, + name="cdp-y", + dims=[("inline", 512), ("crossline", 768)], + coords=["cdp-y"], + dtype=ScalarType.FLOAT64, + ) + assert cdp_y.metadata.units_v1.length == LengthUnitEnum.METER + + # Verify seismic variable (prestack time data) + seismic = validate_variable( + dataset, + name="AmplitudeCDP", + dims=[("inline", 512), ("crossline", 768), ("offset", 36), ("time", 1536)], + coords=["cdp-x", "cdp-y"], + dtype=ScalarType.FLOAT32, + ) + assert isinstance(seismic.compressor, Blosc) + assert seismic.compressor.algorithm == "zstd" + assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid) + assert seismic.metadata.chunk_grid.configuration.chunk_shape == [1, 1, 512, 4096] + assert seismic.metadata.stats_v1 is None diff --git a/tests/unit/v1/templates/test_seismic_3d_prestack_shot.py b/tests/unit/v1/templates/test_seismic_3d_prestack_shot.py new file mode 100644 index 000000000..eea66cf6b --- /dev/null +++ b/tests/unit/v1/templates/test_seismic_3d_prestack_shot.py @@ -0,0 +1,244 @@ +"""Unit tests for Seismic3DPreStackShotTemplate.""" + +from tests.unit.v1.helpers import validate_variable + +from mdio.schemas.chunk_grid import RegularChunkGrid +from mdio.schemas.compressors import Blosc +from mdio.schemas.dtype import ScalarType +from mdio.schemas.v1.templates.seismic_3d_prestack_shot import Seismic3DPreStackShotTemplate +from mdio.schemas.v1.units import AllUnits +from mdio.schemas.v1.units import LengthUnitEnum +from mdio.schemas.v1.units import LengthUnitModel +from mdio.schemas.v1.units import TimeUnitEnum +from mdio.schemas.v1.units import TimeUnitModel + +_UNIT_METER = AllUnits(units_v1=LengthUnitModel(length=LengthUnitEnum.METER)) +_UNIT_SECOND = AllUnits(units_v1=TimeUnitModel(time=TimeUnitEnum.SECOND)) + + +class TestSeismic3DPreStackShotTemplate: + """Unit tests for Seismic3DPreStackShotTemplate.""" + + def test_configuration_depth(self) -> None: + """Unit tests for Seismic3DPreStackShotTemplate in depth domain.""" + t = Seismic3DPreStackShotTemplate(domain="DEPTH") + + # Template attributes for prestack shot + assert t._trace_domain == "depth" + assert t._coord_dim_names == [] + assert t._dim_names == ["shot_point", "cable", "channel", "depth"] + assert t._coord_names == ["gun", "shot-x", "shot-y", "receiver-x", "receiver-y"] + assert t._var_name == "AmplitudeShot" + assert t._var_chunk_shape == [1, 1, 512, 4096] + + # Variables instantiated when build_dataset() is called + assert t._builder is None + assert t._dim_sizes == [] + assert t._coord_units == [] + + # Verify prestack shot attributes + attrs = t._load_dataset_attributes() + assert attrs.attributes == { + "surveyDimensionality": "3D", + "ensembleType": "shot", + "processingStage": "pre-stack", + } + + assert t.get_name() == "PreStackShotGathers3DDepth" + + def test_configuration_time(self) -> None: + """Unit tests for Seismic3DPreStackShotTemplate in time domain.""" + t = Seismic3DPreStackShotTemplate(domain="TIME") + + # Template attributes for prestack shot + assert t._trace_domain == "time" + assert t._coord_dim_names == [] + assert t._dim_names == ["shot_point", "cable", "channel", "time"] + assert t._coord_names == ["gun", "shot-x", "shot-y", "receiver-x", "receiver-y"] + assert t._var_name == "AmplitudeShot" + assert t._var_chunk_shape == [1, 1, 512, 4096] + + # Variables instantiated when build_dataset() is called + assert t._builder is None + assert t._dim_sizes == [] + assert t._coord_units == [] + + # Verify prestack shot attributes + attrs = t._load_dataset_attributes() + assert attrs.attributes == { + "surveyDimensionality": "3D", + "ensembleType": "shot", + "processingStage": "pre-stack", + } + + assert t.get_name() == "PreStackShotGathers3DTime" + + def test_domain_case_handling(self) -> None: + """Test that domain parameter handles different cases correctly.""" + # Test uppercase + t1 = Seismic3DPreStackShotTemplate("ELEVATION") + assert t1._trace_domain == "elevation" + assert t1.get_name() == "PreStackShotGathers3DElevation" + + # Test mixed case + t2 = Seismic3DPreStackShotTemplate("elevatioN") + assert t2._trace_domain == "elevation" + assert t2.get_name() == "PreStackShotGathers3DElevation" + + def test_build_dataset_depth(self) -> None: + """Unit tests for Seismic3DPreStackShotTemplate build in depth domain.""" + t = Seismic3DPreStackShotTemplate(domain="depth") + + assert t.get_name() == "PreStackShotGathers3DDepth" + dataset = t.build_dataset( + "Gulf of Mexico 3D Shot Depth", + sizes=[256, 512, 24, 2048], + coord_units=[_UNIT_METER, _UNIT_METER], + ) + + assert dataset.metadata.name == "Gulf of Mexico 3D Shot Depth" + assert dataset.metadata.attributes["surveyDimensionality"] == "3D" + assert dataset.metadata.attributes["ensembleType"] == "shot" + assert dataset.metadata.attributes["processingStage"] == "pre-stack" + + # Verify variables (including dimension variables) + # 5 coordinate variables + 1 data variable = 6 variables + assert len(dataset.variables) == 6 + + # Verify coordinate variables + validate_variable( + dataset, + name="gun", + dims=[("shot_point", 256)], + coords=["gun"], + dtype=ScalarType.UINT8, + ) + + shot_x = validate_variable( + dataset, + name="shot-x", + dims=[("shot_point", 256)], + coords=["shot-x"], + dtype=ScalarType.FLOAT64, + ) + assert shot_x.metadata.units_v1.length == LengthUnitEnum.METER + + shot_y = validate_variable( + dataset, + name="shot-y", + dims=[("shot_point", 256)], + coords=["shot-y"], + dtype=ScalarType.FLOAT64, + ) + assert shot_y.metadata.units_v1.length == LengthUnitEnum.METER + + receiver_x = validate_variable( + dataset, + name="receiver-x", + dims=[("shot_point", 256), ("cable", 512), ("channel", 24)], + coords=["receiver-x"], + dtype=ScalarType.FLOAT64, + ) + assert receiver_x.metadata.units_v1.length == LengthUnitEnum.METER + + receiver_y = validate_variable( + dataset, + name="receiver-y", + dims=[("shot_point", 256), ("cable", 512), ("channel", 24)], + coords=["receiver-y"], + dtype=ScalarType.FLOAT64, + ) + assert receiver_y.metadata.units_v1.length == LengthUnitEnum.METER + + # Verify seismic variable (prestack shot depth data) + seismic = validate_variable( + dataset, + name="AmplitudeShot", + dims=[("shot_point", 256), ("cable", 512), ("channel", 24), ("depth", 2048)], + coords=["gun", "shot-x", "shot-y", "receiver-x", "receiver-y"], + dtype=ScalarType.FLOAT32, + ) + assert isinstance(seismic.compressor, Blosc) + assert seismic.compressor.algorithm == "zstd" + assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid) + assert seismic.metadata.chunk_grid.configuration.chunk_shape == [1, 1, 512, 4096] + assert seismic.metadata.stats_v1 is None + + def test_build_dataset_time(self) -> None: + """Unit tests for Seismic3DPreStackShotTemplate build in time domain.""" + t = Seismic3DPreStackShotTemplate(domain="time") + + assert t.get_name() == "PreStackShotGathers3DTime" + dataset = t.build_dataset( + "North Sea 3D Shot Time", + sizes=[256, 512, 24, 2048], + coord_units=[_UNIT_METER, _UNIT_METER], + ) + + assert dataset.metadata.name == "North Sea 3D Shot Time" + assert dataset.metadata.attributes["surveyDimensionality"] == "3D" + assert dataset.metadata.attributes["ensembleType"] == "shot" + assert dataset.metadata.attributes["processingStage"] == "pre-stack" + + # Verify variables (including dimension variables) + # 5 coordinate variables + 1 data variable = 6 variables + assert len(dataset.variables) == 6 + + # Verify coordinate variables + validate_variable( + dataset, + name="gun", + dims=[("shot_point", 256)], + coords=["gun"], + dtype=ScalarType.UINT8, + ) + + shot_x = validate_variable( + dataset, + name="shot-x", + dims=[("shot_point", 256)], + coords=["shot-x"], + dtype=ScalarType.FLOAT64, + ) + assert shot_x.metadata.units_v1.length == LengthUnitEnum.METER + + shot_y = validate_variable( + dataset, + name="shot-y", + dims=[("shot_point", 256)], + coords=["shot-y"], + dtype=ScalarType.FLOAT64, + ) + assert shot_y.metadata.units_v1.length == LengthUnitEnum.METER + + receiver_x = validate_variable( + dataset, + name="receiver-x", + dims=[("shot_point", 256), ("cable", 512), ("channel", 24)], + coords=["receiver-x"], + dtype=ScalarType.FLOAT64, + ) + assert receiver_x.metadata.units_v1.length == LengthUnitEnum.METER + + receiver_y = validate_variable( + dataset, + name="receiver-y", + dims=[("shot_point", 256), ("cable", 512), ("channel", 24)], + coords=["receiver-y"], + dtype=ScalarType.FLOAT64, + ) + assert receiver_y.metadata.units_v1.length == LengthUnitEnum.METER + + # Verify seismic variable (prestack shot time data) + seismic = validate_variable( + dataset, + name="AmplitudeShot", + dims=[("shot_point", 256), ("cable", 512), ("channel", 24), ("time", 2048)], + coords=["gun", "shot-x", "shot-y", "receiver-x", "receiver-y"], + dtype=ScalarType.FLOAT32, + ) + assert isinstance(seismic.compressor, Blosc) + assert seismic.compressor.algorithm == "zstd" + assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid) + assert seismic.metadata.chunk_grid.configuration.chunk_shape == [1, 1, 512, 4096] + assert seismic.metadata.stats_v1 is None diff --git a/tests/unit/v1/templates/test_seismic_templates.py b/tests/unit/v1/templates/test_seismic_templates.py new file mode 100644 index 000000000..46d2245a2 --- /dev/null +++ b/tests/unit/v1/templates/test_seismic_templates.py @@ -0,0 +1,59 @@ +"""Unit tests for concrete seismic dataset template implementations.""" + +# Import all concrete template classes +from mdio.schemas.v1.templates.abstract_dataset_template import AbstractDatasetTemplate +from mdio.schemas.v1.templates.seismic_2d_poststack import Seismic2DPostStackTemplate +from mdio.schemas.v1.templates.seismic_3d_poststack import Seismic3DPostStackTemplate +from mdio.schemas.v1.templates.seismic_3d_prestack_cdp import Seismic3DPreStackCDPTemplate +from mdio.schemas.v1.templates.seismic_3d_prestack_shot import Seismic3DPreStackShotTemplate + + +class TestSeismicTemplates: + """Test cases for Seismic2DPostStackTemplate.""" + + def test_get_name_time(self) -> None: + """Test get_name with domain.""" + time_template = Seismic2DPostStackTemplate("time") + dpth_template = Seismic2DPostStackTemplate("depth") + + assert time_template.get_name() == "PostStack2DTime" + assert dpth_template.get_name() == "PostStack2DDepth" + + time_template = Seismic3DPostStackTemplate("time") + dpth_template = Seismic3DPostStackTemplate("depth") + + assert time_template.get_name() == "PostStack3DTime" + assert dpth_template.get_name() == "PostStack3DDepth" + + time_template = Seismic3DPreStackCDPTemplate("time") + dpth_template = Seismic3DPreStackCDPTemplate("depth") + + assert time_template.get_name() == "PreStackCdpGathers3DTime" + assert dpth_template.get_name() == "PreStackCdpGathers3DDepth" + + time_template = Seismic3DPreStackShotTemplate("time") + dpth_template = Seismic3DPreStackShotTemplate("depth") + + assert time_template.get_name() == "PreStackShotGathers3DTime" + assert dpth_template.get_name() == "PreStackShotGathers3DDepth" + + def test_all_templates_inherit_from_abstract(self) -> None: + """Test that all concrete templates inherit from AbstractDatasetTemplate.""" + templates = [ + Seismic2DPostStackTemplate("time"), + Seismic3DPostStackTemplate("time"), + Seismic3DPreStackCDPTemplate("time"), + Seismic3DPreStackShotTemplate("time"), + Seismic2DPostStackTemplate("depth"), + Seismic3DPostStackTemplate("depth"), + Seismic3DPreStackCDPTemplate("depth"), + Seismic3DPreStackShotTemplate("depth"), + ] + + for template in templates: + assert isinstance(template, AbstractDatasetTemplate) + assert hasattr(template, "get_name") + assert hasattr(template, "build_dataset") + + names = [template.get_name() for template in templates] + assert len(names) == len(set(names)), f"Duplicate template names found: {names}" diff --git a/tests/unit/v1/templates/test_template_registry.py b/tests/unit/v1/templates/test_template_registry.py new file mode 100644 index 000000000..a40f5021c --- /dev/null +++ b/tests/unit/v1/templates/test_template_registry.py @@ -0,0 +1,395 @@ +"""Tests for the singleton TemplateRegistry implementation.""" + +import threading +import time + +import pytest + +from mdio.schemas.v1.templates.abstract_dataset_template import AbstractDatasetTemplate +from mdio.schemas.v1.templates.template_registry import TemplateRegistry +from mdio.schemas.v1.templates.template_registry import get_template +from mdio.schemas.v1.templates.template_registry import get_template_registry +from mdio.schemas.v1.templates.template_registry import is_template_registered +from mdio.schemas.v1.templates.template_registry import list_templates +from mdio.schemas.v1.templates.template_registry import register_template + + +class MockDatasetTemplate(AbstractDatasetTemplate): + """Mock template for testing.""" + + def __init__(self, name: str): + super().__init__() + self.template_name = name + + def _get_name(self) -> str: + return self.template_name + + def _load_dataset_attributes(self) -> None: + return None # Mock implementation + + def create_dataset(self) -> str: + """Create a mock dataset. + + Returns: + str: A message indicating the dataset creation. + """ + return f"Mock dataset created by {self.template_name}" + + +class TestTemplateRegistrySingleton: + """Test cases for TemplateRegistry singleton pattern.""" + + def setup_method(self) -> None: + """Reset singleton before each test.""" + TemplateRegistry._reset_instance() + + def teardown_method(self) -> None: + """Clean up after each test.""" + if TemplateRegistry._instance: + TemplateRegistry._instance.clear() + TemplateRegistry._reset_instance() + + def test_singleton_same_instance(self) -> None: + """Test that multiple instantiations return the same instance.""" + registry1 = TemplateRegistry() + registry2 = TemplateRegistry() + registry3 = TemplateRegistry.get_instance() + + assert registry1 is registry2 + assert registry2 is registry3 + assert id(registry1) == id(registry2) == id(registry3) + + def test_singleton_thread_safety(self) -> None: + """Test that singleton is thread-safe during creation.""" + instances = [] + errors = [] + + def create_instance() -> None: + try: + instance = TemplateRegistry() + instances.append(instance) + time.sleep(0.001) # Small delay to increase contention + except Exception as e: + errors.append(e) + + # Create multiple threads trying to create instances + threads = [threading.Thread(target=create_instance) for _ in range(10)] + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + # All instances should be the same + assert len(errors) == 0 + assert len(instances) == 10 + assert all(instance is instances[0] for instance in instances) + + def test_initialization_only_once(self) -> None: + """Test that internal state is initialized only once.""" + registry1 = TemplateRegistry() + template1 = MockDatasetTemplate("test_template") + registry1.register(template1) + + # Create another instance - should have same templates + registry2 = TemplateRegistry() + + assert registry1 is registry2 + assert registry2.is_registered("test_template") + assert registry2.get("test_template") is template1 + + def test_register_template(self) -> None: + """Test template registration.""" + registry = TemplateRegistry() + template = MockDatasetTemplate("test") + + name = registry.register(template) + + assert name == "test" # Should be the template name + assert registry.is_registered("test") + assert registry.get("test") is template + + def test_register_duplicate_template(self) -> None: + """Test error when registering duplicate template.""" + registry = TemplateRegistry() + template1 = MockDatasetTemplate("duplicate") + template2 = MockDatasetTemplate("duplicate") + + registry.register(template1) + + with pytest.raises(ValueError, match="Template 'duplicate' is already registered"): + registry.register(template2) + + def test_get_nonexistent_template(self) -> None: + """Test error when getting non-existent template.""" + registry = TemplateRegistry() + + with pytest.raises(KeyError, match="Template 'nonexistent' is not registered"): + registry.get("nonexistent") + + def test_unregister_template(self) -> None: + """Test template unregistration.""" + registry = TemplateRegistry() + template = MockDatasetTemplate("test_template") + + registry.register(template) + assert registry.is_registered("test_template") + + registry.unregister("test_template") + assert not registry.is_registered("test_template") + + def test_unregister_nonexistent_template(self) -> None: + """Test error when unregistering non-existent template.""" + registry = TemplateRegistry() + + with pytest.raises(KeyError, match="Template 'nonexistent' is not registered"): + registry.unregister("nonexistent") + + def test_list_all_templates(self) -> None: + """Test listing all registered templates.""" + registry = TemplateRegistry() + + # Default templates are always installed + templates = list_templates() + assert len(templates) == 8 + assert "PostStack2DTime" in templates + assert "PostStack3DTime" in templates + assert "PreStackCdpGathers3DTime" in templates + assert "PreStackShotGathers3DTime" in templates + + assert "PostStack2DDepth" in templates + assert "PostStack3DDepth" in templates + assert "PreStackCdpGathers3DDepth" in templates + assert "PreStackShotGathers3DDepth" in templates + + # Add some templates + template1 = MockDatasetTemplate("Template_One") + template2 = MockDatasetTemplate("Template_Two") + + registry.register(template1) + registry.register(template2) + + templates = registry.list_all_templates() + assert len(templates) == 10 + assert "Template_One" in templates + assert "Template_Two" in templates + + def test_clear_templates(self) -> None: + """Test clearing all templates.""" + registry = TemplateRegistry() + + # Default templates are always installed + templates = list_templates() + assert len(templates) == 8 + + # Add some templates + template1 = MockDatasetTemplate("Template1") + template2 = MockDatasetTemplate("Template2") + + registry.register(template1) + registry.register(template2) + + assert len(registry.list_all_templates()) == 10 + + # Clear all + registry.clear() + + assert len(registry.list_all_templates()) == 0 + assert not registry.is_registered("Template1") + assert not registry.is_registered("Template2") + # default templates are also cleared + assert not registry.is_registered("PostStack2DTime") + assert not registry.is_registered("PostStack3DTime") + assert not registry.is_registered("PreStackCdpGathers3DTime") + assert not registry.is_registered("PreStackShotGathers3DTime") + assert not registry.is_registered("PostStack2DDepth") + assert not registry.is_registered("PostStack3DDepth") + assert not registry.is_registered("PreStackCdpGathers3DDepth") + assert not registry.is_registered("PreStackShotGathers3DDepth") + + def test_reset_instance(self) -> None: + """Test resetting the singleton instance.""" + registry1 = TemplateRegistry() + template = MockDatasetTemplate("test") + registry1.register(template) + + # Reset the instance + TemplateRegistry._reset_instance() + + # New instance should be different and contain default templates only + registry2 = TemplateRegistry() + + assert registry1 is not registry2 + assert not registry2.is_registered("test") + + # default templates are registered + assert len(registry2.list_all_templates()) == 8 + assert registry2.is_registered("PostStack2DTime") + assert registry2.is_registered("PostStack3DTime") + assert registry2.is_registered("PreStackCdpGathers3DTime") + assert registry2.is_registered("PreStackShotGathers3DTime") + assert registry2.is_registered("PostStack2DDepth") + assert registry2.is_registered("PostStack3DDepth") + assert registry2.is_registered("PreStackCdpGathers3DDepth") + assert registry2.is_registered("PreStackShotGathers3DDepth") + + +class TestGlobalFunctions: + """Test cases for global convenience functions.""" + + def setup_method(self) -> None: + """Reset singleton before each test.""" + TemplateRegistry._reset_instance() + + def teardown_method(self) -> None: + """Clean up after each test.""" + if TemplateRegistry._instance: + TemplateRegistry._instance.clear() + TemplateRegistry._reset_instance() + + def test_get_template_registry(self) -> None: + """Test global registry getter.""" + registry1 = get_template_registry() + registry2 = get_template_registry() + direct_registry = TemplateRegistry() + + assert registry1 is registry2 + assert registry1 is direct_registry + + def test_register_template_global(self) -> None: + """Test global template registration.""" + template = MockDatasetTemplate("global_test") + + name = register_template(template) + + assert name == "global_test" + assert is_template_registered("global_test") + assert get_template("global_test") is template + + def test_list_templates_global(self) -> None: + """Test global template listing.""" + # Default templates are always installed + templates = list_templates() + assert len(templates) == 8 + assert "PostStack2DTime" in templates + assert "PostStack3DTime" in templates + assert "PreStackCdpGathers3DTime" in templates + assert "PreStackShotGathers3DTime" in templates + + assert "PostStack2DDepth" in templates + assert "PostStack3DDepth" in templates + assert "PreStackCdpGathers3DDepth" in templates + assert "PreStackShotGathers3DDepth" in templates + + template1 = MockDatasetTemplate("template1") + template2 = MockDatasetTemplate("template2") + + register_template(template1) + register_template(template2) + + templates = list_templates() + assert len(templates) == 10 + assert "template1" in templates + assert "template2" in templates + + +class TestConcurrentAccess: + """Test concurrent access to the singleton registry.""" + + def setup_method(self) -> None: + """Reset singleton before each test.""" + TemplateRegistry._reset_instance() + + def teardown_method(self) -> None: + """Clean up after each test.""" + if TemplateRegistry._instance: + TemplateRegistry._instance.clear() + TemplateRegistry._reset_instance() + + def test_concurrent_registration(self) -> None: + """Test concurrent template registration.""" + registry = TemplateRegistry() + results = [] + errors = [] + + def register_template_worker(template_id: int) -> None: + try: + template = MockDatasetTemplate(f"template_{template_id}") + name = registry.register(template) + results.append((template_id, name)) + time.sleep(0.001) # Small delay + except Exception as e: + errors.append((template_id, e)) + + # Create multiple threads registering different templates + threads = [threading.Thread(target=register_template_worker, args=(i,)) for i in range(10)] + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + # All registrations should succeed + assert len(errors) == 0 + assert len(results) == 10 + # Including 8 default templates + assert len(registry.list_all_templates()) == 18 + + # Check all templates are registered + for i in range(10): + assert registry.is_registered(f"template_{i}") + + def test_concurrent_access_mixed_operations(self) -> None: + """Test concurrent mixed operations (register, get, list).""" + registry = TemplateRegistry() + + # Pre-register some templates + for i in range(5): + template = MockDatasetTemplate(f"initial_{i}") + registry.register(template) + + results = [] + errors = [] + + def mixed_operations_worker(worker_id: int) -> None: + try: + operations_results = [] + + # Get existing template + if worker_id % 2 == 0: + template = registry.get("initial_0") + operations_results.append(("get", template.template_name)) + + # Register new template + if worker_id % 3 == 0: + new_template = MockDatasetTemplate(f"worker_{worker_id}") + name = registry.register(new_template) + operations_results.append(("register", name)) + + # List templates + templates = registry.list_all_templates() + operations_results.append(("list", len(templates))) + + results.append((worker_id, operations_results)) + + except Exception as e: + errors.append((worker_id, e)) + + # Run concurrent operations + threads = [threading.Thread(target=mixed_operations_worker, args=(i,)) for i in range(15)] + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + # Check that operations completed without errors + assert len(errors) == 0 + assert len(results) == 15 + + # Verify final state is consistent + final_templates = registry.list_all_templates() + assert len(final_templates) >= 5 # At least the initial templates diff --git a/tests/unit/v1/test_dataset_builder_build.py b/tests/unit/v1/test_dataset_builder_build.py index 833fad981..43b87fb5f 100644 --- a/tests/unit/v1/test_dataset_builder_build.py +++ b/tests/unit/v1/test_dataset_builder_build.py @@ -45,10 +45,10 @@ def test_build() -> None: def test_build_seismic_poststack_3d_acceptance_dataset() -> None: # noqa: PLR0915 Too many statements (57 > 50) """Test building a Seismic PostStack 3D Acceptance dataset.""" - dataset = make_seismic_poststack_3d_acceptance_dataset() + dataset = make_seismic_poststack_3d_acceptance_dataset("Seismic") # Verify dataset structure - assert dataset.metadata.name == "campos_3d" + assert dataset.metadata.name == "Seismic" assert dataset.metadata.api_version == "1.0.0a1" assert dataset.metadata.attributes["foo"] == "bar" assert len(dataset.metadata.attributes["textHeader"]) == 3 diff --git a/tests/unit/v1/test_dataset_serializer.py b/tests/unit/v1/test_dataset_serializer.py index f8439585e..9db2deee1 100644 --- a/tests/unit/v1/test_dataset_serializer.py +++ b/tests/unit/v1/test_dataset_serializer.py @@ -398,7 +398,7 @@ def test_to_xarray_dataset(tmp_path: Path) -> None: def test_seismic_poststack_3d_acceptance_to_xarray_dataset(tmp_path: Path) -> None: """Test building a complete dataset.""" - dataset = make_seismic_poststack_3d_acceptance_dataset() + dataset = make_seismic_poststack_3d_acceptance_dataset("Seismic") xr_ds = to_xarray_dataset(dataset)