Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[ENH] Refactor nugget optimization for modularity and group handling
Streamlined nugget optimization logic by delegating it to a standalone `run_optimization` function, improving reusability and clarity. Added support for optimizing specific structural groups, enhancing flexibility in nugget effect optimization workflows.
  • Loading branch information
Leguark committed May 26, 2025
commit c85d5f6eb9d3f83a4a8bacfd78995cc5f3bc29ca
25 changes: 17 additions & 8 deletions gempy/API/compute_API.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import dotenv
import numpy as np
import os

from typing import Optional

import numpy as np

import gempy_engine
from gempy_engine.core.backend_tensor import BackendTensor
from gempy.API.gp2_gp3_compatibility.gp3_to_gp2_input import gempy3_to_gempy2
from gempy_engine.config import AvailableBackends
from gempy_engine.core.backend_tensor import BackendTensor
from gempy_engine.core.data import Solutions
from gempy_engine.core.data.interpolation_input import InterpolationInput
from .grid_API import set_custom_grid
from ..core.data import StructuralGroup
from ..core.data.gempy_engine_config import GemPyEngineConfig
from ..core.data.geo_model import GeoModel
from ..modules.data_manipulation import interpolation_input_from_structural_frame
Expand Down Expand Up @@ -96,18 +94,29 @@ def compute_model_at(gempy_model: GeoModel, at: np.ndarray,
return sol.raw_arrays.custom


def optimize_and_compute(geo_model: GeoModel, engine_config: GemPyEngineConfig, max_epochs: int = 10,
convergence_criteria: float = 1e5):
def optimize_nuggets(geo_model: GeoModel, engine_config: GemPyEngineConfig, max_epochs: int = 10,
convergence_criteria: float = 1e5, only_groups:list[StructuralGroup] | None = None) -> GeoModel:
"""
Optimize the nuggets of the interpolation input of the provided model.
"""

if engine_config.backend != AvailableBackends.PYTORCH:
raise ValueError(f'Only PyTorch backend is supported for optimization. Received {engine_config.backend}')

geo_model = nugget_optimizer(
target_cond_num=convergence_criteria,
engine_cfg=engine_config,
model=geo_model,
max_epochs=max_epochs,
only_groups=only_groups
)

return geo_model

def optimize_and_compute(geo_model: GeoModel, engine_config: GemPyEngineConfig, max_epochs: int = 10,
convergence_criteria: float = 1e5):

optimize_nuggets(geo_model, engine_config, max_epochs, convergence_criteria)
geo_model.solutions = gempy_engine.compute_model(
interpolation_input=geo_model.taped_interpolation_input,
options=geo_model.interpolation_options,
Expand Down
2 changes: 1 addition & 1 deletion gempy/core/data/geo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from gempy_engine.core.data.interpolation_input import InterpolationInput
from gempy_engine.core.data.raw_arrays_solution import RawArraysSolution
from gempy_engine.core.data.transforms import Transform, GlobalAnisotropy
from gempy_engine.modules.geophysics.gravity_gradient import calculate_gravity_gradient
from .encoders.converters import instantiate_if_necessary
from .encoders.json_geomodel_encoder import encode_numpy_array
from .grid import Grid
Expand Down Expand Up @@ -319,6 +318,7 @@ def deserialize_properties(cls, data: Union["GeoModel", dict], constructor: Mode
# * Reset geophysics if necessary
centered_grid = instance.grid.centered_grid
if centered_grid is not None and instance.geophysics_input is not None:
from gempy_engine.modules.geophysics.gravity_gradient import calculate_gravity_gradient
instance.geophysics_input.tz = calculate_gravity_gradient(centered_grid)

return instance
Expand Down
91 changes: 91 additions & 0 deletions gempy/modules/optimize_nuggets/_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import torch

import gempy_engine
from gempy_engine.core.data.continue_epoch import ContinueEpoch


def run_optimization(lr, max_epochs, min_impr, model, nugget, patience, target_cond_num):
opt = torch.optim.Adam(
params=[
nugget,
],
lr=lr
)
prev_cond = float('inf')
for epoch in range(max_epochs):
opt.zero_grad()
try:
gempy_engine.compute_model(
interpolation_input=model.taped_interpolation_input,
options=model.interpolation_options,
data_descriptor=model.input_data_descriptor,
geophysics_input=model.geophysics_input,
)
except ContinueEpoch:
if True:
# Keep only top 10% gradients
_gradient_masking(nugget, focus=0.01)
else:
if epoch % 1 == 0:
mask_sp = _mask_iqr(nugget.grad.abs().view(-1), multiplier=3)
print(f"Outliers sp: {torch.nonzero(mask_sp)}")
_apply_outlier_gradients(tensor=nugget, mask=mask_sp)

# Step & clamp safely
opt.step()
with torch.no_grad():
nugget.clamp_(min=1e-7)

# Evaluate condition number
cur_cond = model.interpolation_options.kernel_options.condition_number
print(f"[Epoch {epoch}] cond. num. = {cur_cond:.2e}")

if _has_converged(cur_cond, prev_cond, target_cond_num, epoch, min_impr, patience):
break
prev_cond = cur_cond

return nugget


def _mask_iqr(grads, multiplier: float = 1.5) -> torch.BoolTensor:
q1, q3 = grads.quantile(0.25), grads.quantile(0.75)
thresh = q3 + multiplier * (q3 - q1)
return grads > thresh

def _apply_outlier_gradients(
tensor: torch.Tensor,
mask: torch.BoolTensor,
amplification: float = 1.0,
):
# wrap in no_grad if you prefer, but .grad modifications are fine
tensor.grad.view(-1)[mask] *= amplification
tensor.grad.view(-1)[~mask] = 0



def _gradient_masking(nugget, focus=0.01):
"""Old way of avoiding exploding gradients."""
grads = nugget.grad.abs()
k = int(grads.numel() * focus)
top_vals, top_idx = torch.topk(grads, k, largest=True)
mask = torch.zeros_like(grads)
mask[top_idx] = 1
nugget.grad.mul_(mask)


def _has_converged(
current: float,
previous: float,
target: float = 1e5,
epoch: int = 0,
min_improvement: float = 0.01,
patience: int = 10,
) -> bool:
if current < target:
return True
if epoch > patience:
# relative improvement
rel_impr = abs(current - previous) / max(previous, 1e-8)
return rel_impr < min_improvement
return False

Loading