Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
2ca9a94
rm kwargs from RF
radka-j Oct 6, 2025
3ce9677
save mlp kwargs in ensembles
radka-j Oct 6, 2025
cb44c83
save all MLP and GP input params
radka-j Oct 6, 2025
5346644
update HMW to work with Emulator as well Result object, rename result…
radka-j Oct 6, 2025
23596b1
add option to pass transformed_emulator_params
radka-j Oct 6, 2025
94b7029
Merge branch 'iss867/update_gp_factory' into reinitialise
radka-j Oct 6, 2025
d9503f1
make mlp_kwargs a keyword argument in MLP ensembles
radka-j Oct 6, 2025
4e6e742
make scheduler_kwargs a keyword argument
radka-j Oct 6, 2025
c08f610
add scheduler_cls input keyword arg
radka-j Oct 6, 2025
e0ea75b
check x/y standardization from emulator object
radka-j Oct 6, 2025
dacefba
update correlated GP
radka-j Oct 7, 2025
5e25da2
Merge branch 'main' into reinitialise
radka-j Oct 8, 2025
8439e4a
fix scheduler kwarg passing to scheduler_setup
radka-j Oct 8, 2025
38283a8
update scheduler_setup method
radka-j Oct 8, 2025
7a925b2
fix test
radka-j Oct 8, 2025
3c523a6
update scheduler tests
radka-j Oct 8, 2025
f6a2377
add reinitialize method
radka-j Oct 8, 2025
5b0d9f2
add reinitialize method
radka-j Oct 8, 2025
4102bbd
add tensor conversion and device handling to TransformedEmulator
radka-j Oct 8, 2025
d698bb8
fix var order
radka-j Oct 8, 2025
36a8b64
update Emulator.fit to expect InputLike, not TensorLike
radka-j Oct 8, 2025
12371a6
refactor fit_from_reinitialised function
radka-j Oct 8, 2025
264d6a6
update learners tests
radka-j Oct 8, 2025
aa18dca
use fit_from_initialized in learners
radka-j Oct 8, 2025
dee04f5
revert changes in learners
radka-j Oct 8, 2025
e6d12fc
accept both emulator and result as keyword args to HMW
radka-j Oct 8, 2025
f8bf4ee
update test
radka-j Oct 8, 2025
e606532
revert doc change
radka-j Oct 8, 2025
fcd2827
revert Emulator data input types to TensorLike (not InputLike)
radka-j Oct 9, 2025
c10826f
rm now unnecessary tensor transform from TransformedEmulator
radka-j Oct 9, 2025
41b0972
rm now unnecessary tensor transform from Emulator
radka-j Oct 9, 2025
e602188
update fit_from_reinitialize to expect TensorLike not InputLike
radka-j Oct 9, 2025
3c8d3dd
ensure tensor is float
radka-j Oct 9, 2025
61db6eb
ensure tensors are floats
radka-j Oct 9, 2025
0c238b0
revert scheduler_params back to kwargs
radka-j Oct 9, 2025
b77e283
use convert_to_tensors method
radka-j Oct 9, 2025
06c7ca0
rename scheduler_kwargs to scheduler_params
radka-j Oct 9, 2025
39797ad
use fit_from_initialized in AL, change types to DistributionLike from…
radka-j Oct 13, 2025
a5fae9b
Update case_studies/patient_calibration/patient_calibration_case_stud…
radka-j Oct 14, 2025
ae942e5
update docstrings
radka-j Oct 14, 2025
8498cb3
avoid code repetition
radka-j Oct 14, 2025
85bf20d
revert to emulator.fit instead of fit_from_reinitialised in stream ba…
radka-j Oct 14, 2025
ce6ea04
Update case_studies/patient_calibration/patient_calibration_case_stud…
radka-j Oct 14, 2025
b5b902d
Update autoemulate/calibration/history_matching.py
radka-j Oct 14, 2025
bcd2939
Update autoemulate/calibration/history_matching.py
radka-j Oct 14, 2025
704fe37
add option to change whether fit from reinitialized or not in AL
radka-j Oct 14, 2025
43039f1
increase learning rate in AL tutorial, set posterior_predictive=False
radka-j Oct 14, 2025
40d7530
Update autoemulate/calibration/history_matching.py
radka-j Oct 14, 2025
597977d
raise error if output distribution does not have a variance property
radka-j Oct 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
use fit_from_initialized in AL, change types to DistributionLike from…
… GaussianLike to match TranformedEmulator predict type
  • Loading branch information
radka-j committed Oct 13, 2025
commit 39797ad50ee5f3f59d0aa50cb196cf912da7aae1
34 changes: 23 additions & 11 deletions autoemulate/learners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

import torch
from anytree import Node, RenderTree
from torch.distributions import MultivariateNormal
from torcheval.metrics import MeanSquaredError, R2Score

from autoemulate.core.logging_config import get_configured_logger
from autoemulate.core.reinitialize import fit_from_reinitialized
from autoemulate.data.utils import ValidationMixin
from autoemulate.emulators.base import Emulator
from autoemulate.simulations.base import Simulator

from ..core.types import GaussianLike, TensorLike
from ..core.types import DistributionLike, TensorLike


@dataclass(kw_only=True)
Expand Down Expand Up @@ -48,7 +48,12 @@ def __post_init__(self):
log_level = getattr(self, "log_level", "progress_bar")
self.logger, self.progress_bar = get_configured_logger(log_level)
self.logger.info("Initializing Learner with training data.")
self.emulator.fit(self.x_train, self.y_train)
self.emulator = fit_from_reinitialized(
self.x_train,
self.y_train,
emulator=self.emulator,
device=self.emulator.device,
)
self.logger.info("Emulator fitted with initial training data.")
self.in_dim = self.x_train.shape[1]
self.out_dim = self.y_train.shape[1]
Expand Down Expand Up @@ -140,11 +145,9 @@ def fit(self, *args):
x, output, extra = self.query(*args)
if isinstance(output, TensorLike):
y_pred = output
elif isinstance(output, GaussianLike):
elif isinstance(output, DistributionLike):
assert output.variance.ndim == 2
y_pred, _ = output.mean, output.variance
elif isinstance(output, GaussianLike):
y_pred, _ = output.loc, None
y_pred = output.mean
else:
msg = (
f"Output must be either `Tensor` or `MultivariateNormal` but got "
Expand All @@ -159,7 +162,12 @@ def fit(self, *args):
assert isinstance(y_true, TensorLike)
self.x_train = torch.cat([self.x_train, x])
self.y_train = torch.cat([self.y_train, y_true])
self.emulator.fit(self.x_train, self.y_train)
self.emulator = fit_from_reinitialized(
self.x_train,
self.y_train,
emulator=self.emulator,
device=self.emulator.device,
)
self.mse.update(y_pred, y_true)
self.r2.update(y_pred, y_true)
self.n_queries += 1
Expand All @@ -181,7 +189,7 @@ def fit(self, *args):

# If Gaussian output
# TODO: check generality for other GPs (e.g. with full covariance)
if isinstance(output, MultivariateNormal):
if isinstance(output, DistributionLike):
assert isinstance(output.variance, TensorLike)
assert output.variance.ndim == 2
assert output.variance.shape[1] == self.out_dim
Expand Down Expand Up @@ -257,7 +265,11 @@ def summary(self):
@abstractmethod
def query(
self, x: TensorLike | None = None
) -> tuple[TensorLike | None, TensorLike | GaussianLike, dict[str, float]]:
) -> tuple[
TensorLike | None,
TensorLike | DistributionLike,
dict[str, float],
]:
"""
Abstract method to query new samples.

Expand All @@ -268,7 +280,7 @@ def query(

Returns
-------
tuple[TensorLike or None, TensorLike, TensorLike, Dict[str, list[Any]]]
tuple[TensorLike | None, TensorLike | DistributionLike, dict[str, float]]
A tuple containing:
- The queried samples (or None if no query is made),
- The predicted outputs,
Expand Down
35 changes: 20 additions & 15 deletions autoemulate/learners/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from autoemulate.data.utils import set_random_seed

from ..core.types import GaussianLike, TensorLike
from ..core.types import DistributionLike, TensorLike
from .base import Active


Expand All @@ -26,7 +26,7 @@ class Stream(Active):
@abstractmethod
def query(
self, x: TensorLike | None = None
) -> tuple[TensorLike | None, TensorLike | GaussianLike, dict[str, float]]:
) -> tuple[TensorLike | None, TensorLike | DistributionLike, dict[str, float]]:
"""
Abstract method to query new samples from a stream.

Expand All @@ -37,7 +37,7 @@ def query(

Returns
-------
Tuple[torch.Tensor or None, torch.Tensor, torch.Tensor, Dict[str, List[Any]]]
tuple[TensorLike | None, TensorLike | DistributionLike, dict[str, float]]
A tuple containing:
- The queried samples (or None if no query is made),
- The predicted outputs,
Expand All @@ -64,9 +64,7 @@ def fit_samples(self, x: torch.Tensor):
)
):
self.fit(xi.reshape(1, -1))
pb.set_postfix(
ordered_dict={key: val[-1] for key, val in self.metrics.items()}
)
pb.set_postfix(ordered_dict={k: v[-1] for k, v in self.metrics.items()})

def fit_batches(self, x: torch.Tensor, batch_size: int):
"""
Expand Down Expand Up @@ -114,7 +112,11 @@ def query(
self,
x: TensorLike | None = None,
random_seed: int | None = None,
) -> tuple[torch.Tensor | None, TensorLike | GaussianLike, dict[str, float]]:
) -> tuple[
torch.Tensor | None,
TensorLike | DistributionLike,
dict[str, float],
]:
"""
Query new samples randomly based on a fixed probability.

Expand All @@ -125,7 +127,7 @@ def query(

Returns
-------
Tuple[torch.Tensor or None, torch.Tensor, torch.Tensor, Dict[str, List[Any]]]
tuple[TensorLike | None, TensorLike | DistributionLike, dict[str, float]]
A tuple containing:
- The queried samples (or None if the random condition is not met),
- The predicted outputs,
Expand All @@ -135,8 +137,7 @@ def query(
assert isinstance(x, TensorLike)
# TODO: move handling to check method in base class
output = self.emulator.predict(x)
assert isinstance(output, TensorLike | GaussianLike)
# assert isinstance(output, TensorLike | DistributionLike)
assert isinstance(output, TensorLike | DistributionLike)
if random_seed is not None:
set_random_seed(seed=random_seed)
x = x if np.random.rand() < self.p_query else None
Expand Down Expand Up @@ -199,7 +200,7 @@ def query(
self, x: TensorLike | None = None
) -> tuple[
torch.Tensor | None,
torch.Tensor | GaussianLike,
torch.Tensor | DistributionLike,
dict[str, float],
]:
"""
Expand All @@ -212,7 +213,7 @@ def query(

Returns
-------
Tuple[torch.Tensor or None, torch.Tensor, torch.Tensor, Dict[str, List[Any]]]
tuple[TensorLike | None, TensorLike | DistributionLike, dict[str, float]]
A tuple containing:
- The queried samples (or None if the score does not exceed the threshold),
- The predicted outputs,
Expand All @@ -222,7 +223,7 @@ def query(
# TODO: move handling to check method in base class
assert isinstance(x, torch.Tensor)
output = self.emulator.predict(x)
assert isinstance(output, GaussianLike)
assert isinstance(output, DistributionLike)
assert isinstance(output.variance, torch.Tensor)
score = self.score(x, output.mean, output.variance)
x = x if score > self.threshold else None
Expand Down Expand Up @@ -457,7 +458,11 @@ def __post_init__(self):

def query(
self, x: TensorLike | None = None
) -> tuple[TensorLike | None, TensorLike | GaussianLike, dict[str, float]]:
) -> tuple[
TensorLike | None,
TensorLike | DistributionLike,
dict[str, float],
]:
"""
Query new samples based on the adaptive threshold.

Expand All @@ -471,7 +476,7 @@ def query(

Returns
-------
Tuple[torch.Tensor or None, torch.Tensor, torch.Tensor, Dict[str, List[Any]]]
tuple[TensorLike | None, TensorLike | DistributionLike, dict[str, float]]
A tuple containing:
- The queried samples,
- The predicted outputs,
Expand Down