Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[CLN] Refactor grids into its own modules
  • Loading branch information
Leguark committed May 25, 2025
commit 9fb075e686b1c8677534e92dbdefffa5762ae7a8
4 changes: 3 additions & 1 deletion gempy/core/data/grid_modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .grid_types import Sections, RegularGrid, CustomGrid
from .regular_grid import RegularGrid
from .custom_grid import CustomGrid
from .sections_grid import Sections
from .topography import Topography
33 changes: 33 additions & 0 deletions gempy/core/data/grid_modules/custom_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import dataclasses
import numpy as np
from pydantic import Field


@dataclasses.dataclass
class CustomGrid:
"""Object that contains arbitrary XYZ coordinates.

Args:
xyx_coords (numpy.ndarray like): XYZ (in columns) of the desired coordinates

Attributes:
values (np.ndarray): XYZ coordinates
"""

values: np.ndarray = Field(
exclude=True,
default_factory=lambda: np.zeros((0, 3)),
repr=False
)


def __post_init__(self):
custom_grid = np.atleast_2d(self.values)
assert type(custom_grid) is np.ndarray and custom_grid.shape[1] == 3, \
'The shape of new grid must be (n,3) where n is the number of' \
' points of the grid'


@property
def length(self):
return self.values.shape[0]
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@

import numpy as np

from ..core_utils import calculate_line_coordinates_2points
from ..encoders.converters import numpy_array_short_validator
from .... import optional_dependencies
from ....optional_dependencies import require_pandas
from gempy_engine.core.data.transforms import Transform, TransformOpsOrder
from gempy_engine.core.data.transforms import Transform


@dataclasses.dataclass
Expand Down Expand Up @@ -255,141 +253,3 @@ def plot_rotation(regular_grid, pivot, point_x_axis, point_y_axis):
plt.show()


@dataclasses.dataclass
class Sections:
"""
Object that creates a grid of cross sections between two points.

Args:
regular_grid: Model.grid.regular_grid
section_dict: {'section name': ([p1_x, p1_y], [p2_x, p2_y], [xyres, zres])}
"""

def __init__(self, regular_grid=None, z_ext=None, section_dict=None):
pd = require_pandas()
if regular_grid is not None:
self.z_ext = regular_grid.extent[4:]
else:
self.z_ext = z_ext

self.section_dict = section_dict
self.names = []
self.points = []
self.resolution = []
self.length = [0]
self.dist = []
self.df = pd.DataFrame()
self.df['dist'] = self.dist
self.values = np.empty((0, 3))
self.extent = None

if section_dict is not None:
self.set_sections(section_dict)

def _repr_html_(self):
return self.df.to_html()

def __repr__(self):
return self.df.to_string()

def show(self):
pass

def set_sections(self, section_dict, regular_grid=None, z_ext=None):
pd = require_pandas()
self.section_dict = section_dict
if regular_grid is not None:
self.z_ext = regular_grid.extent[4:]

self.names = np.array(list(self.section_dict.keys()))

self.get_section_params()
self.calculate_all_distances()
self.df = pd.DataFrame.from_dict(self.section_dict, orient='index', columns=['start', 'stop', 'resolution'])
self.df['dist'] = self.dist

self.compute_section_coordinates()

def get_section_params(self):
self.points = []
self.resolution = []
self.length = [0]

for i, section in enumerate(self.names):
points = [self.section_dict[section][0], self.section_dict[section][1]]
assert points[0] != points[
1], 'The start and end points of the section must not be identical.'

self.points.append(points)
self.resolution.append(self.section_dict[section][2])
self.length = np.append(self.length, self.section_dict[section][2][0] *
self.section_dict[section][2][1])
self.length = np.array(self.length).cumsum()

def calculate_all_distances(self):
self.coordinates = np.array(self.points).ravel().reshape(-1,
4) # axis are x1,y1,x2,y2
self.dist = np.sqrt(np.diff(self.coordinates[:, [0, 2]]) ** 2 + np.diff(
self.coordinates[:, [1, 3]]) ** 2)

def compute_section_coordinates(self):
for i in range(len(self.names)):
xy = calculate_line_coordinates_2points(self.coordinates[i, :2],
self.coordinates[i, 2:],
self.resolution[i][0])
zaxis = np.linspace(self.z_ext[0], self.z_ext[1], self.resolution[i][1],
dtype="float64")
X, Z = np.meshgrid(xy[:, 0], zaxis, indexing='ij')
Y, _ = np.meshgrid(xy[:, 1], zaxis, indexing='ij')
xyz = np.vstack((X.flatten(), Y.flatten(), Z.flatten())).T
if i == 0:
self.values = xyz
else:
self.values = np.vstack((self.values, xyz))

def generate_axis_coord(self):
for i, name in enumerate(self.names):
xy = calculate_line_coordinates_2points(
self.coordinates[i, :2],
self.coordinates[i, 2:],
self.resolution[i][0]
)
yield name, xy

def get_section_args(self, section_name: str):
where = np.where(self.names == section_name)[0][0]
return self.length[where], self.length[where + 1]

def get_section_grid(self, section_name: str):
l0, l1 = self.get_section_args(section_name)
return self.values[l0:l1]


@dataclasses.dataclass
class CustomGrid:
"""Object that contains arbitrary XYZ coordinates.

Args:
xyx_coords (numpy.ndarray like): XYZ (in columns) of the desired coordinates

Attributes:
values (np.ndarray): XYZ coordinates
"""

values: np.ndarray = Field(
exclude=True,
default_factory=lambda: np.zeros((0, 3)),
repr=False
)


def __post_init__(self):
custom_grid = np.atleast_2d(self.values)
assert type(custom_grid) is np.ndarray and custom_grid.shape[1] == 3, \
'The shape of new grid must be (n,3) where n is the number of' \
' points of the grid'


@property
def length(self):
return self.values.shape[0]
115 changes: 115 additions & 0 deletions gempy/core/data/grid_modules/sections_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import dataclasses
import numpy as np

from gempy.core.data.core_utils import calculate_line_coordinates_2points
from gempy.optional_dependencies import require_pandas


@dataclasses.dataclass
class Sections:
"""
Object that creates a grid of cross sections between two points.

Args:
regular_grid: Model.grid.regular_grid
section_dict: {'section name': ([p1_x, p1_y], [p2_x, p2_y], [xyres, zres])}
"""

def __init__(self, regular_grid=None, z_ext=None, section_dict=None):
pd = require_pandas()
if regular_grid is not None:
self.z_ext = regular_grid.extent[4:]
else:
self.z_ext = z_ext

self.section_dict = section_dict
self.names = []
self.points = []
self.resolution = []
self.length = [0]
self.dist = []
self.df = pd.DataFrame()
self.df['dist'] = self.dist
self.values = np.empty((0, 3))
self.extent = None

if section_dict is not None:
self.set_sections(section_dict)

def _repr_html_(self):
return self.df.to_html()

def __repr__(self):
return self.df.to_string()

def show(self):
pass

def set_sections(self, section_dict, regular_grid=None, z_ext=None):
pd = require_pandas()
self.section_dict = section_dict
if regular_grid is not None:
self.z_ext = regular_grid.extent[4:]

self.names = np.array(list(self.section_dict.keys()))

self.get_section_params()
self.calculate_all_distances()
self.df = pd.DataFrame.from_dict(self.section_dict, orient='index', columns=['start', 'stop', 'resolution'])
self.df['dist'] = self.dist

self.compute_section_coordinates()

def get_section_params(self):
self.points = []
self.resolution = []
self.length = [0]

for i, section in enumerate(self.names):
points = [self.section_dict[section][0], self.section_dict[section][1]]
assert points[0] != points[
1], 'The start and end points of the section must not be identical.'

self.points.append(points)
self.resolution.append(self.section_dict[section][2])
self.length = np.append(self.length, self.section_dict[section][2][0] *
self.section_dict[section][2][1])
self.length = np.array(self.length).cumsum()

def calculate_all_distances(self):
self.coordinates = np.array(self.points).ravel().reshape(-1,
4) # axis are x1,y1,x2,y2
self.dist = np.sqrt(np.diff(self.coordinates[:, [0, 2]]) ** 2 + np.diff(
self.coordinates[:, [1, 3]]) ** 2)

def compute_section_coordinates(self):
for i in range(len(self.names)):
xy = calculate_line_coordinates_2points(self.coordinates[i, :2],
self.coordinates[i, 2:],
self.resolution[i][0])
zaxis = np.linspace(self.z_ext[0], self.z_ext[1], self.resolution[i][1],
dtype="float64")
X, Z = np.meshgrid(xy[:, 0], zaxis, indexing='ij')
Y, _ = np.meshgrid(xy[:, 1], zaxis, indexing='ij')
xyz = np.vstack((X.flatten(), Y.flatten(), Z.flatten())).T
if i == 0:
self.values = xyz
else:
self.values = np.vstack((self.values, xyz))

def generate_axis_coord(self):
for i, name in enumerate(self.names):
xy = calculate_line_coordinates_2points(
self.coordinates[i, :2],
self.coordinates[i, 2:],
self.resolution[i][0]
)
yield name, xy

def get_section_args(self, section_name: str):
where = np.where(self.names == section_name)[0][0]
return self.length[where], self.length[where + 1]

def get_section_grid(self, section_name: str):
l0, l1 = self.get_section_args(section_name)
return self.values[l0:l1]
2 changes: 1 addition & 1 deletion gempy/core/data/grid_modules/topography.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from .grid_types import RegularGrid
from .regular_grid import RegularGrid
from ....modules.grids.create_topography import _LoadDEMArtificial

from ....optional_dependencies import require_skimage
Expand Down