diff --git a/src/mdio/builder/template_registry.py b/src/mdio/builder/template_registry.py index b5be5c104..0208f7d61 100644 --- a/src/mdio/builder/template_registry.py +++ b/src/mdio/builder/template_registry.py @@ -76,10 +76,11 @@ def _register_default_templates(self) -> None: self.register(Seismic3DPostStackTemplate("depth")) # CDP/CMP Ordered Data - self.register(Seismic2DPreStackCDPTemplate("time")) - self.register(Seismic2DPreStackCDPTemplate("depth")) - self.register(Seismic3DPreStackCDPTemplate("time")) - self.register(Seismic3DPreStackCDPTemplate("depth")) + for data_domain in ("time", "depth"): + for gather_domain in ("offset", "angle"): + self.register(Seismic3DPreStackCDPTemplate(data_domain, gather_domain)) + self.register(Seismic2DPreStackCDPTemplate(data_domain, gather_domain)) + self.register(Seismic3DPreStackCocaTemplate("time")) self.register(Seismic3DPreStackCocaTemplate("depth")) diff --git a/src/mdio/builder/templates/abstract_dataset_template.py b/src/mdio/builder/templates/abstract_dataset_template.py index feb4186dd..0f3ec2367 100644 --- a/src/mdio/builder/templates/abstract_dataset_template.py +++ b/src/mdio/builder/templates/abstract_dataset_template.py @@ -15,6 +15,7 @@ from mdio.builder.schemas.v1.units import LengthUnitModel from mdio.builder.schemas.v1.variable import CoordinateMetadata from mdio.builder.schemas.v1.variable import VariableMetadata +from mdio.builder.templates.types import SeismicDataDomain class AbstractDatasetTemplate(ABC): @@ -24,8 +25,13 @@ class AbstractDatasetTemplate(ABC): to override specific steps. """ - def __init__(self, domain: str = "") -> None: - self._trace_domain = domain.lower() + def __init__(self, data_domain: SeismicDataDomain) -> None: + self._data_domain = data_domain.lower() + + if self._data_domain not in ["depth", "time"]: + msg = "domain must be 'depth' or 'time'" + raise ValueError(msg) + self._coord_dim_names = () self._dim_names = () self._coord_names = () @@ -80,7 +86,7 @@ def default_variable_name(self) -> str: @property def trace_domain(self) -> str: """Returns the name of the trace domain.""" - return self._trace_domain + return self._data_domain @property def dimension_names(self) -> tuple[str, ...]: diff --git a/src/mdio/builder/templates/seismic_2d_poststack.py b/src/mdio/builder/templates/seismic_2d_poststack.py index a08fe6489..878b17d96 100644 --- a/src/mdio/builder/templates/seismic_2d_poststack.py +++ b/src/mdio/builder/templates/seismic_2d_poststack.py @@ -3,22 +3,23 @@ from typing import Any from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate +from mdio.builder.templates.abstract_dataset_template import SeismicDataDomain class Seismic2DPostStackTemplate(AbstractDatasetTemplate): """Seismic post-stack 2D time or depth Dataset template.""" - def __init__(self, domain: str): - super().__init__(domain=domain) + def __init__(self, data_domain: SeismicDataDomain): + super().__init__(data_domain=data_domain) self._coord_dim_names = ("cdp",) - self._dim_names = (*self._coord_dim_names, self._trace_domain) + self._dim_names = (*self._coord_dim_names, self._data_domain) self._coord_names = ("cdp_x", "cdp_y") self._var_chunk_shape = (1024, 1024) @property def _name(self) -> str: - return f"PostStack2D{self._trace_domain.capitalize()}" + return f"PostStack2D{self._data_domain.capitalize()}" def _load_dataset_attributes(self) -> dict[str, Any]: - return {"surveyDimensionality": "2D", "ensembleType": "line", "processingStage": "post-stack"} + return {"surveyType": "2D", "gatherType": "stacked"} diff --git a/src/mdio/builder/templates/seismic_2d_prestack_cdp.py b/src/mdio/builder/templates/seismic_2d_prestack_cdp.py index c3631b074..99343d81c 100644 --- a/src/mdio/builder/templates/seismic_2d_prestack_cdp.py +++ b/src/mdio/builder/templates/seismic_2d_prestack_cdp.py @@ -3,22 +3,31 @@ from typing import Any from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate +from mdio.builder.templates.abstract_dataset_template import SeismicDataDomain +from mdio.builder.templates.types import CdpGatherDomain class Seismic2DPreStackCDPTemplate(AbstractDatasetTemplate): """Seismic CDP pre-stack 2D time or depth Dataset template.""" - def __init__(self, domain: str): - super().__init__(domain=domain) + def __init__(self, data_domain: SeismicDataDomain, gather_domain: CdpGatherDomain): + super().__init__(data_domain=data_domain) + self._gather_domain = gather_domain.lower() - self._coord_dim_names = ("cdp", "offset") - self._dim_names = (*self._coord_dim_names, self._trace_domain) + if self._gather_domain not in ["offset", "angle"]: + msg = "gather_type must be 'offset' or 'angle'" + raise ValueError(msg) + + self._coord_dim_names = ("cdp", self._gather_domain) + self._dim_names = (*self._coord_dim_names, self._data_domain) self._coord_names = ("cdp_x", "cdp_y") - self._var_chunk_shape = (1, 512, 4096) + self._var_chunk_shape = (16, 64, 1024) @property def _name(self) -> str: - return f"PreStackCdpGathers2D{self._trace_domain.capitalize()}" + gather_domain_suffix = self._gather_domain.capitalize() + data_domain_suffix = self._data_domain.capitalize() + return f"PreStackCdp{gather_domain_suffix}Gathers2D{data_domain_suffix}" def _load_dataset_attributes(self) -> dict[str, Any]: - return {"surveyDimensionality": "2D", "ensembleType": "cdp", "processingStage": "pre-stack"} + return {"surveyType": "2D", "gatherType": "cdp"} diff --git a/src/mdio/builder/templates/seismic_2d_prestack_shot.py b/src/mdio/builder/templates/seismic_2d_prestack_shot.py index 3bfc60450..252e23bd2 100644 --- a/src/mdio/builder/templates/seismic_2d_prestack_shot.py +++ b/src/mdio/builder/templates/seismic_2d_prestack_shot.py @@ -5,25 +5,26 @@ from mdio.builder.schemas.dtype import ScalarType from mdio.builder.schemas.v1.variable import CoordinateMetadata from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate +from mdio.builder.templates.abstract_dataset_template import SeismicDataDomain class Seismic2DPreStackShotTemplate(AbstractDatasetTemplate): """Seismic Shot pre-stack 2D time or depth Dataset template.""" - def __init__(self, domain: str): - super().__init__(domain=domain) + def __init__(self, data_domain: SeismicDataDomain): + super().__init__(data_domain=data_domain) - self._coord_dim_names = ("shot_point", "channel") # Custom coordinate definition for shot gathers - self._dim_names = (*self._coord_dim_names, self._trace_domain) + self._coord_dim_names = ("shot_point", "channel") + self._dim_names = (*self._coord_dim_names, self._data_domain) self._coord_names = ("gun", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y") - self._var_chunk_shape = (1, 512, 4096) + self._var_chunk_shape = (16, 64, 1024) @property def _name(self) -> str: - return f"PreStackShotGathers2D{self._trace_domain.capitalize()}" + return f"PreStackShotGathers2D{self._data_domain.capitalize()}" def _load_dataset_attributes(self) -> dict[str, Any]: - return {"surveyDimensionality": "2D", "ensembleType": "shot_point", "processingStage": "pre-stack"} + return {"surveyType": "2D", "ensembleType": "common_source"} def _add_coordinates(self) -> None: # Add dimension coordinates @@ -34,18 +35,18 @@ def _add_coordinates(self) -> None: coordinate_metadata = CoordinateMetadata(units_v1=self._horizontal_coord_unit) self._builder.add_coordinate( "gun", - dimensions=("shot_point", "channel"), + dimensions=("shot_point",), data_type=ScalarType.UINT8, ) self._builder.add_coordinate( "source_coord_x", - dimensions=("shot_point", "channel"), + dimensions=("shot_point",), data_type=ScalarType.FLOAT64, metadata=coordinate_metadata, ) self._builder.add_coordinate( "source_coord_y", - dimensions=("shot_point", "channel"), + dimensions=("shot_point",), data_type=ScalarType.FLOAT64, metadata=coordinate_metadata, ) diff --git a/src/mdio/builder/templates/seismic_3d_poststack.py b/src/mdio/builder/templates/seismic_3d_poststack.py index 77893b0e4..ec5329b82 100644 --- a/src/mdio/builder/templates/seismic_3d_poststack.py +++ b/src/mdio/builder/templates/seismic_3d_poststack.py @@ -3,22 +3,24 @@ from typing import Any from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate +from mdio.builder.templates.abstract_dataset_template import SeismicDataDomain class Seismic3DPostStackTemplate(AbstractDatasetTemplate): """Seismic post-stack 3D time or depth Dataset template.""" - def __init__(self, domain: str): - super().__init__(domain=domain) - # Template attributes to be overridden by subclasses + def __init__(self, data_domain: SeismicDataDomain): + super().__init__(data_domain=data_domain) + self._coord_dim_names = ("inline", "crossline") - self._dim_names = (*self._coord_dim_names, self._trace_domain) + self._dim_names = (*self._coord_dim_names, self._data_domain) self._coord_names = ("cdp_x", "cdp_y") self._var_chunk_shape = (128, 128, 128) @property def _name(self) -> str: - return f"PostStack3D{self._trace_domain.capitalize()}" + domain_suffix = self._data_domain.capitalize() + return f"PostStack3D{domain_suffix}" def _load_dataset_attributes(self) -> dict[str, Any]: - return {"surveyDimensionality": "3D", "ensembleType": "line", "processingStage": "post-stack"} + return {"surveyType": "3D", "gatherType": "stacked"} diff --git a/src/mdio/builder/templates/seismic_3d_prestack_cdp.py b/src/mdio/builder/templates/seismic_3d_prestack_cdp.py index 40346004d..637b689dd 100644 --- a/src/mdio/builder/templates/seismic_3d_prestack_cdp.py +++ b/src/mdio/builder/templates/seismic_3d_prestack_cdp.py @@ -3,22 +3,31 @@ from typing import Any from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate +from mdio.builder.templates.abstract_dataset_template import SeismicDataDomain +from mdio.builder.templates.types import CdpGatherDomain class Seismic3DPreStackCDPTemplate(AbstractDatasetTemplate): - """Seismic CDP pre-stack 3D time or depth Dataset template.""" + """Seismic CDP pre-stack 3D gathers Dataset template.""" - def __init__(self, domain: str): - super().__init__(domain=domain) + def __init__(self, data_domain: SeismicDataDomain, gather_domain: CdpGatherDomain): + super().__init__(data_domain=data_domain) + self._gather_domain = gather_domain.lower() - self._coord_dim_names = ("inline", "crossline", "offset") - self._dim_names = (*self._coord_dim_names, self._trace_domain) + if self._gather_domain not in ["offset", "angle"]: + msg = "gather_type must be 'offset' or 'angle'" + raise ValueError(msg) + + self._coord_dim_names = ("inline", "crossline", self._gather_domain) + self._dim_names = (*self._coord_dim_names, self._data_domain) self._coord_names = ("cdp_x", "cdp_y") - self._var_chunk_shape = (1, 1, 512, 4096) + self._var_chunk_shape = (8, 8, 32, 512) @property def _name(self) -> str: - return f"PreStackCdpGathers3D{self._trace_domain.capitalize()}" + gather_domain_suffix = self._gather_domain.capitalize() + data_domain_suffix = self._data_domain.capitalize() + return f"PreStackCdp{gather_domain_suffix}Gathers3D{data_domain_suffix}" def _load_dataset_attributes(self) -> dict[str, Any]: - return {"surveyDimensionality": "3D", "ensembleType": "cdp", "processingStage": "pre-stack"} + return {"surveyType": "3D", "gatherType": "cdp"} diff --git a/src/mdio/builder/templates/seismic_3d_prestack_coca.py b/src/mdio/builder/templates/seismic_3d_prestack_coca.py index 0514ddfdb..2d6fd081c 100644 --- a/src/mdio/builder/templates/seismic_3d_prestack_coca.py +++ b/src/mdio/builder/templates/seismic_3d_prestack_coca.py @@ -6,25 +6,26 @@ from mdio.builder.schemas.v1.units import AngleUnitModel from mdio.builder.schemas.v1.variable import CoordinateMetadata from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate +from mdio.builder.templates.abstract_dataset_template import SeismicDataDomain class Seismic3DPreStackCocaTemplate(AbstractDatasetTemplate): """Seismic Shot pre-stack 3D time or depth Dataset template.""" - def __init__(self, domain: str): - super().__init__(domain=domain) + def __init__(self, data_domain: SeismicDataDomain): + super().__init__(data_domain=data_domain) self._coord_dim_names = ("inline", "crossline", "offset", "azimuth") - self._dim_names = (*self._coord_dim_names, self._trace_domain) + self._dim_names = (*self._coord_dim_names, self._data_domain) self._coord_names = ("cdp_x", "cdp_y") self._var_chunk_shape = (8, 8, 32, 1, 1024) @property def _name(self) -> str: - return f"PreStackCocaGathers3D{self._trace_domain.capitalize()}" + return f"PreStackCocaGathers3D{self._data_domain.capitalize()}" def _load_dataset_attributes(self) -> dict[str, Any]: - return {"surveyDimensionality": "3D", "ensembleType": "cdp_coca", "processingStage": "pre-stack"} + return {"surveyType": "3D", "gatherType": "common_offset_common_azimuth"} def _add_coordinates(self) -> None: # Add dimension coordinates diff --git a/src/mdio/builder/templates/seismic_3d_prestack_shot.py b/src/mdio/builder/templates/seismic_3d_prestack_shot.py index 778134bbb..c59c13694 100644 --- a/src/mdio/builder/templates/seismic_3d_prestack_shot.py +++ b/src/mdio/builder/templates/seismic_3d_prestack_shot.py @@ -5,25 +5,26 @@ from mdio.builder.schemas.dtype import ScalarType from mdio.builder.schemas.v1.variable import CoordinateMetadata from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate +from mdio.builder.templates.abstract_dataset_template import SeismicDataDomain class Seismic3DPreStackShotTemplate(AbstractDatasetTemplate): """Seismic Shot pre-stack 3D time or depth Dataset template.""" - def __init__(self, domain: str): - super().__init__(domain=domain) + def __init__(self, data_domain: SeismicDataDomain): + super().__init__(data_domain=data_domain) - self._coord_dim_names = ("shot_point", "cable", "channel") # Custom coordinates for shot gathers - self._dim_names = (*self._coord_dim_names, self._trace_domain) + self._coord_dim_names = ("shot_point", "cable", "channel") + self._dim_names = (*self._coord_dim_names, self._data_domain) self._coord_names = ("gun", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y") - self._var_chunk_shape = (1, 1, 512, 4096) + self._var_chunk_shape = (8, 2, 128, 1024) @property def _name(self) -> str: - return f"PreStackShotGathers3D{self._trace_domain.capitalize()}" + return f"PreStackShotGathers3D{self._data_domain.capitalize()}" def _load_dataset_attributes(self) -> dict[str, Any]: - return {"surveyDimensionality": "3D", "ensembleType": "shot_point", "processingStage": "pre-stack"} + return {"surveyType": "3D", "ensembleType": "common_source"} def _add_coordinates(self) -> None: # Add dimension coordinates @@ -33,18 +34,18 @@ def _add_coordinates(self) -> None: # Add non-dimension coordinates self._builder.add_coordinate( "gun", - dimensions=("shot_point", "cable", "channel"), + dimensions=("shot_point",), data_type=ScalarType.UINT8, ) self._builder.add_coordinate( "source_coord_x", - dimensions=("shot_point", "cable", "channel"), + dimensions=("shot_point",), data_type=ScalarType.FLOAT64, metadata=CoordinateMetadata(units_v1=self._horizontal_coord_unit), ) self._builder.add_coordinate( "source_coord_y", - dimensions=("shot_point", "cable", "channel"), + dimensions=("shot_point",), data_type=ScalarType.FLOAT64, metadata=CoordinateMetadata(units_v1=self._horizontal_coord_unit), ) diff --git a/src/mdio/builder/templates/types.py b/src/mdio/builder/templates/types.py new file mode 100644 index 000000000..ed60392cd --- /dev/null +++ b/src/mdio/builder/templates/types.py @@ -0,0 +1,8 @@ +"""Module that contains type aliases for templates.""" + +from typing import Literal +from typing import TypeAlias + +SeismicDataDomain: TypeAlias = Literal["depth", "time"] + +CdpGatherDomain: TypeAlias = Literal["offset", "angle"] diff --git a/tests/integration/test_segy_import_export.py b/tests/integration/test_segy_import_export.py index 2b3f154ef..d5f938546 100644 --- a/tests/integration/test_segy_import_export.py +++ b/tests/integration/test_segy_import_export.py @@ -255,12 +255,11 @@ def test_dataset_metadata(self, zarr_tmp: Path) -> None: attributes = ds.attrs["attributes"] assert attributes is not None - assert len(attributes) == 6 + assert len(attributes) == 5 # Validate all attributes provided by the abstract template assert attributes["defaultVariableName"] == "amplitude" - assert attributes["surveyDimensionality"] == "3D" - assert attributes["ensembleType"] == "line" - assert attributes["processingStage"] == "post-stack" + assert attributes["surveyType"] == "3D" + assert attributes["gatherType"] == "stacked" assert attributes["textHeader"] == text_header_teapot_dome() assert attributes["binaryHeader"] == binary_header_teapot_dome() diff --git a/tests/integration/test_segy_import_export_masked.py b/tests/integration/test_segy_import_export_masked.py index a1536a6a6..a31545891 100644 --- a/tests/integration/test_segy_import_export_masked.py +++ b/tests/integration/test_segy_import_export_masked.py @@ -280,9 +280,9 @@ def test_import(self, test_conf: MaskedExportConfig, export_masked_path: Path) - case "3d_stack": template_name = "PostStack3D" + domain case "2d_gather": - template_name = "PreStackCdpGathers2D" + domain + template_name = "PreStackCdpOffsetGathers2D" + domain case "3d_gather": - template_name = "PreStackCdpGathers3D" + domain + template_name = "PreStackCdpOffsetGathers3D" + domain case "2d_streamer": template_name = "PreStackShotGathers2D" + domain case "3d_streamer": diff --git a/tests/unit/v1/templates/test_seismic_2d_poststack.py b/tests/unit/v1/templates/test_seismic_2d_poststack.py index 668451af0..ba6b3cb6f 100644 --- a/tests/unit/v1/templates/test_seismic_2d_poststack.py +++ b/tests/unit/v1/templates/test_seismic_2d_poststack.py @@ -1,5 +1,6 @@ """Unit tests for Seismic2DPostStackTemplate.""" +import pytest from tests.unit.v1.helpers import validate_variable from mdio.builder.schemas.chunk_grid import RegularChunkGrid @@ -11,6 +12,7 @@ from mdio.builder.schemas.v1.units import TimeUnitEnum from mdio.builder.schemas.v1.units import TimeUnitModel from mdio.builder.templates.seismic_2d_poststack import Seismic2DPostStackTemplate +from mdio.builder.templates.types import SeismicDataDomain UNITS_METER = LengthUnitModel(length=LengthUnitEnum.METER) UNITS_SECOND = TimeUnitModel(time=TimeUnitEnum.SECOND) @@ -34,14 +36,14 @@ def _validate_coordinates_headers_trace_mask(dataset: Dataset, headers: Structur ) # Verify dimension coordinate variables - inline = validate_variable( + cdp = validate_variable( dataset, name="cdp", dims=[("cdp", 2048)], coords=["cdp"], dtype=ScalarType.INT32, ) - assert inline.metadata is None + assert cdp.metadata is None domain = validate_variable( dataset, @@ -72,17 +74,18 @@ def _validate_coordinates_headers_trace_mask(dataset: Dataset, headers: Structur assert cdp_y.metadata.units_v1 == UNITS_METER +@pytest.mark.parametrize("data_domain", ["time", "depth"]) class TestSeismic2DPostStackTemplate: """Unit tests for Seismic2DPostStackTemplate.""" - def test_configuration_depth(self) -> None: - """Test configuration of Seismic2DPostStackTemplate with depth domain.""" - t = Seismic2DPostStackTemplate("depth") + def test_configuration(self, data_domain: SeismicDataDomain) -> None: + """Test configuration of Seismic2DPostStackTemplate.""" + t = Seismic2DPostStackTemplate(data_domain=data_domain) # Template attributes - assert t._trace_domain == "depth" + assert t._data_domain == data_domain assert t._coord_dim_names == ("cdp",) - assert t._dim_names == ("cdp", "depth") + assert t._dim_names == ("cdp", data_domain) assert t._coord_names == ("cdp_x", "cdp_y") assert t._var_chunk_shape == (1024, 1024) @@ -93,85 +96,13 @@ def test_configuration_depth(self) -> None: # Verify dataset attributes attrs = t._load_dataset_attributes() - assert attrs == { - "surveyDimensionality": "2D", - "ensembleType": "line", - "processingStage": "post-stack", - } + assert attrs == {"surveyType": "2D", "gatherType": "stacked"} assert t.default_variable_name == "amplitude" - def test_configuration_time(self) -> None: - """Test configuration of Seismic2DPostStackTemplate with time domain.""" - t = Seismic2DPostStackTemplate("time") - - # Template attributes - assert t._trace_domain == "time" - assert t._coord_dim_names == ("cdp",) - assert t._dim_names == ("cdp", "time") - assert t._coord_names == ("cdp_x", "cdp_y") - assert t._var_chunk_shape == (1024, 1024) - - # Variables instantiated when build_dataset() is called - assert t._builder is None - assert t._dim_sizes == () - assert t._horizontal_coord_unit is None - - # Verify dataset attributes - attrs = t._load_dataset_attributes() - assert attrs == { - "surveyDimensionality": "2D", - "ensembleType": "line", - "processingStage": "post-stack", - } - assert t.default_variable_name == "amplitude" - - def test_domain_case_handling(self) -> None: - """Test that domain parameter handles different cases correctly.""" - # Test uppercase - t1 = Seismic2DPostStackTemplate("ELEVATION") - assert t1._trace_domain == "elevation" - assert t1.name == "PostStack2DElevation" - - # Test mixed case - t2 = Seismic2DPostStackTemplate("elevatioN") - assert t2._trace_domain == "elevation" - assert t2.name == "PostStack2DElevation" - - def test_build_dataset_depth(self, structured_headers: StructuredType) -> None: - """Test building a complete 2D depth dataset.""" - t = Seismic2DPostStackTemplate("depth") - - dataset = t.build_dataset( - "Seismic 2D Depth Line 001", - sizes=(2048, 4096), - horizontal_coord_unit=UNITS_METER, - headers=structured_headers, - ) - - # Verify dataset metadata - assert dataset.metadata.name == "Seismic 2D Depth Line 001" - assert dataset.metadata.attributes["surveyDimensionality"] == "2D" - assert dataset.metadata.attributes["ensembleType"] == "line" - assert dataset.metadata.attributes["processingStage"] == "post-stack" - - _validate_coordinates_headers_trace_mask(dataset, structured_headers, "depth") - - # Verify seismic variable - seismic = validate_variable( - dataset, - name="amplitude", - dims=[("cdp", 2048), ("depth", 4096)], - coords=["cdp_x", "cdp_y"], - dtype=ScalarType.FLOAT32, - ) - assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid) - assert seismic.metadata.chunk_grid.configuration.chunk_shape == (1024, 1024) - assert seismic.metadata.stats_v1 is None - - def test_build_dataset_time(self, structured_headers: StructuredType) -> None: + def test_build_dataset_time(self, data_domain: SeismicDataDomain, structured_headers: StructuredType) -> None: """Test building a complete 2D time dataset.""" - t = Seismic2DPostStackTemplate("time") + t = Seismic2DPostStackTemplate(data_domain=data_domain) dataset = t.build_dataset( "Seismic 2D Time Line 001", @@ -182,17 +113,16 @@ def test_build_dataset_time(self, structured_headers: StructuredType) -> None: # Verify dataset metadata assert dataset.metadata.name == "Seismic 2D Time Line 001" - assert dataset.metadata.attributes["surveyDimensionality"] == "2D" - assert dataset.metadata.attributes["ensembleType"] == "line" - assert dataset.metadata.attributes["processingStage"] == "post-stack" + assert dataset.metadata.attributes["surveyType"] == "2D" + assert dataset.metadata.attributes["gatherType"] == "stacked" - _validate_coordinates_headers_trace_mask(dataset, structured_headers, "time") + _validate_coordinates_headers_trace_mask(dataset, structured_headers, data_domain) # Verify seismic variable v = validate_variable( dataset, name="amplitude", - dims=[("cdp", 2048), ("time", 4096)], + dims=[("cdp", 2048), (data_domain, 4096)], coords=["cdp_x", "cdp_y"], dtype=ScalarType.FLOAT32, ) @@ -200,20 +130,12 @@ def test_build_dataset_time(self, structured_headers: StructuredType) -> None: assert v.metadata.chunk_grid.configuration.chunk_shape == (1024, 1024) assert v.metadata.stats_v1 is None - def test_time_vs_depth_comparison(self) -> None: - """Test differences between time and depth templates.""" - time_template = Seismic2DPostStackTemplate("time") - depth_template = Seismic2DPostStackTemplate("depth") - - # Different trace domains - assert time_template._trace_domain == "time" - assert depth_template._trace_domain == "depth" - # Different names - assert time_template.name == "PostStack2DTime" - assert depth_template.name == "PostStack2DDepth" +@pytest.mark.parametrize("data_domain", ["Time", "DePTh"]) +def test_domain_case_handling(data_domain: str) -> None: + """Test that domain parameter handles different cases correctly.""" + template = Seismic2DPostStackTemplate(data_domain=data_domain) + assert template._data_domain == data_domain.lower() - # Same other attributes - assert time_template._coord_dim_names == depth_template._coord_dim_names - assert time_template._coord_names == depth_template._coord_names - assert time_template._var_chunk_shape == depth_template._var_chunk_shape + data_domain_suffix = data_domain.lower().capitalize() + assert template.name == f"PostStack2D{data_domain_suffix}" diff --git a/tests/unit/v1/templates/test_seismic_2d_prestack_cdp.py b/tests/unit/v1/templates/test_seismic_2d_prestack_cdp.py new file mode 100644 index 000000000..5eee1c021 --- /dev/null +++ b/tests/unit/v1/templates/test_seismic_2d_prestack_cdp.py @@ -0,0 +1,174 @@ +"""Unit tests for Seismic2DPreStackCDPTemplate.""" + +import pytest +from tests.unit.v1.helpers import validate_variable + +from mdio.builder.schemas.chunk_grid import RegularChunkGrid +from mdio.builder.schemas.compressors import Blosc +from mdio.builder.schemas.compressors import BloscCname +from mdio.builder.schemas.dtype import ScalarType +from mdio.builder.schemas.dtype import StructuredType +from mdio.builder.schemas.v1.dataset import Dataset +from mdio.builder.schemas.v1.units import LengthUnitEnum +from mdio.builder.schemas.v1.units import LengthUnitModel +from mdio.builder.schemas.v1.units import TimeUnitEnum +from mdio.builder.schemas.v1.units import TimeUnitModel +from mdio.builder.templates.seismic_2d_prestack_cdp import Seismic2DPreStackCDPTemplate +from mdio.builder.templates.types import CdpGatherDomain +from mdio.builder.templates.types import SeismicDataDomain + +UNITS_METER = LengthUnitModel(length=LengthUnitEnum.METER) +UNITS_SECOND = TimeUnitModel(time=TimeUnitEnum.SECOND) + + +def validate_coordinates_headers_trace_mask( + dataset: Dataset, + headers: StructuredType, + data_domain: SeismicDataDomain, + gather_domain: CdpGatherDomain, +) -> None: + """A helper method to validate coordinates, headers, and trace mask.""" + # Verify variables + # 3 dim coords + 2 non-dim coords + 1 data + 1 trace mask + 1 headers = 8 variables + assert len(dataset.variables) == 8 + + # Verify trace headers + validate_variable( + dataset, + name="headers", + dims=[("cdp", 512), (gather_domain, 36)], + coords=["cdp_x", "cdp_y"], + dtype=headers, + ) + + validate_variable( + dataset, + name="trace_mask", + dims=[("cdp", 512), (gather_domain, 36)], + coords=["cdp_x", "cdp_y"], + dtype=ScalarType.BOOL, + ) + + # Verify dimension coordinate variables + cdp = validate_variable( + dataset, + name="cdp", + dims=[("cdp", 512)], + coords=["cdp"], + dtype=ScalarType.INT32, + ) + assert cdp.metadata is None + + domain = validate_variable( + dataset, + name=gather_domain, + dims=[(gather_domain, 36)], + coords=[gather_domain], + dtype=ScalarType.INT32, + ) + assert domain.metadata is None + + domain = validate_variable( + dataset, + name=data_domain, + dims=[(data_domain, 1536)], + coords=[data_domain], + dtype=ScalarType.INT32, + ) + assert domain.metadata is None + + # Verify non-dimension coordinate variables + cdp_x = validate_variable( + dataset, + name="cdp_x", + dims=[("cdp", 512), (gather_domain, 36)], + coords=["cdp_x"], + dtype=ScalarType.FLOAT64, + ) + assert cdp_x.metadata.units_v1.length == LengthUnitEnum.METER + + cdp_y = validate_variable( + dataset, + name="cdp_y", + dims=[("cdp", 512), (gather_domain, 36)], + coords=["cdp_y"], + dtype=ScalarType.FLOAT64, + ) + assert cdp_y.metadata.units_v1.length == LengthUnitEnum.METER + + +@pytest.mark.parametrize("data_domain", ["depth", "time"]) +@pytest.mark.parametrize("gather_domain", ["offset", "angle"]) +class TestSeismic2DPreStackCDPTemplate: + """Unit tests for Seismic2DPreStackCDPTemplate.""" + + def test_configuration(self, data_domain: SeismicDataDomain, gather_domain: CdpGatherDomain) -> None: + """Unit tests for Seismic2DPreStackCDPTemplate.""" + t = Seismic2DPreStackCDPTemplate(data_domain, gather_domain) + + # Template attributes for prestack CDP + assert t._coord_dim_names == ("cdp", gather_domain) + assert t._dim_names == ("cdp", gather_domain, data_domain) + assert t._coord_names == ("cdp_x", "cdp_y") + assert t._var_chunk_shape == (16, 64, 1024) + + # Variables instantiated when build_dataset() is called + assert t._builder is None + assert t._dim_sizes == () + assert t._horizontal_coord_unit is None + + # Verify prestack CDP attributes + attrs = t._load_dataset_attributes() + assert attrs == {"surveyType": "2D", "gatherType": "cdp"} + assert t.default_variable_name == "amplitude" + + def test_build_dataset( + self, + data_domain: SeismicDataDomain, + gather_domain: CdpGatherDomain, + structured_headers: StructuredType, + ) -> None: + """Unit tests for Seismic2DPreStackCDPDepthTemplate build.""" + t = Seismic2DPreStackCDPTemplate(data_domain, gather_domain) + + gather_domain_suffix = gather_domain.capitalize() + data_domain_suffix = data_domain.capitalize() + assert t.name == f"PreStackCdp{gather_domain_suffix}Gathers2D{data_domain_suffix}" + dataset = t.build_dataset( + "North Sea 2D Prestack", + sizes=(512, 36, 1536), + horizontal_coord_unit=UNITS_METER, + headers=structured_headers, + ) + + assert dataset.metadata.name == "North Sea 2D Prestack" + assert dataset.metadata.attributes["surveyType"] == "2D" + assert dataset.metadata.attributes["gatherType"] == "cdp" + + validate_coordinates_headers_trace_mask(dataset, structured_headers, data_domain, gather_domain) + + # Verify seismic variable (prestack depth data) + seismic = validate_variable( + dataset, + name="amplitude", + dims=[("cdp", 512), (gather_domain, 36), (data_domain, 1536)], + coords=["cdp_x", "cdp_y"], + dtype=ScalarType.FLOAT32, + ) + assert isinstance(seismic.compressor, Blosc) + assert seismic.compressor.cname == BloscCname.zstd + assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid) + assert seismic.metadata.chunk_grid.configuration.chunk_shape == (16, 64, 1024) + assert seismic.metadata.stats_v1 is None + + +@pytest.mark.parametrize("data_domain", ["Time", "DePTh"]) +@pytest.mark.parametrize("gather_domain", ["Offset", "OffSeT"]) +def test_domain_case_handling(data_domain: str, gather_domain: str) -> None: + """Test that domain parameter handles different cases correctly.""" + template = Seismic2DPreStackCDPTemplate(data_domain, gather_domain) + assert template._data_domain == data_domain.lower() + + gather_domain_suffix = gather_domain.lower().capitalize() + data_domain_suffix = data_domain.lower().capitalize() + assert template.name == f"PreStackCdp{gather_domain_suffix}Gathers2D{data_domain_suffix}" diff --git a/tests/unit/v1/templates/test_seismic_2d_prestack_shot.py b/tests/unit/v1/templates/test_seismic_2d_prestack_shot.py new file mode 100644 index 000000000..f79648f59 --- /dev/null +++ b/tests/unit/v1/templates/test_seismic_2d_prestack_shot.py @@ -0,0 +1,179 @@ +"""Unit tests for Seismic2DPreStackShotTemplate.""" + +import pytest +from tests.unit.v1.helpers import validate_variable + +from mdio.builder.schemas.chunk_grid import RegularChunkGrid +from mdio.builder.schemas.compressors import Blosc +from mdio.builder.schemas.compressors import BloscCname +from mdio.builder.schemas.dtype import ScalarType +from mdio.builder.schemas.dtype import StructuredType +from mdio.builder.schemas.v1.dataset import Dataset +from mdio.builder.schemas.v1.units import LengthUnitEnum +from mdio.builder.schemas.v1.units import LengthUnitModel +from mdio.builder.templates.seismic_2d_prestack_shot import Seismic2DPreStackShotTemplate + +UNITS_METER = LengthUnitModel(length=LengthUnitEnum.METER) + + +def _validate_coordinates_headers_trace_mask(dataset: Dataset, headers: StructuredType, domain: str) -> None: + """Validate the coordinate, headers, trace_mask variables in the dataset.""" + # Verify variables + # 4 dim coords + 5 non-dim coords + 1 data + 1 trace mask + 1 headers = 12 variables + assert len(dataset.variables) == 11 + + # Verify trace headers + validate_variable( + dataset, + name="headers", + dims=[("shot_point", 256), ("channel", 24)], + coords=["gun", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"], + dtype=headers, + ) + + validate_variable( + dataset, + name="trace_mask", + dims=[("shot_point", 256), ("channel", 24)], + coords=["gun", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"], + dtype=ScalarType.BOOL, + ) + + # Verify dimension coordinate variables + shot_point = validate_variable( + dataset, + name="shot_point", + dims=[("shot_point", 256)], + coords=["shot_point"], + dtype=ScalarType.INT32, + ) + assert shot_point.metadata is None + + channel = validate_variable( + dataset, + name="channel", + dims=[("channel", 24)], + coords=["channel"], + dtype=ScalarType.INT32, + ) + assert channel.metadata is None + + domain = validate_variable( + dataset, + name=domain, + dims=[(domain, 2048)], + coords=[domain], + dtype=ScalarType.INT32, + ) + assert domain.metadata is None + + # Verify non-dimension coordinate variables + validate_variable( + dataset, + name="gun", + dims=[("shot_point", 256)], + coords=["gun"], + dtype=ScalarType.UINT8, + ) + + source_coord_x = validate_variable( + dataset, + name="source_coord_x", + dims=[("shot_point", 256)], + coords=["source_coord_x"], + dtype=ScalarType.FLOAT64, + ) + assert source_coord_x.metadata.units_v1.length == LengthUnitEnum.METER + + source_coord_y = validate_variable( + dataset, + name="source_coord_y", + dims=[("shot_point", 256)], + coords=["source_coord_y"], + dtype=ScalarType.FLOAT64, + ) + assert source_coord_y.metadata.units_v1.length == LengthUnitEnum.METER + + group_coord_x = validate_variable( + dataset, + name="group_coord_x", + dims=[("shot_point", 256), ("channel", 24)], + coords=["group_coord_x"], + dtype=ScalarType.FLOAT64, + ) + assert group_coord_x.metadata.units_v1.length == LengthUnitEnum.METER + + group_coord_y = validate_variable( + dataset, + name="group_coord_y", + dims=[("shot_point", 256), ("channel", 24)], + coords=["group_coord_y"], + dtype=ScalarType.FLOAT64, + ) + assert group_coord_y.metadata.units_v1.length == LengthUnitEnum.METER + + +class TestSeismic2DPreStackShotTemplate: + """Unit tests for Seismic2DPreStackShotTemplate.""" + + def test_configuration(self) -> None: + """Unit tests for Seismic2DPreStackShotTemplate.""" + t = Seismic2DPreStackShotTemplate(data_domain="time") + + # Template attributes for prestack shot + assert t._data_domain == "time" + assert t._coord_dim_names == ("shot_point", "channel") + assert t._dim_names == ("shot_point", "channel", "time") + assert t._coord_names == ("gun", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y") + assert t._var_chunk_shape == (16, 64, 1024) + + # Variables instantiated when build_dataset() is called + assert t._builder is None + assert t._dim_sizes == () + assert t._horizontal_coord_unit is None + + # Verify prestack shot attributes + attrs = t._load_dataset_attributes() + assert attrs == {"surveyType": "2D", "ensembleType": "common_source"} + + assert t.name == "PreStackShotGathers2DTime" + + def test_build_dataset(self, structured_headers: StructuredType) -> None: + """Unit tests for Seismic2DPreStackShotTemplate build in time domain.""" + t = Seismic2DPreStackShotTemplate(data_domain="time") + + assert t.name == "PreStackShotGathers2DTime" + dataset = t.build_dataset( + "North Sea 2D Shot Time", + sizes=(256, 24, 2048), + horizontal_coord_unit=UNITS_METER, + headers=structured_headers, + ) + + assert dataset.metadata.name == "North Sea 2D Shot Time" + assert dataset.metadata.attributes["surveyType"] == "2D" + assert dataset.metadata.attributes["ensembleType"] == "common_source" + + _validate_coordinates_headers_trace_mask(dataset, structured_headers, "time") + + # Verify seismic variable (prestack shot time data) + seismic = validate_variable( + dataset, + name="amplitude", + dims=[("shot_point", 256), ("channel", 24), ("time", 2048)], + coords=["gun", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"], + dtype=ScalarType.FLOAT32, + ) + assert isinstance(seismic.compressor, Blosc) + assert seismic.compressor.cname == BloscCname.zstd + assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid) + assert seismic.metadata.chunk_grid.configuration.chunk_shape == (16, 64, 1024) + assert seismic.metadata.stats_v1 is None + + +@pytest.mark.parametrize("data_domain", ["Time", "TiME"]) +def test_domain_case_handling(data_domain: str) -> None: + """Test that domain parameter handles different cases correctly.""" + template = Seismic2DPreStackShotTemplate(data_domain=data_domain) + assert template._data_domain == data_domain.lower() + assert template.name.endswith(data_domain.capitalize()) diff --git a/tests/unit/v1/templates/test_seismic_3d_poststack.py b/tests/unit/v1/templates/test_seismic_3d_poststack.py index 2094f5d34..f128c72fe 100644 --- a/tests/unit/v1/templates/test_seismic_3d_poststack.py +++ b/tests/unit/v1/templates/test_seismic_3d_poststack.py @@ -1,5 +1,6 @@ """Unit tests for Seismic3DPostStackTemplate.""" +import pytest from tests.unit.v1.helpers import validate_variable from mdio.builder.schemas.chunk_grid import RegularChunkGrid @@ -13,6 +14,7 @@ from mdio.builder.schemas.v1.units import TimeUnitEnum from mdio.builder.schemas.v1.units import TimeUnitModel from mdio.builder.templates.seismic_3d_poststack import Seismic3DPostStackTemplate +from mdio.builder.templates.types import SeismicDataDomain UNITS_METER = LengthUnitModel(length=LengthUnitEnum.METER) UNITS_SECOND = TimeUnitModel(time=TimeUnitEnum.SECOND) @@ -89,17 +91,18 @@ def _validate_coordinates_headers_trace_mask(dataset: Dataset, headers: Structur assert cdp_y.metadata.units_v1.length == LengthUnitEnum.METER +@pytest.mark.parametrize("data_domain", ["depth", "time"]) class TestSeismic3DPostStackTemplate: """Unit tests for Seismic3DPostStackTemplate.""" - def test_configuration_depth(self) -> None: - """Unit tests for Seismic3DPostStackTemplate with depth domain.""" - t = Seismic3DPostStackTemplate(domain="depth") + def test_configuration(self, data_domain: SeismicDataDomain) -> None: + """Unit tests for Seismic3DPostStackTemplate.""" + t = Seismic3DPostStackTemplate(data_domain=data_domain) # Template attributes to be overridden by subclasses - assert t._trace_domain == "depth" # Domain should be lowercased + assert t._data_domain == data_domain # Domain should be lowercased assert t._coord_dim_names == ("inline", "crossline") - assert t._dim_names == ("inline", "crossline", "depth") + assert t._dim_names == ("inline", "crossline", data_domain) assert t._coord_names == ("cdp_x", "cdp_y") assert t._var_chunk_shape == (128, 128, 128) @@ -110,54 +113,15 @@ def test_configuration_depth(self) -> None: # Verify dataset attributes attrs = t._load_dataset_attributes() - assert attrs == { - "surveyDimensionality": "3D", - "ensembleType": "line", - "processingStage": "post-stack", - } + assert attrs == {"surveyType": "3D", "gatherType": "stacked"} assert t.default_variable_name == "amplitude" - def test_configuration_time(self) -> None: - """Unit tests for Seismic3DPostStackTemplate with time domain.""" - t = Seismic3DPostStackTemplate(domain="time") + def test_build_dataset(self, data_domain: SeismicDataDomain, structured_headers: StructuredType) -> None: + """Unit tests for Seismic3DPostStackTemplate build.""" + t = Seismic3DPostStackTemplate(data_domain=data_domain) - # Template attributes to be overridden by subclasses - assert t._trace_domain == "time" # Domain should be lowercased - assert t._coord_dim_names == ("inline", "crossline") - assert t._dim_names == ("inline", "crossline", "time") - assert t._coord_names == ("cdp_x", "cdp_y") - assert t._var_chunk_shape == (128, 128, 128) - - # Variables instantiated when build_dataset() is called - assert t._builder is None - assert t._dim_sizes == () - assert t._horizontal_coord_unit is None - - assert t._load_dataset_attributes() == { - "surveyDimensionality": "3D", - "ensembleType": "line", - "processingStage": "post-stack", - } - - assert t.name == "PostStack3DTime" - - def test_domain_case_handling(self) -> None: - """Test that domain parameter handles different cases correctly.""" - # Test uppercase - t1 = Seismic3DPostStackTemplate("ELEVATION") - assert t1._trace_domain == "elevation" - assert t1.name == "PostStack3DElevation" - - # Test mixed case - t2 = Seismic3DPostStackTemplate("elevatioN") - assert t2._trace_domain == "elevation" - assert t2.name == "PostStack3DElevation" - - def test_build_dataset_depth(self, structured_headers: StructuredType) -> None: - """Unit tests for Seismic3DPostStackTemplate build with depth domain.""" - t = Seismic3DPostStackTemplate(domain="depth") - - assert t.name == "PostStack3DDepth" + data_domain_suffix = data_domain.capitalize() + assert t.name == f"PostStack3D{data_domain_suffix}" dataset = t.build_dataset( "Seismic 3D", sizes=(256, 512, 1024), @@ -166,17 +130,16 @@ def test_build_dataset_depth(self, structured_headers: StructuredType) -> None: ) assert dataset.metadata.name == "Seismic 3D" - assert dataset.metadata.attributes["surveyDimensionality"] == "3D" - assert dataset.metadata.attributes["ensembleType"] == "line" - assert dataset.metadata.attributes["processingStage"] == "post-stack" + assert dataset.metadata.attributes["surveyType"] == "3D" + assert dataset.metadata.attributes["gatherType"] == "stacked" - _validate_coordinates_headers_trace_mask(dataset, structured_headers, "depth") + _validate_coordinates_headers_trace_mask(dataset, structured_headers, data_domain) # Verify seismic variable seismic = validate_variable( dataset, name="amplitude", - dims=[("inline", 256), ("crossline", 512), ("depth", 1024)], + dims=[("inline", 256), ("crossline", 512), (data_domain, 1024)], coords=["cdp_x", "cdp_y"], dtype=ScalarType.FLOAT32, ) @@ -186,35 +149,12 @@ def test_build_dataset_depth(self, structured_headers: StructuredType) -> None: assert seismic.metadata.chunk_grid.configuration.chunk_shape == (128, 128, 128) assert seismic.metadata.stats_v1 is None - def test_build_dataset_time(self, structured_headers: StructuredType) -> None: - """Unit tests for Seismic3DPostStackTimeTemplate build with time domain.""" - t = Seismic3DPostStackTemplate(domain="time") - - assert t.name == "PostStack3DTime" - dataset = t.build_dataset( - "Seismic 3D", - sizes=(256, 512, 1024), - horizontal_coord_unit=UNITS_METER, - headers=structured_headers, - ) - - assert dataset.metadata.name == "Seismic 3D" - assert dataset.metadata.attributes["surveyDimensionality"] == "3D" - assert dataset.metadata.attributes["ensembleType"] == "line" - assert dataset.metadata.attributes["processingStage"] == "post-stack" - _validate_coordinates_headers_trace_mask(dataset, structured_headers, "time") +@pytest.mark.parametrize("data_domain", ["Time", "DePTh"]) +def test_domain_case_handling(data_domain: str) -> None: + """Test that domain parameter handles different cases correctly.""" + template = Seismic3DPostStackTemplate(data_domain=data_domain) + assert template._data_domain == data_domain.lower() - # Verify seismic variable - seismic = validate_variable( - dataset, - name="amplitude", - dims=[("inline", 256), ("crossline", 512), ("time", 1024)], - coords=["cdp_x", "cdp_y"], - dtype=ScalarType.FLOAT32, - ) - assert isinstance(seismic.compressor, Blosc) - assert seismic.compressor.cname == BloscCname.zstd - assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid) - assert seismic.metadata.chunk_grid.configuration.chunk_shape == (128, 128, 128) - assert seismic.metadata.stats_v1 is None + data_domain_suffix = data_domain.lower().capitalize() + assert template.name == f"PostStack3D{data_domain_suffix}" diff --git a/tests/unit/v1/templates/test_seismic_3d_prestack_cdp.py b/tests/unit/v1/templates/test_seismic_3d_prestack_cdp.py index 8f1e1ea77..df552a974 100644 --- a/tests/unit/v1/templates/test_seismic_3d_prestack_cdp.py +++ b/tests/unit/v1/templates/test_seismic_3d_prestack_cdp.py @@ -1,5 +1,6 @@ """Unit tests for Seismic3DPreStackCDPTemplate.""" +import pytest from tests.unit.v1.helpers import validate_variable from mdio.builder.schemas.chunk_grid import RegularChunkGrid @@ -13,12 +14,19 @@ from mdio.builder.schemas.v1.units import TimeUnitEnum from mdio.builder.schemas.v1.units import TimeUnitModel from mdio.builder.templates.seismic_3d_prestack_cdp import Seismic3DPreStackCDPTemplate +from mdio.builder.templates.types import CdpGatherDomain +from mdio.builder.templates.types import SeismicDataDomain UNITS_METER = LengthUnitModel(length=LengthUnitEnum.METER) UNITS_SECOND = TimeUnitModel(time=TimeUnitEnum.SECOND) -def validate_coordinates_headers_trace_mask(dataset: Dataset, headers: StructuredType, domain: str) -> None: +def validate_coordinates_headers_trace_mask( + dataset: Dataset, + headers: StructuredType, + data_domain: SeismicDataDomain, + gather_domain: CdpGatherDomain, +) -> None: """A helper method to validate coordinates, headers, and trace mask.""" # Verify variables # 4 dim coords + 2 non-dim coords + 1 data + 1 trace mask + 1 headers = 8 variables @@ -28,7 +36,7 @@ def validate_coordinates_headers_trace_mask(dataset: Dataset, headers: Structure validate_variable( dataset, name="headers", - dims=[("inline", 512), ("crossline", 768), ("offset", 36)], + dims=[("inline", 512), ("crossline", 768), (gather_domain, 36)], coords=["cdp_x", "cdp_y"], dtype=headers, ) @@ -36,7 +44,7 @@ def validate_coordinates_headers_trace_mask(dataset: Dataset, headers: Structure validate_variable( dataset, name="trace_mask", - dims=[("inline", 512), ("crossline", 768), ("offset", 36)], + dims=[("inline", 512), ("crossline", 768), (gather_domain, 36)], coords=["cdp_x", "cdp_y"], dtype=ScalarType.BOOL, ) @@ -60,20 +68,20 @@ def validate_coordinates_headers_trace_mask(dataset: Dataset, headers: Structure ) assert crossline.metadata is None - crossline = validate_variable( + domain = validate_variable( dataset, - name="offset", - dims=[("offset", 36)], - coords=["offset"], + name=gather_domain, + dims=[(gather_domain, 36)], + coords=[gather_domain], dtype=ScalarType.INT32, ) - assert crossline.metadata is None + assert domain.metadata is None domain = validate_variable( dataset, - name=domain, - dims=[(domain, 1536)], - coords=[domain], + name=data_domain, + dims=[(data_domain, 1536)], + coords=[data_domain], dtype=ScalarType.INT32, ) assert domain.metadata is None @@ -82,7 +90,7 @@ def validate_coordinates_headers_trace_mask(dataset: Dataset, headers: Structure cdp_x = validate_variable( dataset, name="cdp_x", - dims=[("inline", 512), ("crossline", 768), ("offset", 36)], + dims=[("inline", 512), ("crossline", 768), (gather_domain, 36)], coords=["cdp_x"], dtype=ScalarType.FLOAT64, ) @@ -91,26 +99,27 @@ def validate_coordinates_headers_trace_mask(dataset: Dataset, headers: Structure cdp_y = validate_variable( dataset, name="cdp_y", - dims=[("inline", 512), ("crossline", 768), ("offset", 36)], + dims=[("inline", 512), ("crossline", 768), (gather_domain, 36)], coords=["cdp_y"], dtype=ScalarType.FLOAT64, ) assert cdp_y.metadata.units_v1.length == LengthUnitEnum.METER +@pytest.mark.parametrize("data_domain", ["depth", "time"]) +@pytest.mark.parametrize("gather_domain", ["offset", "angle"]) class TestSeismic3DPreStackCDPTemplate: """Unit tests for Seismic3DPreStackCDPTemplate.""" - def test_configuration_depth(self) -> None: + def test_configuration(self, data_domain: SeismicDataDomain, gather_domain: CdpGatherDomain) -> None: """Unit tests for Seismic3DPreStackCDPTemplate.""" - t = Seismic3DPreStackCDPTemplate(domain="DEPTH") + t = Seismic3DPreStackCDPTemplate(data_domain, gather_domain) # Template attributes for prestack CDP - assert t._trace_domain == "depth" - assert t._coord_dim_names == ("inline", "crossline", "offset") - assert t._dim_names == ("inline", "crossline", "offset", "depth") + assert t._coord_dim_names == ("inline", "crossline", gather_domain) + assert t._dim_names == ("inline", "crossline", gather_domain, data_domain) assert t._coord_names == ("cdp_x", "cdp_y") - assert t._var_chunk_shape == (1, 1, 512, 4096) + assert t._var_chunk_shape == (8, 8, 32, 512) # Variables instantiated when build_dataset() is called assert t._builder is None @@ -119,113 +128,56 @@ def test_configuration_depth(self) -> None: # Verify prestack CDP attributes attrs = t._load_dataset_attributes() - assert attrs == { - "surveyDimensionality": "3D", - "ensembleType": "cdp", - "processingStage": "pre-stack", - } + assert attrs == {"surveyType": "3D", "gatherType": "cdp"} assert t.default_variable_name == "amplitude" - def test_configuration_time(self) -> None: - """Unit tests for Seismic3DPreStackCDPTemplate.""" - t = Seismic3DPreStackCDPTemplate(domain="TIME") - - # Template attributes for prestack CDP - assert t._trace_domain == "time" - assert t._coord_dim_names == ("inline", "crossline", "offset") - assert t._dim_names == ("inline", "crossline", "offset", "time") - assert t._coord_names == ("cdp_x", "cdp_y") - assert t._var_chunk_shape == (1, 1, 512, 4096) - - # Variables instantiated when build_dataset() is called - assert t._builder is None - assert t._dim_sizes == () - assert t._horizontal_coord_unit is None - - # Verify prestack CDP attributes - attrs = t._load_dataset_attributes() - assert attrs == { - "surveyDimensionality": "3D", - "ensembleType": "cdp", - "processingStage": "pre-stack", - } - - assert t.name == "PreStackCdpGathers3DTime" - - def test_domain_case_handling(self) -> None: - """Test that domain parameter handles different cases correctly.""" - # Test uppercase - t1 = Seismic3DPreStackCDPTemplate("ELEVATION") - assert t1._trace_domain == "elevation" - assert t1.name == "PreStackCdpGathers3DElevation" - - # Test mixed case - t2 = Seismic3DPreStackCDPTemplate("elevatioN") - assert t2._trace_domain == "elevation" - assert t2.name == "PreStackCdpGathers3DElevation" - - def test_build_dataset_depth(self, structured_headers: StructuredType) -> None: - """Unit tests for Seismic3DPreStackCDPDepthTemplate build with depth domain.""" - t = Seismic3DPreStackCDPTemplate(domain="depth") - - assert t.name == "PreStackCdpGathers3DDepth" + def test_build_dataset( + self, + data_domain: SeismicDataDomain, + gather_domain: CdpGatherDomain, + structured_headers: StructuredType, + ) -> None: + """Unit tests for Seismic3DPreStackCDPDepthTemplate build.""" + t = Seismic3DPreStackCDPTemplate(data_domain, gather_domain) + + gather_domain_suffix = gather_domain.capitalize() + data_domain_suffix = data_domain.capitalize() + assert t.name == f"PreStackCdp{gather_domain_suffix}Gathers3D{data_domain_suffix}" dataset = t.build_dataset( - "North Sea 3D Prestack Depth", + "North Sea 3D Prestack", sizes=(512, 768, 36, 1536), horizontal_coord_unit=UNITS_METER, headers=structured_headers, ) - assert dataset.metadata.name == "North Sea 3D Prestack Depth" - assert dataset.metadata.attributes["surveyDimensionality"] == "3D" - assert dataset.metadata.attributes["ensembleType"] == "cdp" - assert dataset.metadata.attributes["processingStage"] == "pre-stack" + assert dataset.metadata.name == "North Sea 3D Prestack" + assert dataset.metadata.attributes["surveyType"] == "3D" + assert dataset.metadata.attributes["gatherType"] == "cdp" - validate_coordinates_headers_trace_mask(dataset, structured_headers, "depth") + validate_coordinates_headers_trace_mask(dataset, structured_headers, data_domain, gather_domain) # Verify seismic variable (prestack depth data) seismic = validate_variable( dataset, name="amplitude", - dims=[("inline", 512), ("crossline", 768), ("offset", 36), ("depth", 1536)], + dims=[("inline", 512), ("crossline", 768), (gather_domain, 36), (data_domain, 1536)], coords=["cdp_x", "cdp_y"], dtype=ScalarType.FLOAT32, ) assert isinstance(seismic.compressor, Blosc) assert seismic.compressor.cname == BloscCname.zstd assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid) - assert seismic.metadata.chunk_grid.configuration.chunk_shape == (1, 1, 512, 4096) + assert seismic.metadata.chunk_grid.configuration.chunk_shape == (8, 8, 32, 512) assert seismic.metadata.stats_v1 is None - def test_build_dataset_time(self, structured_headers: StructuredType) -> None: - """Unit tests for Seismic3DPreStackCDPTimeTemplate build with time domain.""" - t = Seismic3DPreStackCDPTemplate(domain="time") - assert t.name == "PreStackCdpGathers3DTime" - dataset = t.build_dataset( - "Santos Basin 3D Prestack", - sizes=(512, 768, 36, 1536), - horizontal_coord_unit=UNITS_METER, - headers=structured_headers, - ) +@pytest.mark.parametrize("data_domain", ["Time", "DePTh"]) +@pytest.mark.parametrize("gather_domain", ["Offset", "OffSeT"]) +def test_domain_case_handling(data_domain: str, gather_domain: str) -> None: + """Test that domain parameter handles different cases correctly.""" + template = Seismic3DPreStackCDPTemplate(data_domain, gather_domain) + assert template._data_domain == data_domain.lower() - assert dataset.metadata.name == "Santos Basin 3D Prestack" - assert dataset.metadata.attributes["surveyDimensionality"] == "3D" - assert dataset.metadata.attributes["ensembleType"] == "cdp" - assert dataset.metadata.attributes["processingStage"] == "pre-stack" - - validate_coordinates_headers_trace_mask(dataset, structured_headers, "time") - - # Verify seismic variable (prestack time data) - seismic = validate_variable( - dataset, - name="amplitude", - dims=[("inline", 512), ("crossline", 768), ("offset", 36), ("time", 1536)], - coords=["cdp_x", "cdp_y"], - dtype=ScalarType.FLOAT32, - ) - assert isinstance(seismic.compressor, Blosc) - assert seismic.compressor.cname == BloscCname.zstd - assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid) - assert seismic.metadata.chunk_grid.configuration.chunk_shape == (1, 1, 512, 4096) - assert seismic.metadata.stats_v1 is None + gather_domain_suffix = gather_domain.lower().capitalize() + data_domain_suffix = data_domain.lower().capitalize() + assert template.name == f"PreStackCdp{gather_domain_suffix}Gathers3D{data_domain_suffix}" diff --git a/tests/unit/v1/templates/test_seismic_3d_prestack_coca.py b/tests/unit/v1/templates/test_seismic_3d_prestack_coca.py index c3576ef62..693de9988 100644 --- a/tests/unit/v1/templates/test_seismic_3d_prestack_coca.py +++ b/tests/unit/v1/templates/test_seismic_3d_prestack_coca.py @@ -1,5 +1,6 @@ """Unit tests for Seismic3DPreStackCocaTemplate.""" +import pytest from tests.unit.v1.helpers import validate_variable from mdio.builder.schemas.chunk_grid import RegularChunkGrid @@ -14,6 +15,7 @@ from mdio.builder.schemas.v1.units import TimeUnitEnum from mdio.builder.schemas.v1.units import TimeUnitModel from mdio.builder.templates.seismic_3d_prestack_coca import Seismic3DPreStackCocaTemplate +from mdio.builder.templates.types import SeismicDataDomain UNITS_METER = LengthUnitModel(length=LengthUnitEnum.METER) UNITS_SECOND = TimeUnitModel(time=TimeUnitEnum.SECOND) @@ -108,16 +110,17 @@ def _validate_coordinates_headers_trace_mask(dataset: Dataset, headers: Structur assert cdp_y.metadata.units_v1.length == LengthUnitEnum.METER +@pytest.mark.parametrize("data_domain", ["depth", "time"]) class TestSeismic3DPreStackCocaTemplate: """Unit tests for Seismic3DPreStackCocaTemplate.""" - def test_configuration_time(self) -> None: - """Unit tests for Seismic3DPreStackCocaTemplate in time domain.""" - t = Seismic3DPreStackCocaTemplate(domain="time") + def test_configuration(self, data_domain: SeismicDataDomain) -> None: + """Unit tests for Seismic3DPreStackCocaTemplate.""" + t = Seismic3DPreStackCocaTemplate(data_domain=data_domain) # Template attributes assert t._coord_dim_names == ("inline", "crossline", "offset", "azimuth") - assert t._dim_names == ("inline", "crossline", "offset", "azimuth", "time") + assert t._dim_names == ("inline", "crossline", "offset", "azimuth", data_domain) assert t._coord_names == ("cdp_x", "cdp_y") assert t._var_chunk_shape == (8, 8, 32, 1, 1024) @@ -128,16 +131,12 @@ def test_configuration_time(self) -> None: # Verify dataset attributes attrs = t._load_dataset_attributes() - assert attrs == { - "surveyDimensionality": "3D", - "ensembleType": "cdp_coca", - "processingStage": "pre-stack", - } + assert attrs == {"surveyType": "3D", "gatherType": "common_offset_common_azimuth"} assert t.default_variable_name == "amplitude" - def test_build_dataset_time(self, structured_headers: StructuredType) -> None: - """Unit tests for Seismic3DPreStackShotTemplate build in time domain.""" - t = Seismic3DPreStackCocaTemplate(domain="time") + def test_build_dataset(self, data_domain: SeismicDataDomain, structured_headers: StructuredType) -> None: + """Unit tests for Seismic3DPreStackShotTemplate build.""" + t = Seismic3DPreStackCocaTemplate(data_domain=data_domain) dataset = t.build_dataset( "Permian Basin 3D CDP Coca Gathers", @@ -147,17 +146,16 @@ def test_build_dataset_time(self, structured_headers: StructuredType) -> None: ) assert dataset.metadata.name == "Permian Basin 3D CDP Coca Gathers" - assert dataset.metadata.attributes["surveyDimensionality"] == "3D" - assert dataset.metadata.attributes["ensembleType"] == "cdp_coca" - assert dataset.metadata.attributes["processingStage"] == "pre-stack" + assert dataset.metadata.attributes["surveyType"] == "3D" + assert dataset.metadata.attributes["gatherType"] == "common_offset_common_azimuth" - _validate_coordinates_headers_trace_mask(dataset, structured_headers, "time") + _validate_coordinates_headers_trace_mask(dataset, structured_headers, data_domain) # Verify seismic variable (prestack shot depth data) seismic = validate_variable( dataset, name="amplitude", - dims=[("inline", 256), ("crossline", 256), ("offset", 100), ("azimuth", 6), ("time", 2048)], + dims=[("inline", 256), ("crossline", 256), ("offset", 100), ("azimuth", 6), (data_domain, 2048)], coords=["cdp_x", "cdp_y"], dtype=ScalarType.FLOAT32, ) @@ -166,3 +164,13 @@ def test_build_dataset_time(self, structured_headers: StructuredType) -> None: assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid) assert seismic.metadata.chunk_grid.configuration.chunk_shape == (8, 8, 32, 1, 1024) assert seismic.metadata.stats_v1 is None + + +@pytest.mark.parametrize("data_domain", ["Time", "DePTh"]) +def test_domain_case_handling(data_domain: str) -> None: + """Test that domain parameter handles different cases correctly.""" + template = Seismic3DPreStackCocaTemplate(data_domain=data_domain) + assert template._data_domain == data_domain.lower() + + data_domain_suffix = data_domain.lower().capitalize() + assert template.name == f"PreStackCocaGathers3D{data_domain_suffix}" diff --git a/tests/unit/v1/templates/test_seismic_3d_prestack_shot.py b/tests/unit/v1/templates/test_seismic_3d_prestack_shot.py index 0948a1976..3ee729275 100644 --- a/tests/unit/v1/templates/test_seismic_3d_prestack_shot.py +++ b/tests/unit/v1/templates/test_seismic_3d_prestack_shot.py @@ -1,5 +1,6 @@ """Unit tests for Seismic3DPreStackShotTemplate.""" +import pytest from tests.unit.v1.helpers import validate_variable from mdio.builder.schemas.chunk_grid import RegularChunkGrid @@ -39,32 +40,32 @@ def _validate_coordinates_headers_trace_mask(dataset: Dataset, headers: Structur ) # Verify dimension coordinate variables - inline = validate_variable( + shot_point = validate_variable( dataset, name="shot_point", dims=[("shot_point", 256)], coords=["shot_point"], dtype=ScalarType.INT32, ) - assert inline.metadata is None + assert shot_point.metadata is None - crossline = validate_variable( + cable = validate_variable( dataset, name="cable", dims=[("cable", 512)], coords=["cable"], dtype=ScalarType.INT32, ) - assert crossline.metadata is None + assert cable.metadata is None - crossline = validate_variable( + channel = validate_variable( dataset, name="channel", dims=[("channel", 24)], coords=["channel"], dtype=ScalarType.INT32, ) - assert crossline.metadata is None + assert channel.metadata is None domain = validate_variable( dataset, @@ -79,7 +80,7 @@ def _validate_coordinates_headers_trace_mask(dataset: Dataset, headers: Structur validate_variable( dataset, name="gun", - dims=[("shot_point", 256), ("cable", 512), ("channel", 24)], + dims=[("shot_point", 256)], coords=["gun"], dtype=ScalarType.UINT8, ) @@ -87,7 +88,7 @@ def _validate_coordinates_headers_trace_mask(dataset: Dataset, headers: Structur source_coord_x = validate_variable( dataset, name="source_coord_x", - dims=[("shot_point", 256), ("cable", 512), ("channel", 24)], + dims=[("shot_point", 256)], coords=["source_coord_x"], dtype=ScalarType.FLOAT64, ) @@ -96,7 +97,7 @@ def _validate_coordinates_headers_trace_mask(dataset: Dataset, headers: Structur source_coord_y = validate_variable( dataset, name="source_coord_y", - dims=[("shot_point", 256), ("cable", 512), ("channel", 24)], + dims=[("shot_point", 256)], coords=["source_coord_y"], dtype=ScalarType.FLOAT64, ) @@ -124,41 +125,16 @@ def _validate_coordinates_headers_trace_mask(dataset: Dataset, headers: Structur class TestSeismic3DPreStackShotTemplate: """Unit tests for Seismic3DPreStackShotTemplate.""" - def test_configuration_depth(self) -> None: - """Unit tests for Seismic3DPreStackShotTemplate in depth domain.""" - t = Seismic3DPreStackShotTemplate(domain="DEPTH") - - # Template attributes for prestack shot - assert t._trace_domain == "depth" - assert t._coord_dim_names == ("shot_point", "cable", "channel") - assert t._dim_names == ("shot_point", "cable", "channel", "depth") - assert t._coord_names == ("gun", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y") - assert t._var_chunk_shape == (1, 1, 512, 4096) - - # Variables instantiated when build_dataset() is called - assert t._builder is None - assert t._dim_sizes == () - assert t._horizontal_coord_unit is None - - # Verify prestack shot attributes - attrs = t._load_dataset_attributes() - assert attrs == { - "surveyDimensionality": "3D", - "ensembleType": "shot_point", - "processingStage": "pre-stack", - } - assert t.default_variable_name == "amplitude" - - def test_configuration_time(self) -> None: + def test_configuration(self) -> None: """Unit tests for Seismic3DPreStackShotTemplate in time domain.""" - t = Seismic3DPreStackShotTemplate(domain="TIME") + t = Seismic3DPreStackShotTemplate(data_domain="time") # Template attributes for prestack shot - assert t._trace_domain == "time" + assert t._data_domain == "time" assert t._coord_dim_names == ("shot_point", "cable", "channel") assert t._dim_names == ("shot_point", "cable", "channel", "time") assert t._coord_names == ("gun", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y") - assert t._var_chunk_shape == (1, 1, 512, 4096) + assert t._var_chunk_shape == (8, 2, 128, 1024) # Variables instantiated when build_dataset() is called assert t._builder is None @@ -167,62 +143,13 @@ def test_configuration_time(self) -> None: # Verify prestack shot attributes attrs = t._load_dataset_attributes() - assert attrs == { - "surveyDimensionality": "3D", - "ensembleType": "shot_point", - "processingStage": "pre-stack", - } + assert attrs == {"surveyType": "3D", "ensembleType": "common_source"} assert t.name == "PreStackShotGathers3DTime" - def test_domain_case_handling(self) -> None: - """Test that domain parameter handles different cases correctly.""" - # Test uppercase - t1 = Seismic3DPreStackShotTemplate("ELEVATION") - assert t1._trace_domain == "elevation" - assert t1.name == "PreStackShotGathers3DElevation" - - # Test mixed case - t2 = Seismic3DPreStackShotTemplate("elevatioN") - assert t2._trace_domain == "elevation" - assert t2.name == "PreStackShotGathers3DElevation" - - def test_build_dataset_depth(self, structured_headers: StructuredType) -> None: - """Unit tests for Seismic3DPreStackShotTemplate build in depth domain.""" - t = Seismic3DPreStackShotTemplate(domain="depth") - - assert t.name == "PreStackShotGathers3DDepth" - dataset = t.build_dataset( - "Gulf of Mexico 3D Shot Depth", - sizes=(256, 512, 24, 2048), - horizontal_coord_unit=UNITS_METER, - headers=structured_headers, - ) - - assert dataset.metadata.name == "Gulf of Mexico 3D Shot Depth" - assert dataset.metadata.attributes["surveyDimensionality"] == "3D" - assert dataset.metadata.attributes["ensembleType"] == "shot_point" - assert dataset.metadata.attributes["processingStage"] == "pre-stack" - - _validate_coordinates_headers_trace_mask(dataset, structured_headers, "depth") - - # Verify seismic variable (prestack shot depth data) - seismic = validate_variable( - dataset, - name="amplitude", - dims=[("shot_point", 256), ("cable", 512), ("channel", 24), ("depth", 2048)], - coords=["gun", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"], - dtype=ScalarType.FLOAT32, - ) - assert isinstance(seismic.compressor, Blosc) - assert seismic.compressor.cname == BloscCname.zstd - assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid) - assert seismic.metadata.chunk_grid.configuration.chunk_shape == (1, 1, 512, 4096) - assert seismic.metadata.stats_v1 is None - - def test_build_dataset_time(self, structured_headers: StructuredType) -> None: + def test_build_dataset(self, structured_headers: StructuredType) -> None: """Unit tests for Seismic3DPreStackShotTemplate build in time domain.""" - t = Seismic3DPreStackShotTemplate(domain="time") + t = Seismic3DPreStackShotTemplate(data_domain="time") assert t.name == "PreStackShotGathers3DTime" dataset = t.build_dataset( @@ -233,9 +160,8 @@ def test_build_dataset_time(self, structured_headers: StructuredType) -> None: ) assert dataset.metadata.name == "North Sea 3D Shot Time" - assert dataset.metadata.attributes["surveyDimensionality"] == "3D" - assert dataset.metadata.attributes["ensembleType"] == "shot_point" - assert dataset.metadata.attributes["processingStage"] == "pre-stack" + assert dataset.metadata.attributes["surveyType"] == "3D" + assert dataset.metadata.attributes["ensembleType"] == "common_source" _validate_coordinates_headers_trace_mask(dataset, structured_headers, "time") @@ -250,5 +176,13 @@ def test_build_dataset_time(self, structured_headers: StructuredType) -> None: assert isinstance(seismic.compressor, Blosc) assert seismic.compressor.cname == BloscCname.zstd assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid) - assert seismic.metadata.chunk_grid.configuration.chunk_shape == (1, 1, 512, 4096) + assert seismic.metadata.chunk_grid.configuration.chunk_shape == (8, 2, 128, 1024) assert seismic.metadata.stats_v1 is None + + +@pytest.mark.parametrize("data_domain", ["Time", "TiME"]) +def test_domain_case_handling(data_domain: str) -> None: + """Test that domain parameter handles different cases correctly.""" + template = Seismic3DPreStackShotTemplate(data_domain=data_domain) + assert template._data_domain == data_domain.lower() + assert template.name.endswith(data_domain.capitalize()) diff --git a/tests/unit/v1/templates/test_seismic_templates.py b/tests/unit/v1/templates/test_seismic_templates.py index 3cdfeea14..ac45fcfac 100644 --- a/tests/unit/v1/templates/test_seismic_templates.py +++ b/tests/unit/v1/templates/test_seismic_templates.py @@ -22,7 +22,7 @@ def test_custom_data_variable_name(self) -> None: # Define a template with a custom data variable name 'velocity' class Velocity2DPostStackTemplate(Seismic2DPostStackTemplate): def __init__(self, domain: str): - super().__init__(domain=domain) + super().__init__(data_domain=domain) @property def _default_variable_name(self) -> str: @@ -30,7 +30,7 @@ def _default_variable_name(self) -> str: @property def _name(self) -> str: - return f"Velocity2D{self._trace_domain.capitalize()}" + return f"Velocity2D{self._data_domain.capitalize()}" t = Velocity2DPostStackTemplate("depth") assert t.name == "Velocity2DDepth" @@ -55,8 +55,8 @@ def test_get_name_time(self) -> None: assert Seismic3DPostStackTemplate("time").name == "PostStack3DTime" assert Seismic3DPostStackTemplate("depth").name == "PostStack3DDepth" - assert Seismic3DPreStackCDPTemplate("time").name == "PreStackCdpGathers3DTime" - assert Seismic3DPreStackCDPTemplate("depth").name == "PreStackCdpGathers3DDepth" + assert Seismic3DPreStackCDPTemplate("time", "angle").name == "PreStackCdpAngleGathers3DTime" + assert Seismic3DPreStackCDPTemplate("depth", "offset").name == "PreStackCdpOffsetGathers3DDepth" assert Seismic3DPreStackShotTemplate("time").name == "PreStackShotGathers3DTime" diff --git a/tests/unit/v1/templates/test_template_registry.py b/tests/unit/v1/templates/test_template_registry.py index b1eea1b2a..6aaa1aaed 100644 --- a/tests/unit/v1/templates/test_template_registry.py +++ b/tests/unit/v1/templates/test_template_registry.py @@ -12,16 +12,21 @@ from mdio.builder.template_registry import list_templates from mdio.builder.template_registry import register_template from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate +from mdio.builder.templates.types import SeismicDataDomain EXPECTED_DEFAULT_TEMPLATE_NAMES = [ "PostStack2DTime", "PostStack2DDepth", "PostStack3DTime", "PostStack3DDepth", - "PreStackCdpGathers2DTime", - "PreStackCdpGathers2DDepth", - "PreStackCdpGathers3DTime", - "PreStackCdpGathers3DDepth", + "PreStackCdpOffsetGathers2DTime", + "PreStackCdpAngleGathers2DDepth", + "PreStackCdpOffsetGathers2DTime", + "PreStackCdpAngleGathers2DDepth", + "PreStackCdpOffsetGathers3DTime", + "PreStackCdpAngleGathers3DTime", + "PreStackCdpOffsetGathers3DDepth", + "PreStackCdpAngleGathers3DDepth", "PreStackCocaGathers3DTime", "PreStackCocaGathers3DDepth", "PreStackShotGathers2DTime", @@ -32,8 +37,8 @@ class MockDatasetTemplate(AbstractDatasetTemplate): """Mock template for testing.""" - def __init__(self, name: str): - super().__init__() + def __init__(self, name: str, data_domain: SeismicDataDomain = "time"): + super().__init__(data_domain=data_domain) self.template_name = name @property @@ -184,7 +189,7 @@ def test_list_all_templates(self) -> None: registry.register(template2) templates = registry.list_all_templates() - assert len(templates) == 12 + 2 # 12 default + 2 custom + assert len(templates) == 16 + 2 # 16 default + 2 custom assert "Template_One" in templates assert "Template_Two" in templates @@ -194,7 +199,7 @@ def test_clear_templates(self) -> None: # Default templates are always installed templates = list_templates() - assert len(templates) == 12 + assert len(templates) == 16 # Add some templates template1 = MockDatasetTemplate("Template1") @@ -203,7 +208,7 @@ def test_clear_templates(self) -> None: registry.register(template1) registry.register(template2) - assert len(registry.list_all_templates()) == 12 + 2 # 12 default + 2 custom + assert len(registry.list_all_templates()) == 16 + 2 # 16 default + 2 custom # Clear all registry.clear() @@ -281,7 +286,7 @@ def test_list_templates_global(self) -> None: register_template(template2) templates = list_templates() - assert len(templates) == 14 # 12 default + 2 custom + assert len(templates) == 18 # 16 default + 2 custom assert "template1" in templates assert "template2" in templates @@ -327,7 +332,7 @@ def register_template_worker(template_id: int) -> None: assert len(errors) == 0 assert len(results) == 10 # Including 8 default templates - assert len(registry.list_all_templates()) == 22 # 12 default + 10 registered + assert len(registry.list_all_templates()) == 26 # 16 default + 10 registered # Check all templates are registered for i in range(10):