Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v5
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: '3.10'
python-version: '3.12'
- name: Install Poetry
run: |
python -m pip install --upgrade pip
Expand Down
6 changes: 4 additions & 2 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: chartboost/ruff-action@v1
- uses: actions/checkout@v5
- uses: astral-sh/ruff-action@v3
- run: ruff check --fix
- run: ruff format
8 changes: 4 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"] # Add more Python versions here, e.g. "3.9", "3.11"
python-version: ["3.10", "3.12"]

steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v5
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Install Poetry
Expand All @@ -35,7 +35,7 @@ jobs:
run: |
.venv/bin/pytest --cov-report=xml
- name: Upload coverage report to Codecov
uses: codecov/codecov-action@v4
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: ./coverage.xml
Expand Down
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ exclude: |
)$
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v6.0.0
hooks:
- id: check-merge-conflict
- id: check-added-large-files
Expand All @@ -18,8 +18,8 @@ repos:
- id: check-yaml
- id: requirements-txt-fixer
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.1.8'
rev: 'v0.14.0'
hooks:
- id: ruff
- id: ruff-check
args: [ --fix ]
- id: ruff-format
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@
Install the `jax-sph` library from PyPI as

```bash
python3.10 -m venv venv
python3.12 -m venv venv
source venv/bin/activate
pip install jax-sph
```

By default `jax-sph` is installed without GPU support. If you have a CUDA-capable GPU, follow the instructions in the [GPU support](#gpu-support) section.

### Clone
We recommend using a `poetry` or `python3-venv` environment.
We recommend using a `poetry` or `python3-venv` environment. Last tested with python3.10 + jax[cpu]==0.6.2 and python3.12 + jax[cpu]==0.7.2.

**Using Poetry**
```bash
Expand All @@ -51,7 +51,7 @@ Later, you just need to `source .venv/bin/activate` to activate the environment.

**Using `python3-venv`**
```bash
python3 -m venv venv
python3.12 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
pip install -e . # to install jax_sph in interactive mode
Expand All @@ -62,7 +62,7 @@ Later, you just need to `source venv/bin/activate` to activate the environment.
If you want to use a CUDA GPU, you first need a running Nvidia driver. And then just follow the instructions [here](https://jax.readthedocs.io/en/latest/installation.html). The whole process could look like this:
```bash
source .venv/bin/activate
pip install -U "jax[cuda12]==0.4.29"
pip install -U "jax[cuda12]" # specify version as "jax[cuda12]==0.6.2"
```

## Getting Started
Expand Down
1 change: 0 additions & 1 deletion cases/ht.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"Heat Transfer over a flat plate setup"


import jax.numpy as jnp
import numpy as np
from omegaconf import DictConfig
Expand Down
1 change: 0 additions & 1 deletion cases/ldc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Lid-driven cavity case setup"""


import jax.numpy as jnp
import numpy as np
from omegaconf import DictConfig
Expand Down
1 change: 0 additions & 1 deletion cases/tgv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Taylor-Green case setup"""


import jax.numpy as jnp
import numpy as np
from omegaconf import DictConfig
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
absl-py>=2.1.0
h5py
jax[cpu]==0.4.29
jax[cpu]==0.6.2
jraph>=0.0.6.dev0
omegaconf
pandas
Expand Down
2 changes: 1 addition & 1 deletion jax_sph/case_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def initialize(self):
assert state[k][mask].shape == _state[k][_mask].shape, ValueError(
f"Shape mismatch for key {k} in state0 file."
)
state[k][mask] = _state[k][_mask]
state[k] = state[k].at[mask].set(_state[k][_mask])

# the following arguments are needed for dataset generation
cfg.case.c_ref, cfg.case.p_ref, cfg.case.p_bg = c_ref, p_ref, p_bg
Expand Down
12 changes: 6 additions & 6 deletions jax_sph/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ def check_cfg(cfg: DictConfig) -> None:
assert cfg.solver.dt is None or cfg.solver.dt > 0.0, "dt must be > 0.0."
assert cfg.solver.t_end > 0.0, "t_end must be > 0.0."
assert cfg.solver.artificial_alpha >= 0.0, "artificial_alpha must be >= 0.0."
assert (
cfg.solver.eta_limiter == -1 or cfg.solver.eta_limiter >= 0
), "eta_limiter must be -1 or >0."
assert cfg.solver.eta_limiter == -1 or cfg.solver.eta_limiter >= 0, (
"eta_limiter must be -1 or >0."
)
assert cfg.solver.kappa >= 0.0, "kappa must be >= 0.0."

assert cfg.kernel.name in ["QSK", "WC2K"], "Kernel must be one of 'QSK' or 'WC2K'."
Expand All @@ -189,7 +189,7 @@ def check_cfg(cfg: DictConfig) -> None:
assert cfg.nl.backend in _nl_backends, f"nl_backend must be in {_nl_backends}."

_io_trite_type = ["h5", "vtk"]
assert all(
[w in _io_trite_type for w in cfg.io.write_type]
), f"write_type must be in {_io_trite_type}."
assert all([w in _io_trite_type for w in cfg.io.write_type]), (
f"write_type must be in {_io_trite_type}."
)
assert cfg.io.write_every > 0, "write_every must be > 0."
1 change: 0 additions & 1 deletion jax_sph/eos.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Equation of state."""


from abc import ABC, abstractmethod


Expand Down
27 changes: 12 additions & 15 deletions jax_sph/jax_md/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import jraph
import numpy as onp
from absl import logging
from jax import eval_shape, jit, lax, ops, tree_map, vmap
from jax import eval_shape, jit, lax, ops, tree, vmap
from jax.core import ShapedArray

from jax_sph.jax_md import dataclasses, space, util
Expand Down Expand Up @@ -103,9 +103,9 @@ def kwarg_buffers(self):
@dataclasses.dataclass
class CellListFns:
allocate: Callable[..., CellList] = dataclasses.static_field()
update: Callable[
[Array, Union[CellList, int]], CellList
] = dataclasses.static_field()
update: Callable[[Array, Union[CellList, int]], CellList] = (
dataclasses.static_field()
)

def __iter__(self):
return iter((self.allocate, self.update))
Expand Down Expand Up @@ -492,9 +492,7 @@ def __str__(self) -> str:
return "Partition Error: Cell size too small"

if jnp.any(self.code & PEC.MALFORMED_BOX):
return (
"Partition Error: Incorrect box format. Expecting upper " "triangular."
)
return "Partition Error: Incorrect box format. Expecting upper triangular."

raise ValueError(f"Unexpected Error Code {self.code}.")

Expand Down Expand Up @@ -523,8 +521,7 @@ def _displacement_or_metric_to_metric_sq(
except ValueError:
continue
raise ValueError(
"Canonicalize displacement not implemented for spatial dimension larger"
"than 4."
"Canonicalize displacement not implemented for spatial dimension larger than 4."
)


Expand Down Expand Up @@ -654,9 +651,9 @@ class NeighborList:
format: NeighborListFormat = dataclasses.static_field()
cell_size: Optional[float] = dataclasses.static_field()
cell_list_fn: Callable[[Array, CellList], CellList] = dataclasses.static_field()
update_fn: Callable[
[Array, "NeighborList"], "NeighborList"
] = dataclasses.static_field()
update_fn: Callable[[Array, "NeighborList"], "NeighborList"] = (
dataclasses.static_field()
)

def update(self, position: Array, **kwargs) -> "NeighborList":
return self.update_fn(position, self, **kwargs)
Expand Down Expand Up @@ -1082,10 +1079,10 @@ def pad(x):
padding = jnp.zeros((1,) + x.shape[1:], dtype=x.dtype)
return jnp.concatenate((x, padding), axis=0)

nodes = tree_map(pad, nodes)
nodes = tree.map(pad, nodes)

# Pad the globals to add one fictitious global.
globals = tree_map(pad, globals)
globals = tree.map(pad, globals)

# If there is an additional mask, reorder the edges.
if mask is not None:
Expand All @@ -1099,7 +1096,7 @@ def pad(x):
def reorder_edges(x):
return jnp.zeros_like(x).at[index].set(x)

edges = tree_map(reorder_edges, edges)
edges = tree.map(reorder_edges, edges)
mask = receivers < N

return jraph.GraphsTuple(
Expand Down
7 changes: 3 additions & 4 deletions jax_sph/jax_md/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def inverse(box: Box) -> Box:
elif box.ndim == 2:
return jnp.linalg.inv(box)
raise ValueError(
("Box must be either: a scalar, a vector, or a matrix. " f"Found {box}.")
(f"Box must be either: a scalar, a vector, or a matrix. Found {box}.")
)


Expand Down Expand Up @@ -116,7 +116,7 @@ def raw_transform(box: Box, R: Array) -> Array:
right_indices = free_indices + "i"
return jnp.einsum(f"ij,{left_indices}->{right_indices}", box, R)
raise ValueError(
("Box must be either: a scalar, a vector, or a matrix. " f"Found {box}.")
(f"Box must be either: a scalar, a vector, or a matrix. Found {box}.")
)


Expand Down Expand Up @@ -473,6 +473,5 @@ def canonicalize_displacement_or_metric(displacement_or_metric):
except ValueError:
continue
raise ValueError(
"Canonicalize displacement not implemented for spatial dimension larger"
"than 4."
"Canonicalize displacement not implemented for spatial dimension larger than 4."
)
35 changes: 21 additions & 14 deletions notebooks/iclr24_grads.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@
"from jax_sph.defaults import defaults\n",
"from jax_sph.integrator import si_euler\n",
"from jax_sph.partition import neighbor_list\n",
"from jax_sph.solver import WCSPH\n"
"from jax_sph.solver import WCSPH"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -61,7 +61,7 @@
"WARMUP = 10\n",
"N = 20\n",
"\n",
"dx = 1.0 / N0.000367\n",
"dx = 1.0 / N\n",
"dt = 0.0\n",
"\n",
"FD_EPS = 0.001 * dx"
Expand Down Expand Up @@ -137,8 +137,7 @@
"state0s = {}\n",
"\n",
"for ind, case_id in enumerate(CASES):\n",
"\n",
" cfg = OmegaConf.create({\"config\": \"../cases/\"+case_id.lower()+\".yaml\"})\n",
" cfg = OmegaConf.create({\"config\": \"../cases/\" + case_id.lower() + \".yaml\"})\n",
" cfg = OmegaConf.merge(defaults.copy(), OmegaConf.load(cfg.config), cfg)\n",
" cfg.seed = 0\n",
" cfg.case.dx = dx\n",
Expand All @@ -147,9 +146,9 @@
"\n",
" Case = load_case(os.path.dirname(cfg.config), cfg.case.source)\n",
" case = Case(cfg)\n",
" (\n",
" cfg, box_size, state, g_ext_fn, bc_fn, nw_fn, eos_fn, _, displ_fn, shift_fn\n",
" ) = case.initialize()\n",
" (cfg, box_size, state, g_ext_fn, bc_fn, nw_fn, eos_fn, _, displ_fn, shift_fn) = (\n",
" case.initialize()\n",
" )\n",
"\n",
" dt = cfg.solver.dt\n",
"\n",
Expand Down Expand Up @@ -258,16 +257,24 @@
" # https://github.com/tumaer/jax-sph/tree/v0.0.1\n",
" axs[0].scatter(state0[\"r\"][:, 0], state0[\"r\"][:, 1], s=10)\n",
" axs[0].quiver(\n",
" state0[\"r\"][:, 0], state0[\"r\"][:, 1], \n",
" sph_grad[:, 0], sph_grad[:, 1], \n",
" scale=15, width=0.005, headwidth=3, \n",
" state0[\"r\"][:, 0],\n",
" state0[\"r\"][:, 1],\n",
" sph_grad[:, 0],\n",
" sph_grad[:, 1],\n",
" scale=15,\n",
" width=0.005,\n",
" headwidth=3,\n",
" )\n",
" axs[0].set_title(\"Autograd\", fontsize=26)\n",
" axs[1].scatter(state0[\"r\"][:, 0], state0[\"r\"][:, 1], s=10)\n",
" axs[1].quiver(\n",
" state0[\"r\"][:, 0], state0[\"r\"][:, 1], \n",
" fd_grad[:, 0], fd_grad[:, 1], \n",
" scale=15, width=0.005, headwidth=3, \n",
" state0[\"r\"][:, 0],\n",
" state0[\"r\"][:, 1],\n",
" fd_grad[:, 0],\n",
" fd_grad[:, 1],\n",
" scale=15,\n",
" width=0.005,\n",
" headwidth=3,\n",
" )\n",
" axs[1].set_title(\"Finite differences\", fontsize=26)\n",
" axs[0].set_xticks([])\n",
Expand Down
Loading
Loading