Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[ENH] Add default ColorsGenerator and improve condition handling
Introduce `Field(default_factory=ColorsGenerator)` to `StructuralFrame` for streamlined initialization. Convert condition number to numpy array in optimization logic for compatibility. Minor adjustments made for improved model validation and notes added for future enhancements.
  • Loading branch information
Leguark committed May 26, 2025
commit 99bfdb79f8c6d02646bc4a2d24fadb790b4d5f77
2 changes: 1 addition & 1 deletion gempy/core/color_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ColorsGenerator:
)
_index: int = 0

def __init__(self):
def __post_init__(self):
self.regenerate_color_palette()

@staticmethod
Expand Down
5 changes: 3 additions & 2 deletions gempy/core/data/structural_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import warnings
from dataclasses import dataclass
from pydantic import model_validator, computed_field, ValidationError
from pydantic import model_validator, computed_field, ValidationError, Field
from pydantic.functional_validators import ModelWrapValidatorHandler
from typing import Generator, Union

Expand Down Expand Up @@ -32,11 +32,12 @@ class StructuralFrame:
"""

structural_groups: list[StructuralGroup]
color_generator: ColorsGenerator = Field(default_factory=ColorsGenerator)
# ? Should I create some sort of structural options class? For example, the masking descriptor and faults relations pointer
is_dirty: bool = True

# region Constructor

#
def __init__(self, structural_groups: list[StructuralGroup], color_gen: ColorsGenerator):
self.structural_groups = structural_groups # ? This maybe could be optional
self.color_generator = color_gen
Expand Down
4 changes: 3 additions & 1 deletion gempy/modules/optimize_nuggets/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def run_optimization(lr, max_epochs, min_impr, model, nugget, patience, target_c
if _has_converged(cur_cond, prev_cond, target_cond_num, epoch, min_impr, patience):
break
prev_cond = cur_cond


# Condition number to numpy
model.interpolation_options.kernel_options.condition_number = model.interpolation_options.kernel_options.condition_number.detach().numpy()
return nugget


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ def test_optimize_nugget_effect():
engine_config=gp.data.GemPyEngineConfig(
backend=gp.data.AvailableBackends.PYTORCH,
),
validate_serialization=False
validate_serialization=True
)

# TODO: Save model

print(f"Final cond number: {geo_model.interpolation_options.kernel_options.condition_number}")
nugget_effect = geo_model.taped_interpolation_input.surface_points.nugget_effect_scalar.detach().numpy()
Expand Down