Skip to content

Commit c27791f

Browse files
committed
More experiments.
1 parent 42c8500 commit c27791f

File tree

15 files changed

+5128
-1249
lines changed

15 files changed

+5128
-1249
lines changed

jupyterbook/content/experimental/cifar/gf_cifar10_naive_tfds_splines.ipynb

Lines changed: 2321 additions & 0 deletions
Large diffs are not rendered by default.

jupyterbook/content/experimental/details/marginal_gauss.ipynb

Lines changed: 235 additions & 240 deletions
Large diffs are not rendered by default.

jupyterbook/content/experimental/mnist/gf_demo_mnist_naive_tfds_splines.ipynb

Lines changed: 138 additions & 183 deletions
Large diffs are not rendered by default.

jupyterbook/content/experimental/plane/gf_demo_plane_splines.ipynb

Lines changed: 160 additions & 468 deletions
Large diffs are not rendered by default.

jupyterbook/content/public/gf_demo_plane.ipynb

Lines changed: 169 additions & 146 deletions
Large diffs are not rendered by default.

jupyterbook/content/public/rbig_building_blocks.ipynb

Lines changed: 168 additions & 93 deletions
Large diffs are not rendered by default.

jupyterbook/content/public/rbig_other_transforms.ipynb

Lines changed: 1599 additions & 0 deletions
Large diffs are not rendered by default.

rbig_jax/models/gaussflow.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from distrax._src.distributions.normal import Normal
1212

1313
from rbig_jax.transforms.base import Bijector, BijectorChain
14-
from rbig_jax.transforms.inversecdf import InitInverseGaussCDF
15-
from rbig_jax.transforms.logit import InitLogitTransform
14+
from rbig_jax.transforms.inversecdf import InitInverseGaussCDF, InitGaussCDF
15+
from rbig_jax.transforms.logit import InitLogitTransform, InitSigmoidTransform
1616
from rbig_jax.transforms.parametric.householder import InitHouseHolder
1717
from rbig_jax.transforms.parametric.mixture.gaussian import InitMixtureGaussianCDF
1818
from rbig_jax.transforms.parametric.mixture.logistic import InitMixtureLogisticCDF
@@ -253,6 +253,110 @@ def init_gf_spline_model(
253253
return gf_model
254254

255255

256+
def init_gf_composite_spline_model(
257+
shape: tuple,
258+
X: Array = None,
259+
n_blocks: int = 4,
260+
n_bins: int = 20,
261+
range_min: float = 0.0,
262+
range_max: float = 1.0,
263+
init_rotation: str = "random",
264+
n_reflections: int = 10,
265+
squeeze: str = "sigmoid",
266+
**kwargs,
267+
):
268+
269+
n_features = shape[0]
270+
rng = jax.random.PRNGKey(42)
271+
# rng, _ = jax.random.split(jax.random.PRNGKey(123), 2)
272+
# =====================
273+
# Composite Transform
274+
# ======================
275+
init_nl_forward_f = InitGaussCDF()
276+
init_nl_inverse_f = InitInverseGaussCDF()
277+
278+
# =====================
279+
# RQ Spline
280+
# ======================
281+
init_rq_f = InitPiecewiseRationalQuadraticCDF(
282+
n_bins=n_bins, range_min=range_min, range_max=range_max, **kwargs
283+
)
284+
# =====================
285+
# HouseHolder Transform
286+
# ======================
287+
n_reflections = n_reflections
288+
# initialize init function
289+
init_hh_f = InitHouseHolder(n_reflections=n_reflections, method=init_rotation)
290+
291+
block_rngs = jax.random.split(rng, num=n_blocks)
292+
# rng = jax.random.split(jax.random.PRNGKey(42), n_blocks)
293+
# block_rngs = jax.random.split(jax.random.PRNGKey(42), n_blocks)
294+
295+
itercount = itertools.count()
296+
bijectors = []
297+
298+
X_g = X.copy()
299+
300+
pbar = tqdm.tqdm(block_rngs)
301+
with pbar:
302+
for iblock, irng in enumerate(pbar):
303+
304+
pbar.set_description(
305+
f"Initializing - Block: {iblock+1} | Layer {next(itercount)}"
306+
)
307+
# ======================
308+
# Forward Squeezing Transform
309+
# ======================
310+
# intialize bijector and transformation
311+
X_g, layer = init_nl_forward_f.transform_and_bijector(inputs=X_g,)
312+
# add bijector to list
313+
bijectors.append(layer)
314+
# ======================
315+
# RQ Spline
316+
# ======================
317+
# create keys for all inits
318+
irng, irq_rng = jax.random.split(irng, 2)
319+
320+
# intialize bijector and transformation
321+
X_g, layer = init_rq_f.transform_and_bijector(
322+
inputs=X_g, rng=irq_rng, shape=X.shape[1:]
323+
)
324+
325+
# add bijector to list
326+
bijectors.append(layer)
327+
328+
# ======================
329+
# Inverse Squeezing Transform
330+
# ======================
331+
# intialize bijector and transformation
332+
X_g, layer = init_nl_inverse_f.transform_and_bijector(inputs=X_g,)
333+
# add bijector to list
334+
bijectors.append(layer)
335+
336+
# ======================
337+
# HOUSEHOLDER
338+
# ======================
339+
pbar.set_description(
340+
f"Initializing - Block: {iblock+1} | Layer {next(itercount)}"
341+
)
342+
# create keys for all inits
343+
irng, hh_rng = jax.random.split(irng, 2)
344+
345+
# intialize bijector and transformation
346+
X_g, layer = init_hh_f.transform_and_bijector(
347+
inputs=X_g, rng=hh_rng, n_features=n_features
348+
)
349+
350+
bijectors.append(layer)
351+
352+
# create base dist
353+
base_dist = Normal(jnp.zeros((n_features,)), jnp.ones((n_features,)))
354+
355+
# create flow model
356+
gf_model = GaussianizationFlow(base_dist=base_dist, bijectors=bijectors)
357+
return gf_model
358+
359+
256360
def add_gf_model_args(parser):
257361
# ====================
258362
# Model Args

rbig_jax/plots.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,6 @@ def plot_info_loss(
8484
plt.show()
8585

8686

87-
88-
8987
def plot_image_grid(image, image_shape: Optional = None):
9088

9189
fig = plt.figure(figsize=(10.0, 10.0))

rbig_jax/transforms/inversecdf.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -129,31 +129,32 @@ def InitGaussCDF(eps: float = 1e-5, jitted=False):
129129
else:
130130
f = bijector.forward
131131

132-
def init_bijector(inputs, **kwargs):
132+
def transform(inputs, **kwargs):
133+
134+
outputs = f(inputs)
135+
136+
return outputs
137+
138+
def bijector(inputs=None, **kwargs):
133139

134140
return GaussCDF(eps=eps)
135141

136-
def bijector_and_transform(inputs, **kwargs):
142+
def transform_and_bijector(inputs, **kwargs):
137143
outputs = f(inputs)
138144
return outputs, GaussCDF(eps=eps)
139145

140-
def transform(inputs, **kwargs):
141-
outputs = f(inputs)
142-
return outputs
146+
def transform_gradient_bijector(inputs, **kwargs):
147+
bijector = GaussCDF(eps=eps)
143148

144-
def params(inputs, **kwargs):
145-
return ()
149+
outputs, logabsdet = bijector.forward_and_log_det(inputs)
146150

147-
def params_and_transform(inputs, **kwargs):
148-
outputs = f(inputs)
149-
return outputs, ()
151+
return outputs, logabsdet, bijector
150152

151153
return InitLayersFunctions(
152-
bijector=init_bijector,
153-
bijector_and_transform=bijector_and_transform,
154154
transform=transform,
155-
params=params,
156-
params_and_transform=params_and_transform,
155+
bijector=bijector,
156+
transform_and_bijector=transform_and_bijector,
157+
transform_gradient_bijector=transform_gradient_bijector,
157158
)
158159

159160

0 commit comments

Comments
 (0)