diff --git a/docs/content/4-saving.ipynb b/docs/content/4-saving.ipynb index 4abe8871..e1bd3f19 100644 --- a/docs/content/4-saving.ipynb +++ b/docs/content/4-saving.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "tags": [ "remove_cell" @@ -23,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "tags": [ "hide_input" @@ -93,6 +93,45 @@ " ngb_unpickled = pickle.load(f)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n" + ] + }, { "cell_type": "code", "execution_count": 12, diff --git a/examples/user-guide/content/4-saving.ipynb b/examples/user-guide/content/4-saving.ipynb index 4abe8871..fb23a461 100644 --- a/examples/user-guide/content/4-saving.ipynb +++ b/examples/user-guide/content/4-saving.ipynb @@ -23,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "tags": [ "hide_input" @@ -93,6 +93,45 @@ " ngb_unpickled = pickle.load(f)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n" + ] + }, { "cell_type": "code", "execution_count": 12, diff --git a/ngboost/ngboost.py b/ngboost/ngboost.py index 2406e574..2f6bfccc 100644 --- a/ngboost/ngboost.py +++ b/ngboost/ngboost.py @@ -4,16 +4,47 @@ # pylint: disable=unused-argument,too-many-locals,too-many-branches,too-many-statements # pylint: disable=unused-variable,invalid-unary-operand-type,attribute-defined-outside-init # pylint: disable=redundant-keyword-arg,protected-access,unnecessary-lambda-assignment +# pylint: disable=too-many-public-methods,too-many-lines +import json +from pathlib import Path +from typing import Any, Dict + import numpy as np from sklearn.base import clone from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeRegressor from sklearn.utils import check_array, check_random_state, check_X_y -from ngboost.distns import MultivariateNormal, Normal, k_categorical +from ngboost.distns import ( + Bernoulli, + Cauchy, + Exponential, + Gamma, + HalfNormal, + Laplace, + LogNormal, + MultivariateNormal, + Normal, + NormalFixedMean, + NormalFixedVar, + Poisson, + T, + TFixedDf, + TFixedDfFixedVar, + Weibull, + k_categorical, +) from ngboost.learners import default_tree_learner from ngboost.manifold import manifold -from ngboost.scores import LogScore +from ngboost.scores import CRPScore, LogScore +from ngboost.serialization import numpy_to_list, tree_from_dict, tree_to_dict + +try: + import ubjson + + UBJSON_AVAILABLE = True +except ImportError: + UBJSON_AVAILABLE = False class NGBoost: @@ -653,3 +684,397 @@ def _get_feature_importance(self, tree, tree_index): total_feature_importance = np.zeros(self.n_features) total_feature_importance[self.col_idxs[tree_index]] = tree_feature_importance return total_feature_importance + + def to_dict(self, include_non_essential=False) -> Dict[str, Any]: + """ + Convert the NGBoost model to a JSON-serializable dictionary. + + Parameters: + include_non_essential: If False, exclude feature_importances_ and evals_result + to reduce file size (default: False) + + Returns: + Dictionary containing all model data needed for reconstruction + """ + if not self.base_models: + raise ValueError("Model must be fitted before serialization") + + # Serialize base models (trees) + serialized_base_models = [] + for iteration_models in self.base_models: + iteration_trees = [] + for tree in iteration_models: + if isinstance(tree, DecisionTreeRegressor): + iteration_trees.append(tree_to_dict(tree)) + else: + raise ValueError( + f"Unsupported base learner type: {type(tree)}. " + "Only DecisionTreeRegressor is currently supported for JSON serialization." + ) + serialized_base_models.append(iteration_trees) + + # Build the model dictionary + model_dict = { + "version": "1.0", + "model_type": self.__class__.__name__, + "Dist_name": self.Dist.__name__, + "Score_name": self.Score.__name__, + "natural_gradient": self.natural_gradient, + "n_estimators": self.n_estimators, + "learning_rate": self.learning_rate, + "minibatch_frac": self.minibatch_frac, + "col_sample": self.col_sample, + "verbose": self.verbose, + "verbose_eval": self.verbose_eval, + "tol": self.tol, + "random_state": ( + numpy_to_list(list(self.random_state.get_state())) + if self.random_state + else None + ), + "validation_fraction": self.validation_fraction, + "early_stopping_rounds": self.early_stopping_rounds, + "n_features": self.n_features, + "init_params": numpy_to_list(self.init_params), + "base_models": serialized_base_models, + "scalings": numpy_to_list(self.scalings), + "col_idxs": numpy_to_list(self.col_idxs), + "best_val_loss_itr": self.best_val_loss_itr, + } + + # Handle special distribution cases + if self.Dist.__name__ == "Categorical": + model_dict["K"] = self.Dist.n_params + 1 + elif self.Dist.__name__ == "MVN": + model_dict["K"] = int((-3 + (9 + 8 * (self.Dist.n_params)) ** 0.5) / 2) + elif self.Dist.__name__ == "SurvivalDistn": + # SurvivalDistn is a dynamically created class, save the base distribution + model_dict["_is_survival"] = True + model_dict["_basedist_name"] = self.Dist._basedist.__name__ + + # Include non-essential data if requested + if include_non_essential: + if ( + hasattr(self, "feature_importances_") + and self.feature_importances_ is not None + ): + model_dict["feature_importances_"] = numpy_to_list( + self.feature_importances_ + ) + if hasattr(self, "evals_result") and self.evals_result: + model_dict["evals_result"] = { + k: {kk: numpy_to_list(vv) for kk, vv in v.items()} + for k, v in self.evals_result.items() + } + + return model_dict + + @classmethod + def from_dict(cls, model_dict: Dict[str, Any]): + """ + Reconstruct an NGBoost model from a dictionary. + + Parameters: + model_dict: Dictionary containing model data (from to_dict()) + + Returns: + Reconstructed NGBoost model instance + + Raises: + ValueError: If the model dictionary is invalid or missing required keys + KeyError: If required keys are missing from the dictionary + """ + # Validate required keys + required_keys = [ + "version", + "model_type", + "Dist_name", + "Score_name", + "natural_gradient", + "n_estimators", + "learning_rate", + "n_features", + "init_params", + "base_models", + "scalings", + "col_idxs", + ] + missing_keys = [key for key in required_keys if key not in model_dict] + if missing_keys: + raise ValueError( + f"Invalid model dictionary: missing required keys: {missing_keys}. " + "The model file may be corrupted or in an unsupported format." + ) + + # Check version compatibility (for future format changes) + version = model_dict.get("version", "unknown") + if version != "1.0": + raise ValueError( + f"Unsupported model version: {version}. " + "This version of NGBoost supports version 1.0. " + "Please upgrade NGBoost or use a compatible model file." + ) + + # Determine the correct class to instantiate + model_type = model_dict.get("model_type", "NGBoost") + + # Import API classes if needed (lazy import to avoid circular dependencies) + if model_type in ("NGBRegressor", "NGBClassifier", "NGBSurvival"): + # pylint: disable=import-outside-toplevel + from ngboost.api import NGBClassifier, NGBRegressor, NGBSurvival + + if model_type == "NGBRegressor": + instance = NGBRegressor.__new__(NGBRegressor) + elif model_type == "NGBClassifier": + instance = NGBClassifier.__new__(NGBClassifier) + elif model_type == "NGBSurvival": + instance = NGBSurvival.__new__(NGBSurvival) + else: + # This should never happen, but ensures instance is always defined + instance = cls.__new__(cls) + else: + instance = cls.__new__(cls) + + # Restore distribution + dist_name = model_dict["Dist_name"] + if dist_name == "Categorical": + if "K" not in model_dict: + raise ValueError( + "Invalid model dictionary: missing 'K' for Categorical distribution." + ) + instance.Dist = k_categorical(model_dict["K"]) + elif dist_name == "MVN": + if "K" not in model_dict: + raise ValueError( + "Invalid model dictionary: missing 'K' for MVN distribution." + ) + instance.Dist = MultivariateNormal(model_dict["K"]) + elif model_dict.get("_is_survival", False): + # Handle SurvivalDistn - dynamically created class + # pylint: disable=import-outside-toplevel + from ngboost.distns.utils import SurvivalDistnClass + + if "_basedist_name" not in model_dict: + raise ValueError( + "Invalid model dictionary: missing '_basedist_name' for Survival distribution." + ) + basedist_name = model_dict["_basedist_name"] + dist_map = { + "Bernoulli": Bernoulli, + "Cauchy": Cauchy, + "Exponential": Exponential, + "Gamma": Gamma, + "HalfNormal": HalfNormal, + "Laplace": Laplace, + "LogNormal": LogNormal, + "Normal": Normal, + "NormalFixedMean": NormalFixedMean, + "NormalFixedVar": NormalFixedVar, + "Poisson": Poisson, + "T": T, + "TFixedDf": TFixedDf, + "TFixedDfFixedVar": TFixedDfFixedVar, + "Weibull": Weibull, + } + + if basedist_name not in dist_map: + raise ValueError( + f"Unknown base distribution for Survival: {basedist_name}" + ) + basedist = dist_map[basedist_name] + instance.Dist = SurvivalDistnClass(basedist) + else: + dist_map = { + "Bernoulli": Bernoulli, + "Cauchy": Cauchy, + "Exponential": Exponential, + "Gamma": Gamma, + "HalfNormal": HalfNormal, + "Laplace": Laplace, + "LogNormal": LogNormal, + "Normal": Normal, + "NormalFixedMean": NormalFixedMean, + "NormalFixedVar": NormalFixedVar, + "Poisson": Poisson, + "T": T, + "TFixedDf": TFixedDf, + "TFixedDfFixedVar": TFixedDfFixedVar, + "Weibull": Weibull, + } + + if dist_name not in dist_map: + raise ValueError(f"Unknown distribution: {dist_name}") + instance.Dist = dist_map[dist_name] + + # Restore score + score_name = model_dict["Score_name"] + score_map = { + "LogScore": LogScore, + "MLE": LogScore, + "CRPScore": CRPScore, + "CRPS": CRPScore, + } + instance.Score = score_map.get(score_name, LogScore) + + # Restore manifold + instance.Manifold = manifold(instance.Score, instance.Dist) + + # Restore hyperparameters + instance.natural_gradient = model_dict["natural_gradient"] + instance.n_estimators = model_dict["n_estimators"] + instance.learning_rate = model_dict["learning_rate"] + instance.minibatch_frac = model_dict["minibatch_frac"] + instance.col_sample = model_dict["col_sample"] + instance.verbose = model_dict["verbose"] + instance.verbose_eval = model_dict["verbose_eval"] + instance.tol = model_dict["tol"] + instance.validation_fraction = model_dict.get("validation_fraction", 0.1) + instance.early_stopping_rounds = model_dict.get("early_stopping_rounds", None) + instance.n_features = model_dict["n_features"] + instance.init_params = np.array(model_dict["init_params"]) + instance.best_val_loss_itr = model_dict.get("best_val_loss_itr", None) + + # Restore random state + if model_dict.get("random_state") is not None: + # random_state is saved as a list: [version, state_array, has_gauss, cached_gauss] + state_list = model_dict["random_state"] + state = ( + state_list[0], + np.array(state_list[1], dtype=np.uint32), + state_list[2] if len(state_list) > 2 else None, + ) + instance.random_state = check_random_state(None) + instance.random_state.set_state(state) + else: + instance.random_state = check_random_state(None) + + # Restore base models + instance.base_models = [] + if not model_dict["base_models"]: + raise ValueError( + "Invalid model dictionary: 'base_models' is empty. " + "The model must be fitted before serialization." + ) + for iteration_trees in model_dict["base_models"]: + iteration_models = [] + for tree_dict in iteration_trees: + iteration_models.append(tree_from_dict(tree_dict)) + instance.base_models.append(iteration_models) + + # Restore scalings and column indices + if len(model_dict["scalings"]) != len(model_dict["base_models"]): + raise ValueError( + f"Mismatch between number of scalings ({len(model_dict['scalings'])}) " + f"and base_models ({len(model_dict['base_models'])}). " + "The model file may be corrupted." + ) + if len(model_dict["col_idxs"]) != len(model_dict["base_models"]): + raise ValueError( + f"Mismatch between number of col_idxs ({len(model_dict['col_idxs'])}) " + f"and base_models ({len(model_dict['base_models'])}). " + "The model file may be corrupted." + ) + instance.scalings = [float(s) for s in model_dict["scalings"]] + instance.col_idxs = [ + list(idx) if isinstance(idx, list) else idx + for idx in model_dict["col_idxs"] + ] + + # Restore base learner (default to DecisionTreeRegressor) + instance.Base = default_tree_learner + + # Restore multi_output flag + if hasattr(instance.Dist, "multi_output"): + instance.multi_output = instance.Dist.multi_output + else: + instance.multi_output = False + + # Restore non-essential data if present + if "feature_importances_" in model_dict: + instance.feature_importances_ = np.array(model_dict["feature_importances_"]) + if "evals_result" in model_dict: + instance.evals_result = model_dict["evals_result"] + + return instance + + def save_json(self, filepath: str, include_non_essential: bool = False): + """ + Save the model to a JSON file. + + Parameters: + filepath: Path to save the JSON file + include_non_essential: If False, exclude feature_importances_ and evals_result + to reduce file size (default: False) + """ + model_dict = self.to_dict(include_non_essential=include_non_essential) + + filepath = Path(filepath) + with filepath.open("w", encoding="utf-8") as f: + json.dump(model_dict, f, indent=2) + + @classmethod + def load_json(cls, filepath: str): + """ + Load a model from a JSON file. + + Parameters: + filepath: Path to the JSON file + + Returns: + Reconstructed NGBoost model instance + """ + filepath = Path(filepath) + with filepath.open("r", encoding="utf-8") as f: + model_dict = json.load(f) + + return cls.from_dict(model_dict) + + def save_ubj(self, filepath: str, include_non_essential: bool = False): + """ + Save the model to a Universal Binary JSON (UBJ) file. + + Parameters: + filepath: Path to save the UBJ file + include_non_essential: If False, exclude feature_importances_ and evals_result + to reduce file size (default: False) + + Raises: + ImportError: If ubjson package is not installed + """ + if not UBJSON_AVAILABLE: + raise ImportError( + "ubjson package is required for UBJ serialization. " + "Install it with: pip install ubjson" + ) + + model_dict = self.to_dict(include_non_essential=include_non_essential) + + filepath = Path(filepath) + with filepath.open("wb") as f: + ubjson.dump(model_dict, f) + + @classmethod + def load_ubj(cls, filepath: str): + """ + Load a model from a Universal Binary JSON (UBJ) file. + + Parameters: + filepath: Path to the UBJ file + + Returns: + Reconstructed NGBoost model instance + + Raises: + ImportError: If ubjson package is not installed + """ + if not UBJSON_AVAILABLE: + raise ImportError( + "ubjson package is required for UBJ serialization. " + "Install it with: pip install ubjson" + ) + + filepath = Path(filepath) + with filepath.open("rb") as f: + model_dict = ubjson.load(f) + + return cls.from_dict(model_dict) diff --git a/ngboost/serialization.py b/ngboost/serialization.py new file mode 100644 index 00000000..248b8485 --- /dev/null +++ b/ngboost/serialization.py @@ -0,0 +1,99 @@ +"""Serialization utilities for NGBoost models to JSON and Universal Binary JSON formats.""" + +import base64 +import binascii +import pickle +from typing import Any, Dict + +import numpy as np +from sklearn.tree import DecisionTreeRegressor + +# UBJSON availability is checked in ngboost.py where it's actually used +# This module doesn't need to import ubjson directly + + +def tree_to_dict(tree: DecisionTreeRegressor) -> Dict[str, Any]: + """ + Convert a sklearn DecisionTreeRegressor to a JSON-serializable dictionary. + + Uses base64-encoded pickle for the tree structure since sklearn's tree + structure is read-only and cannot be reconstructed from arrays directly. + This is still more portable than full model pickle since only tree structures + are pickled, not the entire model. + + Parameters: + tree: sklearn DecisionTreeRegressor object + + Returns: + Dictionary containing tree structure (base64-encoded pickle) + """ + # Pickle the tree and encode as base64 for JSON compatibility + tree_pickle = pickle.dumps(tree) + tree_b64 = base64.b64encode(tree_pickle).decode("utf-8") + + return {"_pickle": tree_b64} + + +def tree_from_dict(tree_dict: Dict[str, Any]) -> DecisionTreeRegressor: + """ + Reconstruct a sklearn DecisionTreeRegressor from a dictionary. + + Parameters: + tree_dict: Dictionary containing tree structure (base64-encoded pickle) + + Returns: + DecisionTreeRegressor object + + Raises: + ValueError: If the tree dictionary is invalid or corrupted + TypeError: If the unpickled object is not a DecisionTreeRegressor + """ + if "_pickle" not in tree_dict: + raise ValueError( + "Invalid tree dictionary: missing '_pickle' key. " + "The dictionary may be corrupted or in an unsupported format." + ) + + try: + # Decode base64 and unpickle the tree + tree_b64 = tree_dict["_pickle"] + tree_pickle = base64.b64decode(tree_b64.encode("utf-8")) + tree = pickle.loads(tree_pickle) + except (binascii.Error, UnicodeDecodeError) as e: + raise ValueError( + f"Failed to decode tree data: {e}. " "The model file may be corrupted." + ) from e + except (pickle.PickleError, AttributeError, ImportError) as e: + raise ValueError( + f"Failed to unpickle tree: {e}. " + "The model may have been saved with an incompatible version of sklearn." + ) from e + + if not isinstance(tree, DecisionTreeRegressor): + raise TypeError( + f"Expected DecisionTreeRegressor, got {type(tree)}. " + "The model file may be corrupted." + ) + + return tree + + +def numpy_to_list(obj: Any) -> Any: + """ + Recursively convert numpy arrays and scalars to Python lists/values. + + Parameters: + obj: Object that may contain numpy arrays + + Returns: + Object with numpy arrays converted to lists + """ + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, (np.integer, np.floating)): + return obj.item() + if isinstance(obj, dict): + return {key: numpy_to_list(value) for key, value in obj.items()} + if isinstance(obj, (list, tuple)): + return [numpy_to_list(item) for item in obj] + return obj diff --git a/pyproject.toml b/pyproject.toml index 8569f048..c410b1d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ pre-commit = "^4.0.0" isort = "^5.13.0" pylint = "^3.2.0" flake8 = "^7.1.0" +py-ubjson = "*" [build-system] diff --git a/tests/test_json_serialization.py b/tests/test_json_serialization.py new file mode 100644 index 00000000..488746db --- /dev/null +++ b/tests/test_json_serialization.py @@ -0,0 +1,302 @@ +"""Tests for JSON and UBJ serialization of NGBoost models.""" + +import tempfile +from pathlib import Path + +import numpy as np +import pytest + +from ngboost import NGBClassifier, NGBRegressor, NGBSurvival +from ngboost.distns import MultivariateNormal + +try: + import ubjson # noqa: F401 +except ImportError: + ubjson = None + + +@pytest.fixture(name="simple_regressor") +def fixture_simple_regressor(california_housing_data): + """Create a simple fitted regressor for testing.""" + X_train, _, Y_train, _ = california_housing_data + ngb = NGBRegressor(verbose=False, n_estimators=10) + ngb.fit(X_train, Y_train) + return ngb, X_train + + +@pytest.fixture(name="simple_classifier") +def fixture_simple_classifier(breast_cancer_data): + """Create a simple fitted classifier for testing.""" + X_train, _, Y_train, _ = breast_cancer_data + ngb = NGBClassifier(verbose=False, n_estimators=10) + ngb.fit(X_train, Y_train) + return ngb, X_train + + +@pytest.fixture(name="simple_survival") +def fixture_simple_survival(california_housing_survival_data): + """Create a simple fitted survival model for testing.""" + X_train, _, T_train, E_train, _ = california_housing_survival_data + ngb = NGBSurvival(verbose=False, n_estimators=10) + ngb.fit(X_train, T_train, E_train) + return ngb, X_train + + +def test_to_dict_regressor(simple_regressor): + """Test that to_dict() works for regressor.""" + ngb, X_train = simple_regressor + + model_dict = ngb.to_dict() + + assert "version" in model_dict + assert "model_type" in model_dict + assert model_dict["model_type"] == "NGBRegressor" + assert "base_models" in model_dict + assert "scalings" in model_dict + assert "col_idxs" in model_dict + assert "init_params" in model_dict + assert len(model_dict["base_models"]) == len(ngb.base_models) + + +def test_from_dict_regressor(simple_regressor): + """Test that from_dict() reconstructs regressor correctly.""" + ngb, X_train = simple_regressor + + # Get original predictions + original_preds = ngb.predict(X_train) + original_dists = ngb.pred_dist(X_train) + + # Serialize and deserialize + model_dict = ngb.to_dict() + ngb_loaded = NGBRegressor.from_dict(model_dict) + + # Check predictions match + loaded_preds = ngb_loaded.predict(X_train) + np.testing.assert_array_almost_equal(original_preds, loaded_preds, decimal=5) + + # Check distribution parameters match + loaded_dists = ngb_loaded.pred_dist(X_train) + # Compare params dict values + if isinstance(original_dists.params, dict): + for key in original_dists.params: + np.testing.assert_array_almost_equal( + original_dists.params[key], loaded_dists.params[key], decimal=5 + ) + else: + np.testing.assert_array_almost_equal( + original_dists.params, loaded_dists.params, decimal=5 + ) + + +def test_save_load_json_regressor(simple_regressor): + """Test save_json() and load_json() for regressor.""" + ngb, X_train = simple_regressor + + original_preds = ngb.predict(X_train) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + filepath = f.name + + try: + ngb.save_json(filepath) + + ngb_loaded = NGBRegressor.load_json(filepath) + loaded_preds = ngb_loaded.predict(X_train) + + np.testing.assert_array_almost_equal(original_preds, loaded_preds, decimal=5) + finally: + Path(filepath).unlink() + + +def test_to_dict_classifier(simple_classifier): + """Test that to_dict() works for classifier.""" + ngb, X_train = simple_classifier + + model_dict = ngb.to_dict() + + assert model_dict["model_type"] == "NGBClassifier" + assert "base_models" in model_dict + + +def test_from_dict_classifier(simple_classifier): + """Test that from_dict() reconstructs classifier correctly.""" + ngb, X_train = simple_classifier + + original_preds = ngb.predict(X_train) + original_proba = ngb.predict_proba(X_train) + + model_dict = ngb.to_dict() + ngb_loaded = NGBClassifier.from_dict(model_dict) + + loaded_preds = ngb_loaded.predict(X_train) + loaded_proba = ngb_loaded.predict_proba(X_train) + + np.testing.assert_array_equal(original_preds, loaded_preds) + np.testing.assert_array_almost_equal(original_proba, loaded_proba, decimal=5) + + +def test_to_dict_survival(simple_survival): + """Test that to_dict() works for survival model.""" + ngb, X_train = simple_survival + + model_dict = ngb.to_dict() + + assert model_dict["model_type"] == "NGBSurvival" + assert "base_models" in model_dict + + +def test_from_dict_survival(simple_survival): + """Test that from_dict() reconstructs survival model correctly.""" + ngb, X_train = simple_survival + + original_preds = ngb.predict(X_train) + + model_dict = ngb.to_dict() + ngb_loaded = NGBSurvival.from_dict(model_dict) + + loaded_preds = ngb_loaded.predict(X_train) + np.testing.assert_array_almost_equal(original_preds, loaded_preds, decimal=5) + + +def test_exclude_non_essential(simple_regressor): + """Test that include_non_essential=False excludes optional data.""" + ngb, _ = simple_regressor + + # Force computation of feature_importances_ + _ = ngb.feature_importances_ + + model_dict_with = ngb.to_dict(include_non_essential=True) + model_dict_without = ngb.to_dict(include_non_essential=False) + + assert "feature_importances_" in model_dict_with + assert "feature_importances_" not in model_dict_without + + +def test_multivariate_normal(california_housing_survival_data): + """Test serialization with MultivariateNormal distribution.""" + X_surv_train, _, T_surv_train, E_surv_train, _ = california_housing_survival_data + + ngb = NGBRegressor(Dist=MultivariateNormal(2), n_estimators=10, verbose=False) + Y_mvn = np.vstack((T_surv_train, E_surv_train)).T + ngb.fit(X_surv_train, Y_mvn) + + original_preds = ngb.predict(X_surv_train) + + model_dict = ngb.to_dict() + ngb_loaded = NGBRegressor.from_dict(model_dict) + + loaded_preds = ngb_loaded.predict(X_surv_train) + np.testing.assert_array_almost_equal(original_preds, loaded_preds, decimal=5) + + +def test_ubj_serialization(simple_regressor): + """Test UBJ serialization if ubjson is available.""" + if ubjson is None: + pytest.skip("ubjson package not available") + + ngb, X_train = simple_regressor + + original_preds = ngb.predict(X_train) + + with tempfile.NamedTemporaryFile(mode="wb", suffix=".ubj", delete=False) as f: + filepath = f.name + + try: + ngb.save_ubj(filepath) + + ngb_loaded = NGBRegressor.load_ubj(filepath) + loaded_preds = ngb_loaded.predict(X_train) + + np.testing.assert_array_almost_equal(original_preds, loaded_preds, decimal=5) + finally: + Path(filepath).unlink() + + +def test_ubj_import_error(simple_regressor): + """Test that save_ubj raises ImportError when ubjson is not available.""" + # pylint: disable=import-outside-toplevel + import ngboost.ngboost as ngb_module + + ngb, _ = simple_regressor + + # Temporarily disable UBJSON + original_available = ngb_module.UBJSON_AVAILABLE + ngb_module.UBJSON_AVAILABLE = False + + try: + with tempfile.NamedTemporaryFile(mode="wb", suffix=".ubj", delete=False) as f: + filepath = f.name + + try: + with pytest.raises(ImportError, match="ubjson"): + ngb.save_ubj(filepath) + finally: + Path(filepath).unlink(missing_ok=True) + finally: + ngb_module.UBJSON_AVAILABLE = original_available + + +def test_json_file_size(simple_regressor): + """Test that excluding non-essential data reduces file size.""" + ngb, _ = simple_regressor + + # Force computation of feature_importances_ + _ = ngb.feature_importances_ + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + filepath_with = f.name + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + filepath_without = f.name + + try: + ngb.save_json(filepath_with, include_non_essential=True) + ngb.save_json(filepath_without, include_non_essential=False) + + size_with = Path(filepath_with).stat().st_size + size_without = Path(filepath_without).stat().st_size + + # File without non-essential data should be smaller or equal + assert size_without <= size_with + finally: + Path(filepath_with).unlink(missing_ok=True) + Path(filepath_without).unlink(missing_ok=True) + + +def test_to_dict_unfitted_model(): + """Test that to_dict() raises error for unfitted model.""" + ngb = NGBRegressor() + + with pytest.raises(ValueError, match="Model must be fitted"): + ngb.to_dict() + + +def test_from_dict_missing_keys(): + """Test that from_dict() raises error for invalid dictionary.""" + invalid_dict = {"version": "1.0"} + + with pytest.raises(ValueError, match="missing required keys"): + NGBRegressor.from_dict(invalid_dict) + + +def test_from_dict_corrupted_tree(simple_regressor): + """Test that from_dict() handles corrupted tree data gracefully.""" + ngb, _ = simple_regressor + + model_dict = ngb.to_dict() + # Corrupt a tree dictionary + model_dict["base_models"][0][0]["_pickle"] = "invalid_base64" + + with pytest.raises(ValueError, match="Failed to decode"): + NGBRegressor.from_dict(model_dict) + + +def test_from_dict_version_check(simple_regressor): + """Test that from_dict() checks version compatibility.""" + ngb, _ = simple_regressor + + model_dict = ngb.to_dict() + model_dict["version"] = "2.0" # Future version + + with pytest.raises(ValueError, match="Unsupported model version"): + NGBRegressor.from_dict(model_dict)