Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[CLN] Add regions to structuctural_frame
  • Loading branch information
Leguark committed May 23, 2025
commit 52dd88fca9c0e96f10d07f3952ba80bc49afcc15
279 changes: 146 additions & 133 deletions gempy/core/data/structural_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,144 +34,12 @@ class StructuralFrame:
# ? Should I create some sort of structural options class? For example, the masking descriptor and faults relations pointer
is_dirty: bool = True

@model_validator(mode="wrap")
@classmethod
def deserialize_binary(cls, data: Union["StructuralFrame", dict], constructor: ModelWrapValidatorHandler["StructuralFrame"]) -> "StructuralFrame":
match data:
case StructuralFrame():
return data
case dict():
instance: StructuralFrame = constructor(data)
metadata = data.get('binary_meta_data', {})

context = loading_model_context.get()

if 'binary_body' not in context:
return instance

binary_array = context['binary_body']

sp_binary = binary_array[:metadata["sp_binary_length"]]
ori_binary = binary_array[metadata["sp_binary_length"]:]

# Reconstruct arrays
sp_data: np.ndarray = np.frombuffer(sp_binary, dtype=SurfacePointsTable.dt)
ori_data: np.ndarray = np.frombuffer(ori_binary, dtype=OrientationsTable.dt)

instance.surface_points = SurfacePointsTable(
data=sp_data,
name_id_map=instance.surface_points_copy.name_id_map
)

instance.orientations = OrientationsTable(
data=ori_data,
name_id_map=instance.orientations_copy.name_id_map
)

return instance
case _:
raise ValidationError(f"Invalid data type for StructuralFrame: {type(data)}")

# Access the context variable to get injected data

@model_validator(mode="after")
def deserialize_surface_points(self: "StructuralFrame"):
# Access the context variable to get injected data
context = loading_model_context.get()

if 'surface_points_binary' not in context:
return self

# Check if we have a binary payload to digest
binary_array = context['surface_points_binary']
if not isinstance(binary_array, np.ndarray):
return self
if binary_array.shape[0] < 1:
return self

self.surface_points = SurfacePointsTable(
data=binary_array,
name_id_map=self.surface_points_copy.name_id_map
)

return self

@model_validator(mode="after")
def deserialize_orientations(self: "StructuralFrame"):
# TODO: Check here the binary size of surface_points_binary

# Access the context variable to get injected data
context = loading_model_context.get()
if 'orientations_binary' not in context:
return self

# Check if we have a binary payload to digest
binary_array = context['orientations_binary']
if not isinstance(binary_array, np.ndarray):
return self

self.orientations = OrientationsTable(
data=binary_array,
name_id_map=self.orientations_copy.name_id_map
)

return self

@computed_field
def binary_meta_data(self) -> dict:
sp_data = self.surface_points_copy.data
ori_data = self.orientations_copy.data
return {
'sp_shape' : sp_data.shape,
'sp_dtype' : str(sp_data.dtype),
'sp_binary_length' : len(sp_data.tobytes()),
'ori_shape' : ori_data.shape,
'ori_dtype' : str(ori_data.dtype),
'ori_binary_length': len(ori_data.tobytes())
}

@computed_field
@property
def serialize_sp(self) -> int:
return int(hashlib.md5(self.surface_points_copy.data.tobytes()).hexdigest()[:8], 16)

@computed_field
@property
def serialize_orientations(self) -> int:
return int(hashlib.md5(self.orientations_copy.data.tobytes()).hexdigest()[:8], 16)
# region Constructor

def __init__(self, structural_groups: list[StructuralGroup], color_gen: ColorsGenerator):
self.structural_groups = structural_groups # ? This maybe could be optional
self.color_generator = color_gen

def get_element_by_name(self, element_name: str) -> StructuralElement:
elements: Generator = (group.get_element_by_name(element_name) for group in self.structural_groups)
valid_elements: Generator = (element for element in elements if element is not None)
element = next(valid_elements, None)
if element is None:
raise ValueError(f"Element with name {element_name} not found in the structural frame.")
return element

def get_group_by_name(self, group_name: str) -> StructuralGroup:
groups: Generator = (group for group in self.structural_groups if group.name == group_name)
group = next(groups, None)
if group is None:
raise ValueError(f"Group with name {group_name} not found in the structural frame.")
return group

def get_group_by_element(self, element: StructuralElement) -> StructuralGroup:
groups: Generator = (group for group in self.structural_groups if element in group.elements)
group = next(groups, None)
if group is None:
raise ValueError(f"Element {element.name} not found in any group in the structural frame.")
return group

def append_group(self, group: StructuralGroup):
self.structural_groups.append(group)

def insert_group(self, index: int, group: StructuralGroup):
self.structural_groups.insert(index, group)

@classmethod
def from_data_tables(cls, surface_points: SurfacePointsTable, orientations: OrientationsTable):
surface_points_groups: list[SurfacePointsTable] = surface_points.get_surface_points_by_id_groups()
Expand Down Expand Up @@ -245,6 +113,37 @@ def initialize_default_structure(cls) -> 'StructuralFrame':

return structural_frame

# endregion

# region Methods
def get_element_by_name(self, element_name: str) -> StructuralElement:
elements: Generator = (group.get_element_by_name(element_name) for group in self.structural_groups)
valid_elements: Generator = (element for element in elements if element is not None)
element = next(valid_elements, None)
if element is None:
raise ValueError(f"Element with name {element_name} not found in the structural frame.")
return element

def get_group_by_name(self, group_name: str) -> StructuralGroup:
groups: Generator = (group for group in self.structural_groups if group.name == group_name)
group = next(groups, None)
if group is None:
raise ValueError(f"Group with name {group_name} not found in the structural frame.")
return group

def get_group_by_element(self, element: StructuralElement) -> StructuralGroup:
groups: Generator = (group for group in self.structural_groups if element in group.elements)
group = next(groups, None)
if group is None:
raise ValueError(f"Element {element.name} not found in any group in the structural frame.")
return group

def append_group(self, group: StructuralGroup):
self.structural_groups.append(group)

def insert_group(self, index: int, group: StructuralGroup):
self.structural_groups.insert(index, group)

def __repr__(self):
structural_groups_repr = ',\n'.join([repr(g) for g in self.structural_groups])
fault_relations_str = np.array2string(self.fault_relations, precision=2, separator=', ', suppress_small=True) if self.fault_relations is not None else 'None'
Expand Down Expand Up @@ -285,6 +184,9 @@ def _repr_html_(self):
"""
return html

# endregion

# region Properties
@property
def structural_elements(self) -> list[StructuralElement]:
"""Returns a list of all structural elements across the structural groups."""
Expand Down Expand Up @@ -555,6 +457,117 @@ def surfaces_df(self) -> 'pd.DataFrame':

# endregion

# endregion
# region Pydantic

@model_validator(mode="wrap")
@classmethod
def deserialize_binary(cls, data: Union["StructuralFrame", dict], constructor: ModelWrapValidatorHandler["StructuralFrame"]) -> "StructuralFrame":
match data:
case StructuralFrame():
return data
case dict():
instance: StructuralFrame = constructor(data)
metadata = data.get('binary_meta_data', {})

context = loading_model_context.get()

if 'binary_body' not in context:
return instance

binary_array = context['binary_body']

sp_binary = binary_array[:metadata["sp_binary_length"]]
ori_binary = binary_array[metadata["sp_binary_length"]:]

# Reconstruct arrays
sp_data: np.ndarray = np.frombuffer(sp_binary, dtype=SurfacePointsTable.dt)
ori_data: np.ndarray = np.frombuffer(ori_binary, dtype=OrientationsTable.dt)

instance.surface_points = SurfacePointsTable(
data=sp_data,
name_id_map=instance.surface_points_copy.name_id_map
)

instance.orientations = OrientationsTable(
data=ori_data,
name_id_map=instance.orientations_copy.name_id_map
)

return instance
case _:
raise ValidationError(f"Invalid data type for StructuralFrame: {type(data)}")

# Access the context variable to get injected data

@model_validator(mode="after")
def deserialize_surface_points(self: "StructuralFrame"):
# Access the context variable to get injected data
context = loading_model_context.get()

if 'surface_points_binary' not in context:
return self

# Check if we have a binary payload to digest
binary_array = context['surface_points_binary']
if not isinstance(binary_array, np.ndarray):
return self
if binary_array.shape[0] < 1:
return self

self.surface_points = SurfacePointsTable(
data=binary_array,
name_id_map=self.surface_points_copy.name_id_map
)

return self

@model_validator(mode="after")
def deserialize_orientations(self: "StructuralFrame"):
# TODO: Check here the binary size of surface_points_binary

# Access the context variable to get injected data
context = loading_model_context.get()
if 'orientations_binary' not in context:
return self

# Check if we have a binary payload to digest
binary_array = context['orientations_binary']
if not isinstance(binary_array, np.ndarray):
return self

self.orientations = OrientationsTable(
data=binary_array,
name_id_map=self.orientations_copy.name_id_map
)

return self

@computed_field
def binary_meta_data(self) -> dict:
sp_data = self.surface_points_copy.data
ori_data = self.orientations_copy.data
return {
'sp_shape' : sp_data.shape,
'sp_dtype' : str(sp_data.dtype),
'sp_binary_length' : len(sp_data.tobytes()),
'ori_shape' : ori_data.shape,
'ori_dtype' : str(ori_data.dtype),
'ori_binary_length': len(ori_data.tobytes())
}

@computed_field
@property
def serialize_sp(self) -> int:
return int(hashlib.md5(self.surface_points_copy.data.tobytes()).hexdigest()[:8], 16)

@computed_field
@property
def serialize_orientations(self) -> int:
return int(hashlib.md5(self.orientations_copy.data.tobytes()).hexdigest()[:8], 16)

# endregion

def _validate_faults_relations(self):
"""Check that if there are any StackRelationType.FAULT in the structural groups the fault relation matrix is
given and shape is the right one, i.e. a square matrix of size equals to len(groups)"""
Expand Down