Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[CLN] Refactor Sections class and imports for clarity
Simplified imports and removed unused `SectionDefinition` class. Adjusted and clarified type annotations while ensuring internal fields are properly excluded from serialization.
  • Loading branch information
Leguark committed May 25, 2025
commit cfb0461c1f6a9513d0e8e7c3742f288bd4b1b527
45 changes: 22 additions & 23 deletions gempy/core/data/grid_modules/sections_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,16 @@
import dataclasses
import numpy as np

from gempy.core.data.core_utils import calculate_line_coordinates_2points
from gempy.optional_dependencies import require_pandas
from ..core_utils import calculate_line_coordinates_2points
from ..encoders.converters import short_array_type
from ....optional_dependencies import require_pandas

try:
import pandas as pd
except ImportError:
pandas = None


@dataclasses.dataclass
class SectionDefinition:
"""
A single cross‐section’s raw parameters.
"""
start: Tuple[float, float]
stop: Tuple[float, float]
resolution: Tuple[int, int]


@dataclasses.dataclass
class Sections:
"""
Expand All @@ -40,18 +31,25 @@ class Sections:

# user‐provided inputs

z_ext: Tuple[float, float]
section_dict: Dict[str, tuple[list[int]]]
z_ext: Tuple[float, float] | short_array_type
section_dict: Dict[
str,
Tuple[
Tuple[float, float], # start
Tuple[float, float], # stop
Tuple[int, int] # resolution
]
]

# computed/internal (will be serialized too unless excluded)
names: List[str] = Field(default_factory=list)
points: List[List[Tuple[float, float]]] = Field(default_factory=list)
resolution: List[Tuple[int, int]] = Field(default_factory=list)
length: np.ndarray = Field(default_factory=lambda: np.array([0]), exclude=False)
dist: np.ndarray = Field(default_factory=lambda: np.array([]), exclude=False)
df: Optional[pd.DataFrame] = Field(default_factory=None, exclude=False)
values: np.ndarray = Field(default_factory=lambda: np.empty((0, 3)), exclude=False)
extent: Optional[np.ndarray] = None
names: List[str] = Field(default_factory=list, exclude=True)
points: List[List[Tuple[float, float]]] = Field(default_factory=list, exclude=True)
resolution: List[Tuple[int, int]] = Field(default_factory=list, exclude=True)
length: np.ndarray = Field(default_factory=lambda: np.array([0]), exclude=True)
dist: np.ndarray = Field(default_factory=lambda: np.array([]), exclude=True)
df: Optional[pd.DataFrame] = Field(default=None, exclude=True)
values: np.ndarray = Field(default_factory=lambda: np.empty((0, 3)), exclude=True)
extent: Optional[np.ndarray] = Field(default=None, exclude=True)

# def __init__(self, regular_grid=None, z_ext=None, section_dict=None):
# pd = require_pandas()
Expand All @@ -75,12 +73,13 @@ class Sections:
# self.set_sections(section_dict)
def __post_init__(self):
self.initialize_computations()
pass

# @model_validator(mode="after")
# def init_class(self):
# self.initialize_computations()
# return self

#
def initialize_computations(self):
# copy names
self.names = list(self.section_dict.keys())
Expand Down