Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[WIP] Playing with different optimizers
  • Loading branch information
Leguark committed May 26, 2025
commit e9911cda9b52cfabd0b40445169bb5402adf63d8
124 changes: 118 additions & 6 deletions gempy/modules/optimize_nuggets/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def nugget_optimizer(
lr: float = 1e-2,
patience: int = 10,
min_impr: float = 0.01,
) -> float:
) -> GeoModel:
"""
Optimize the nugget effect scalar to achieve a target condition number.
Returns the final nugget effect value.
Expand Down Expand Up @@ -59,9 +59,25 @@ def nugget_optimizer(
except ContinueEpoch:
# Keep only top 10% gradients
if False:
_gradient_masking(nugget)
_gradient_masking(
nugget=nugget,
focus=0.01
)
elif True:
if epoch % 5 == 0:
# if True:
grads = nugget.grad.abs().view(-1)
q1, q3 = grads.quantile(0.25), grads.quantile(0.75)
iqr = q3 - q1
thresh = q3 + 1.5 * iqr
mask = grads > thresh

# print the indices of mask
print(f"Outliers: {torch.nonzero(mask)}")

_gradient_foo(nugget_effect_scalar=nugget, mask=mask)
else:
clip_grad_norm_(parameters=[nugget], max_norm=1.0)
clip_grad_norm_(parameters=[nugget], max_norm=0.0001)

# Step & clamp safely
opt.step()
Expand All @@ -77,13 +93,20 @@ def nugget_optimizer(
prev_cond = cur_cond

model.interpolation_options.kernel_options.optimizing_condition_number = False
return nugget.item()
return model


def _gradient_masking(nugget):
def _gradient_foo(nugget_effect_scalar: torch.Tensor, mask):

# amplify outliers if you want bigger jumps
nugget_effect_scalar.grad[mask] *= 5.0
# zero all other gradients
nugget_effect_scalar.grad[~mask] = 0

def _gradient_masking(nugget, focus = 0.01):
"""Old way of avoiding exploding gradients."""
grads = nugget.grad.abs()
k = int(grads.numel() * 0.1)
k = int(grads.numel() * focus)
top_vals, top_idx = torch.topk(grads, k, largest=True)
mask = torch.zeros_like(grads)
mask[top_idx] = 1
Expand All @@ -105,3 +128,92 @@ def _has_converged(
rel_impr = abs(current - previous) / max(previous, 1e-8)
return rel_impr < min_improvement
return False


# region legacy
def nugget_optimizer__legacy(target_cond_num, engine_cfg, model, max_epochs):
geo_model: GeoModel = model
convergence_criteria = target_cond_num
engine_config = engine_cfg

BackendTensor.change_backend_gempy(
engine_backend=engine_config.backend,
use_gpu=engine_config.use_gpu,
dtype=engine_config.dtype
)
import torch
from gempy_engine.core.data.continue_epoch import ContinueEpoch

interpolation_input: InterpolationInput = interpolation_input_from_structural_frame(geo_model)
geo_model.taped_interpolation_input = interpolation_input
nugget_effect_scalar: torch.Tensor = geo_model.taped_interpolation_input.surface_points.nugget_effect_scalar
nugget_effect_scalar.requires_grad = True
optimizer = torch.optim.Adam(
params=[nugget_effect_scalar],
lr=0.01,
)
# Optimization loop
geo_model.interpolation_options.kernel_options.optimizing_condition_number = True

previous_condition_number = 0
for epoch in range(max_epochs):
optimizer.zero_grad()
try:
# geo_model.taped_interpolation_input.grid = geo_model.interpolation_input_copy.grid

gempy_engine.compute_model(
interpolation_input=geo_model.taped_interpolation_input,
options=geo_model.interpolation_options,
data_descriptor=geo_model.input_data_descriptor,
geophysics_input=geo_model.geophysics_input,
)
except ContinueEpoch:
# Get absolute values of gradients
grad_magnitudes = torch.abs(nugget_effect_scalar.grad)

# Get indices of the 10 largest gradients
grad_magnitudes.size

# * This ignores 90 percent of the gradients
# To int
n_values = int(grad_magnitudes.size()[0] * 0.9)
_, indices = torch.topk(grad_magnitudes, n_values, largest=False)

# Zero out gradients that are not in the top 10
mask = torch.ones_like(nugget_effect_scalar.grad)
mask[indices] = 0
nugget_effect_scalar.grad *= mask

# Update the vector
optimizer.step()
nugget_effect_scalar.data = nugget_effect_scalar.data.clamp_(min=1e-7) # Replace negative values with 0

# optimizer.zero_grad()
# Monitor progress
if epoch % 1 == 0:
# print(f"Epoch {epoch}: Condition Number = {condition_number.item()}")
print(f"Epoch {epoch}")

if _check_convergence_criterion(
conditional_number=geo_model.interpolation_options.kernel_options.condition_number,
condition_number_old=previous_condition_number,
conditional_number_target=convergence_criteria,
epoch=epoch
):
break
previous_condition_number = geo_model.interpolation_options.kernel_options.condition_number
continue
geo_model.interpolation_options.kernel_options.optimizing_condition_number = False
return geo_model


def _check_convergence_criterion(conditional_number: float, condition_number_old: float, conditional_number_target: float = 1e5, epoch: int = 0):
import torch
reached_conditional_target = conditional_number < conditional_number_target
if reached_conditional_target == False and epoch > 10:
condition_number_change = torch.abs(conditional_number - condition_number_old) / condition_number_old
if condition_number_change < 0.01:
reached_conditional_target = True
return reached_conditional_target

# endregion
11 changes: 8 additions & 3 deletions test/test_private/test_terranigma/test_nuggets/_aux_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,22 @@ def initialize_geo_model(structural_elements: list[gp.data.StructuralElement], e
structural_relation=gp.data.StackRelationType.ERODE
)

structural_groups = [structural_group_intrusion, structural_group_green, structural_group_blue, structural_group_red]
structural_groups = [
# structural_group_intrusion,
# structural_group_green,
# structural_group_blue,
structural_group_red
]
structural_frame = gp.data.StructuralFrame(
structural_groups=structural_groups[2:],
structural_groups=structural_groups,
color_gen=gp.data.ColorsGenerator()
)
# TODO: If elements do not have color maybe loop them on structural frame constructor?

geo_model: gp.data.GeoModel = gp.create_geomodel(
project_name='Tutorial_ch1_1_Basics',
extent=extent,
resolution=[20, 10, 20],
# resolution=[20, 10, 20],
refinement=5, # * Here we define the number of octree levels. If octree levels are defined, the resolution is ignored.
structural_frame=structural_frame
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

dotenv.load_dotenv()


# %%
# Config
seed = 123456
Expand All @@ -27,44 +26,57 @@
# Load necessary configuration and paths from environment variables
path = os.getenv("PATH_TO_NUGGET_TEST_MODEL")


def test_optimize_nugget_effect():
# Initialize lists to store structural elements for the geological model
structural_elements = []
global_extent = None
color_gen = gp.data.ColorsGenerator()

for filename in os.listdir(path):
base, ext = os.path.splitext(filename)
if ext == '.nc':
structural_element, global_extent = process_file(os.path.join(path, filename), global_extent, color_gen)
structural_elements.append(structural_element)


import xarray as xr
geo_model: gp.data.GeoModel = initialize_geo_model(
structural_elements=structural_elements,
extent=(np.array(global_extent)),
topography=(xr.open_dataset(os.path.join(path, "Topography.nc")))
)

if False:
gpv.plot_3d(geo_model, show_data=True, image=True)


geo_model.interpolation_options.cache_mode = gp.data.InterpolationOptions.CacheMode.NO_CACHE
gp.API.compute_API.optimize_and_compute(
geo_model=geo_model,
engine_config=gp.data.GemPyEngineConfig(
backend=gp.data.AvailableBackends.PYTORCH,
),
max_epochs=100,
convergence_criteria=1e5
)

nugget_effect = geo_model.taped_interpolation_input.surface_points.nugget_effect_scalar.detach().numpy()
if True:
gp.API.compute_API.optimize_and_compute(
geo_model=geo_model,
engine_config=gp.data.GemPyEngineConfig(
backend=gp.data.AvailableBackends.PYTORCH,
),
max_epochs=100,
convergence_criteria=1e5
)

print(f"Final cond number: {geo_model.interpolation_options.kernel_options.condition_number}")
nugget_effect = geo_model.taped_interpolation_input.surface_points.nugget_effect_scalar.detach().numpy()

else:
gp.compute_model(
gempy_model=geo_model,
engine_config=gp.data.GemPyEngineConfig(
backend=gp.data.AvailableBackends.PYTORCH,
),
validate_serialization=False
)


nugget_effect = geo_model.taped_interpolation_input.surface_points.nugget_effect_scalar.detach().numpy()

if plot_evaluation:=True:
if plot_evaluation := True:
import matplotlib.pyplot as plt

plt.hist(nugget_effect, bins=50, color='black', alpha=0.7, log=True)
Expand All @@ -73,15 +85,15 @@ def test_optimize_nugget_effect():
plt.title('Histogram of Eigenvalues (nugget-grad)')
plt.show()


if plot_result:=True:
if plot_result := True:
import gempy_viewer as gpv
import pyvista as pv

gempy_vista = gpv.plot_3d(
model=geo_model,
show=False,
show_boundaries=True,
show_topography=False,
kwargs_plot_structured_grid={'opacity': 0.3}
)

Expand Down