Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[WIP] Better configuration of parameters
  • Loading branch information
Leguark committed May 26, 2025
commit d1efe2866694ac8ddaf01a29fe3ccbe7bd3089ba
54 changes: 27 additions & 27 deletions gempy/modules/optimize_nuggets/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
torch = None



def nugget_optimizer(
target_cond_num: float,
engine_cfg: GemPyEngineConfig,
model: GeoModel,
max_epochs: int,
lr: float = 1e-2,
lr: float = .1,
patience: int = 10,
min_impr: float = 0.01,
) -> GeoModel:
Expand All @@ -40,9 +39,12 @@ def nugget_optimizer(
interp_in: InterpolationInput = interpolation_input_from_structural_frame(model)
model.taped_interpolation_input = interp_in
nugget: torch.Tensor = interp_in.surface_points.nugget_effect_scalar
nugget_ori: torch.Tensor = interp_in.orientations.nugget_effect_grad

nugget.requires_grad_(True)
nugget_ori.requires_grad_(True)

opt = torch.optim.Adam(params=[nugget], lr=lr)
opt = torch.optim.Adam(params=[nugget, nugget_ori], lr=lr)

model.interpolation_options.kernel_options.optimizing_condition_number = True

Expand All @@ -58,26 +60,16 @@ def nugget_optimizer(
)
except ContinueEpoch:
# Keep only top 10% gradients
if False:
_gradient_masking(
nugget=nugget,
focus=0.01
)
elif True:
if epoch % 5 == 0:
# if True:
grads = nugget.grad.abs().view(-1)
q1, q3 = grads.quantile(0.25), grads.quantile(0.75)
iqr = q3 - q1
thresh = q3 + 1.5 * iqr
mask = grads > thresh

# print the indices of mask
print(f"Outliers: {torch.nonzero(mask)}")

_gradient_foo(nugget_effect_scalar=nugget, mask=mask)
else:
clip_grad_norm_(parameters=[nugget], max_norm=0.0001)
if epoch % 5 == 0:
mask_sp = _mask_iqr(nugget.grad.abs().view(-1))
mask_ori = _mask_iqr(nugget_ori.grad.abs().view(-1))

# print the indices of mask
print(f"Outliers sp: {torch.nonzero(mask_sp)}")
print(f"Outliers ori: {torch.nonzero(mask_ori)}")

_gradient_foo(nugget_effect_scalar=nugget, mask=mask_sp)
_gradient_foo(nugget_effect_scalar=nugget_ori, mask=mask_ori)

# Step & clamp safely
opt.step()
Expand All @@ -96,14 +88,22 @@ def nugget_optimizer(
return model


def _mask_iqr(grads):
q1, q3 = grads.quantile(0.25), grads.quantile(0.75)
iqr = q3 - q1
thresh = q3 + 1.5 * iqr
mask = grads > thresh
return mask


def _gradient_foo(nugget_effect_scalar: torch.Tensor, mask):

# amplify outliers if you want bigger jumps
nugget_effect_scalar.grad[mask] *= 5.0
# zero all other gradients
nugget_effect_scalar.grad[~mask] = 0

def _gradient_masking(nugget, focus = 0.01):

def _gradient_masking(nugget, focus=0.01):
"""Old way of avoiding exploding gradients."""
grads = nugget.grad.abs()
k = int(grads.numel() * focus)
Expand Down Expand Up @@ -135,15 +135,15 @@ def nugget_optimizer__legacy(target_cond_num, engine_cfg, model, max_epochs):
geo_model: GeoModel = model
convergence_criteria = target_cond_num
engine_config = engine_cfg

BackendTensor.change_backend_gempy(
engine_backend=engine_config.backend,
use_gpu=engine_config.use_gpu,
dtype=engine_config.dtype
)
import torch
from gempy_engine.core.data.continue_epoch import ContinueEpoch

interpolation_input: InterpolationInput = interpolation_input_from_structural_frame(geo_model)
geo_model.taped_interpolation_input = interpolation_input
nugget_effect_scalar: torch.Tensor = geo_model.taped_interpolation_input.surface_points.nugget_effect_scalar
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,15 @@ def test_optimize_nugget_effect():
point_size=25,
)

if False:
ori_cloud = pv.PolyData(geo_model.orientations_copy.df[['X', 'Y', 'Z']].to_numpy())
ori_cloud['values2'] = geo_model.taped_interpolation_input.orientations.nugget_effect_grad.detach().numpy()

gempy_vista.p.add_mesh(
ori_cloud,
scalars='values2',
cmap='viridis',
point_size=20,
)

gempy_vista.p.show()