From 2183bc716addc6b9943fff35c41879d01e54fb67 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Thu, 6 Mar 2025 17:55:52 +0100 Subject: [PATCH 01/17] Add posterior_statistics --- baybe/campaign.py | 60 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/baybe/campaign.py b/baybe/campaign.py index aca581029..96059f2bb 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -19,6 +19,7 @@ from baybe.constraints.base import DiscreteConstraint from baybe.exceptions import IncompatibilityError, NotEnoughPointsLeftError from baybe.objectives.base import Objective, to_objective +from baybe.objectives.desirability import DesirabilityObjective from baybe.parameters.base import Parameter from baybe.recommenders.base import RecommenderProtocol from baybe.recommenders.meta.base import MetaRecommender @@ -533,6 +534,65 @@ def posterior(self, candidates: pd.DataFrame) -> Posterior: with torch.no_grad(): return surrogate.posterior(candidates) + def posterior_statistics( + self, candidates: pd.DataFrame, std_instead_of_var: bool = True + ) -> pd.DataFrame: + """Return common posterior statistics for each target. + + Args: + candidates: The candidate points in experimental recommendations. + For details, see :meth:`baybe.surrogates.base.Surrogate.posterior`. + std_instead_of_var: Flag deciding if the standard deviation or variance is + returned (if supported by the posterior). + + Raises: + TypeError: If the posterior utilized by the surrogate does not support + any of the possible statistics. + + Returns: + Data frame with prediction statistics for each target for each candidate. + """ + posterior = self.posterior(candidates) + + considered_stats = ["mean", "variance", "mode"] + supported_stats = [x for x in considered_stats if hasattr(posterior, x)] + if not supported_stats: + raise TypeError( + f"The utilized posterior is of type {posterior.__class__.__name__} and " + f"does not support any of the possible statistics: {considered_stats}. " + f"To call {self.posterior_statistics.__name__}, at least one of these " + f"statistics must be supported by the surrogate posterior." + ) + + assert self.objective is not None + match self.objective: + case DesirabilityObjective(): + # TODO: Once desirability also supports posterior transforms this check + # here will have to depend on the configuration of the obejctive and + # whether it uses the transforms or not. + targets = ["desirability"] + case _: + targets = [t.name for t in self.objective.targets] + + stats = pd.DataFrame(index=candidates.index) + for i, t in enumerate(targets): + for stat in supported_stats: + vals = ( + getattr(posterior, stat) + .cpu() + .numpy() + .reshape((len(stats), len(targets))) + ) + if stat == "variance" and std_instead_of_var: + stat_name = "std" + vals = np.sqrt(vals) + else: + stat_name = stat + + stats[f"{t}_{stat_name}"] = vals[:, i] + + return stats + def get_surrogate( self, batch_size: int | None = None, From f87d5adb2d12371455757fdeeb8dc6d4d64e6ed5 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Thu, 6 Mar 2025 18:57:59 +0100 Subject: [PATCH 02/17] Add test --- baybe/campaign.py | 2 +- tests/conftest.py | 4 +- tests/test_campaign.py | 88 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 3 deletions(-) diff --git a/baybe/campaign.py b/baybe/campaign.py index 96059f2bb..a77f3dd2d 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -570,7 +570,7 @@ def posterior_statistics( # TODO: Once desirability also supports posterior transforms this check # here will have to depend on the configuration of the obejctive and # whether it uses the transforms or not. - targets = ["desirability"] + targets = ["Desirability"] case _: targets = [t.name for t in self.objective.targets] diff --git a/tests/conftest.py b/tests/conftest.py index f644a67b4..e13832a10 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,7 +24,7 @@ from torch._C import _LinAlgError from baybe._optional.info import CHEM_INSTALLED -from baybe.acquisition import qExpectedImprovement +from baybe.acquisition import qLogExpectedImprovement from baybe.campaign import Campaign from baybe.constraints import ( ContinuousCardinalityConstraint, @@ -700,7 +700,7 @@ def fixture_default_streaming_sequential_meta_recommender(): @pytest.fixture(name="acqf") def fixture_default_acquisition_function(): """The default acquisition function to be used if not specified differently.""" - return qExpectedImprovement() + return qLogExpectedImprovement() @pytest.fixture(name="lengthscale_prior") diff --git a/tests/test_campaign.py b/tests/test_campaign.py index f501e04f5..721ec1776 100644 --- a/tests/test_campaign.py +++ b/tests/test_campaign.py @@ -4,17 +4,25 @@ import pandas as pd import pytest +from pandas.testing import assert_index_equal from pytest import param +from baybe.acquisition import qLogEI, qLogNEHVI, qTS from baybe.campaign import _EXCLUDED, Campaign from baybe.constraints.conditions import SubSelectionCondition from baybe.constraints.discrete import DiscreteExcludeConstraint +from baybe.objectives import DesirabilityObjective, ParetoObjective from baybe.parameters.numerical import ( NumericalContinuousParameter, NumericalDiscreteParameter, ) from baybe.searchspace.core import SearchSpaceType from baybe.searchspace.discrete import SubspaceDiscrete +from baybe.surrogates import ( + BetaBernoulliMultiArmedBanditSurrogate, + GaussianProcessSurrogate, +) +from baybe.targets import BinaryTarget, NumericalTarget from baybe.utils.basic import UNSPECIFIED from .conftest import run_iterations @@ -113,3 +121,83 @@ def test_setting_allow_flags(flag, space_type, value): with pytest.raises(ValueError) if expect_error else nullcontext(): Campaign(parameter, **kwargs) + + +@pytest.mark.parametrize( + ("parameter_names", "objective", "surrogate_model", "acqf", "batch_size"), + [ + param( + ["Categorical_1", "Num_Disc_1", "Conti_finite1"], + NumericalTarget("t1", "MAX").to_objective(), + GaussianProcessSurrogate(), + qLogEI(), + 3, + id="single_target", + ), + param( + ["Categorical_1", "Num_Disc_1", "Conti_finite1"], + DesirabilityObjective( + ( + NumericalTarget("t1", "MAX", bounds=(0, 1)), + NumericalTarget("t2", "MIN", bounds=(0, 1)), + ) + ), + GaussianProcessSurrogate(), + qLogEI(), + 3, + id="desirability", + ), + param( + ["Categorical_1", "Num_Disc_1", "Conti_finite1"], + ParetoObjective( + (NumericalTarget("t1", "MAX"), NumericalTarget("t2", "MIN")) + ), + GaussianProcessSurrogate(), + qLogNEHVI(), + 3, + id="pareto", + ), + param( + ["Categorical_1"], + BinaryTarget(name="Target_binary").to_objective(), + BetaBernoulliMultiArmedBanditSurrogate(), + qTS(), + 1, + id="bernoulli", + ), + ], +) +@pytest.mark.parametrize("std_instead_of_var", [True, False], ids=["std", "var"]) +@pytest.mark.parametrize("n_grid_points", [5], ids=["g5"]) +@pytest.mark.parametrize("n_iterations", [1], ids=["i1"]) +def test_posterior_statistics( + ongoing_campaign, n_iterations, batch_size, std_instead_of_var +): + """Posterior statistics can have expected shape, index and columns.""" + stats = ongoing_campaign.posterior_statistics( + ongoing_campaign.measurements, std_instead_of_var + ) + print(stats) + + # Assert number of entries and index + ( + assert_index_equal(ongoing_campaign.measurements.index, stats.index), + (ongoing_campaign.measurements.index, stats.index), + ) + + # Assert expected columns are present + # mode is not tested as Pareto posteriors do not provide it. + match ongoing_campaign.objective: + case DesirabilityObjective(): + targets = ["Desirability"] + case _: + targets = [t.name for t in ongoing_campaign.objective.targets] + tested_stats = {"mean"} | ({"std"} if std_instead_of_var else {"variance"}) + for t in targets: + for stat in tested_stats: + assert ( + sum(f"{t}_{stat}" in x for x in stats.columns) == 1 + ), f"{t}_{stat} not in the returned posterior statistics" + + # Assert no NaN's present + assert not stats.isna().any().any() From 43dc9a40a3862d4c21047d90c2cc8cf3f49190ca Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Thu, 6 Mar 2025 19:29:56 +0100 Subject: [PATCH 03/17] Update user guide --- docs/userguide/campaigns.md | 45 +++++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/docs/userguide/campaigns.md b/docs/userguide/campaigns.md index 7b65fc48e..f58d6d26d 100644 --- a/docs/userguide/campaigns.md +++ b/docs/userguide/campaigns.md @@ -13,9 +13,9 @@ It further serves as the primary interface for interacting with BayBE as a user since it is responsible for handling experimental data, making recommendations, adding measurements, and most other user-related tasks. -## Creating a campaign +## Creating a Campaign -### Basic creation +### Basic Creation Creating a campaign requires specifying at least two pieces of information that describe the underlying optimization problem at hand: @@ -40,7 +40,7 @@ campaign = Campaign( ) ~~~ -### Creation from a JSON config +### Creation From a JSON Config Instead of using the default constructor, it is also possible to create a `Campaign` from a JSON configuration string via [`Campaign.from_config`](baybe.campaign.Campaign.from_config). @@ -52,7 +52,7 @@ instantiating the object, which skips the potentially costly search space creati For more details and a full exemplary config, we refer to the corresponding [example](./../../examples/Serialization/create_from_config). -## Getting recommendations +## Getting Recommendations ### Basics @@ -76,7 +76,7 @@ with the three parameters `Categorical_1`, `Categorical_2` and `Num_disc_1`: | 18 | C | bad | 1 | | 9 | B | bad | 1 | -```{admonition} Batch optimization +```{admonition} Batch Optimization :class: important In general, the parameter configurations in a recommended batch are **jointly** optimized and therefore tailored to the specific batch size requested. @@ -100,7 +100,7 @@ is not capable of joint optimization. Currently, the is the only recommender available that performs joint optimization. ``` -```{admonition} Sequential vs. parallel experimentation +```{admonition} Sequential vs. Parallel Experimentation :class: note If you have a fixed experimental budget but the luxury of choosing whether to run your experiments sequentially or in parallel, sequential @@ -125,7 +125,34 @@ far. This is done by setting the following Boolean flags: `pending_experiments` can be recommended (see [asynchronous workflows](PENDING_EXPERIMENTS)). -### Caching of recommendations +### Prediction Statistics +You might be interested in statistics about the predicted target values for your +recommendations, or indeed for any set of possible candidate points. The +[`posterior`](baybe.campaign.Campaign.posterior) and +[`posterior_statistics`](baybe.campaign.Campaign.posterior_statistics) methods provide +a simple interface to look at the resulting statistics: +~~~python +stats = campaign.posterior_statistics(rec) +~~~ + +This will return a table with mean and standard deviation (and possibly other +statistics) of the target predictions for the provided candidates: + +| | Yield_mean | Yield_std | Selectivity_mean | Selectivity_std | ... | +|---:|:-----------|:----------|:-----------------|:-----------------|-----| +| 15 | 83.54 | 5.23 | 91.22 | 7.42 |.....| +| 18 | 56.12 | 2.34 | 87.32 | 12.38 |.....| +| 9 | 59.10 | 5.34 | 83.72 | 9.62 |.....| + +```{admonition} Posterior Statistics with Desirability Objectives +:class: note +A [`DesirabilityObjective`](baybe.objectives.desirability.DesirabilityObjective) +scalarizes all targets into one single quantity called "Desirability". As a result, +the posterior statistics are only shown for this quantity, and not for individual +targets. +``` + +### Caching of Recommendations The `Campaign` object caches the last batch of recommendations returned, in order to avoid unnecessary computations for subsequent queries between which the status @@ -136,7 +163,7 @@ The latter is necessary because each batch is optimized for the specific number experiments requested (see note above). (AM)= -## Adding measurements +## Adding Measurements Available experimental data can be added at any time during the campaign lifecycle using the [`add_measurements`](baybe.campaign.Campaign.add_measurements) method, @@ -200,7 +227,7 @@ experimentation at a later point in time: 5. Run your (potentially lengthy) real-world experiments 6. Repeat -## Further information +## Further Information Campaigns are created as a first step in most of our [examples](./../../examples/examples). From e9a2c45d659ab6f493ccc614c7d4be657b0726d5 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Thu, 6 Mar 2025 19:32:10 +0100 Subject: [PATCH 04/17] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 235244537..774bf7a93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `SubstanceParameter`, `CustomDisreteParameter` and `CategoricalParameter` now also support restricting the search space via `active_values`, while `values` continue to identify allowed measurement inputs +- `Campaign.posterior_statistics` as convenience for providing statistical measures + about the target predictions of a given set of candidates ### Changed - Acquisition function indicator `is_mc` has been removed in favor of new indicators From f7e5eb1c21fef28f963cb9bbac655e9be163f6df Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Mon, 10 Mar 2025 18:13:59 +0100 Subject: [PATCH 05/17] Rework statistics selection --- baybe/campaign.py | 82 +++++++++++++++++++++---------------- docs/userguide/campaigns.md | 13 ++++++ tests/test_campaign.py | 49 ++++++++++++++++------ 3 files changed, 95 insertions(+), 49 deletions(-) diff --git a/baybe/campaign.py b/baybe/campaign.py index a77f3dd2d..67eee7a0e 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -6,11 +6,12 @@ import json from collections.abc import Callable, Collection from functools import reduce -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeAlias import cattrs import numpy as np import pandas as pd +import torch from attrs import Attribute, Factory, define, evolve, field, fields from attrs.converters import optional from attrs.validators import instance_of @@ -89,6 +90,9 @@ def _validate_allow_flag(campaign: Campaign, attribute: Attribute, value: Any) - ) +Statistic: TypeAlias = float | Literal["mean", "std", "variance", "mode"] + + @define class Campaign(SerialMixin): """Main class for interaction with BayBE. @@ -529,69 +533,75 @@ def posterior(self, candidates: pd.DataFrame) -> Posterior: f"provide a '{method_name}' method." ) - import torch - with torch.no_grad(): return surrogate.posterior(candidates) def posterior_statistics( - self, candidates: pd.DataFrame, std_instead_of_var: bool = True + self, candidates: pd.DataFrame, statistics: Sequence[Statistic] | None = None ) -> pd.DataFrame: """Return common posterior statistics for each target. Args: - candidates: The candidate points in experimental recommendations. + candidates: The candidate points in experimental representation. For details, see :meth:`baybe.surrogates.base.Surrogate.posterior`. - std_instead_of_var: Flag deciding if the standard deviation or variance is - returned (if supported by the posterior). + statistics: Sequence indicating which statistics to compute. Also accepts + floats, for which the corresponding quantile point will be computed. Raises: + ValueError: If a requested quantile is outside the open interval (0,1). TypeError: If the posterior utilized by the surrogate does not support - any of the possible statistics. + a requested statistic. Returns: Data frame with prediction statistics for each target for each candidate. """ + statistics = statistics or ["mean", "std"] + for stat in (x for x in statistics if isinstance(x, float)): + if not 0 < stat < 1.0: + raise ValueError( + f"Posterior quantile statistics can only be computed for quantiles " + f"between 0 and 1 (non-inclusive). Provided value: '{stat}' as " + f"part of {statistics=}'." + ) posterior = self.posterior(candidates) - considered_stats = ["mean", "variance", "mode"] - supported_stats = [x for x in considered_stats if hasattr(posterior, x)] - if not supported_stats: - raise TypeError( - f"The utilized posterior is of type {posterior.__class__.__name__} and " - f"does not support any of the possible statistics: {considered_stats}. " - f"To call {self.posterior_statistics.__name__}, at least one of these " - f"statistics must be supported by the surrogate posterior." - ) - assert self.objective is not None match self.objective: case DesirabilityObjective(): # TODO: Once desirability also supports posterior transforms this check - # here will have to depend on the configuration of the obejctive and + # here will have to depend on the configuration of the objective and # whether it uses the transforms or not. targets = ["Desirability"] case _: targets = [t.name for t in self.objective.targets] - stats = pd.DataFrame(index=candidates.index) + result = pd.DataFrame(index=candidates.index) for i, t in enumerate(targets): - for stat in supported_stats: - vals = ( - getattr(posterior, stat) - .cpu() - .numpy() - .reshape((len(stats), len(targets))) - ) - if stat == "variance" and std_instead_of_var: - stat_name = "std" - vals = np.sqrt(vals) - else: - stat_name = stat - - stats[f"{t}_{stat_name}"] = vals[:, i] - - return stats + for stat in statistics: + stat_name = f"Q_{stat}" if isinstance(stat, float) else stat + + try: + vals = ( + posterior.quantile(torch.tensor(stat)) + if isinstance(stat, float) + else getattr(posterior, stat if stat != "std" else "variance") + ) + if stat == "std": + vals = torch.sqrt(vals) + except (AttributeError, NotImplementedError) as e: + # We could arrive here because an invalid statistics string has + # been requested or because a quantile point has been requested, + # but the posterior type does not implement quantiles. + raise TypeError( + f"The utilized posterior of type " + f"{posterior.__class__.__name__} does not support the " + f"statistic associated with the requested input '{stat}'." + ) from e + + vals = vals.cpu().numpy().reshape((len(result), len(targets))) + result[f"{t}_{stat_name}"] = vals[:, i] + + return result def get_surrogate( self, diff --git a/docs/userguide/campaigns.md b/docs/userguide/campaigns.md index f58d6d26d..a22e21b2f 100644 --- a/docs/userguide/campaigns.md +++ b/docs/userguide/campaigns.md @@ -144,6 +144,19 @@ statistics) of the target predictions for the provided candidates: | 18 | 56.12 | 2.34 | 87.32 | 12.38 |.....| | 9 | 59.10 | 5.34 | 83.72 | 9.62 |.....| +You can also provide an optional sequence of statistic names to compute other +statistics. If a float is provided, the corresponding quantile points will be +calculated: +~~~python +stats = campaign.posterior_statistics(rec, statistics=["mode", 0.5]) +~~~ + +| | Yield_mode | Yield_Q_0.5 | Selectivity_mode | Selectivity_Q_0.5 | ... | +|---:|:-----------|:------------|:-----------------|:------------------|-----| +| 15 | 83.54 | 83.54 | 91.22 | 91.22 |.....| +| 18 | 56.12 | 56.12 | 87.32 | 87.32 |.....| +| 9 | 59.10 | 59.10 | 83.72 | 83.72 |.....| + ```{admonition} Posterior Statistics with Desirability Objectives :class: note A [`DesirabilityObjective`](baybe.objectives.desirability.DesirabilityObjective) diff --git a/tests/test_campaign.py b/tests/test_campaign.py index 721ec1776..68fd03421 100644 --- a/tests/test_campaign.py +++ b/tests/test_campaign.py @@ -167,17 +167,22 @@ def test_setting_allow_flags(flag, space_type, value): ), ], ) -@pytest.mark.parametrize("std_instead_of_var", [True, False], ids=["std", "var"]) @pytest.mark.parametrize("n_grid_points", [5], ids=["g5"]) @pytest.mark.parametrize("n_iterations", [1], ids=["i1"]) -def test_posterior_statistics( - ongoing_campaign, n_iterations, batch_size, std_instead_of_var -): +def test_posterior_statistics(ongoing_campaign, n_iterations, batch_size): """Posterior statistics can have expected shape, index and columns.""" + objective = ongoing_campaign.objective + tested_stats = {"mean", "std"} + test_quantiles = not ( + isinstance(objective, ParetoObjective) + or isinstance(objective.targets[0], BinaryTarget) + ) + if test_quantiles: + tested_stats |= {0.05, 0.95} + stats = ongoing_campaign.posterior_statistics( - ongoing_campaign.measurements, std_instead_of_var + ongoing_campaign.measurements, tested_stats ) - print(stats) # Assert number of entries and index ( @@ -185,19 +190,37 @@ def test_posterior_statistics( (ongoing_campaign.measurements.index, stats.index), ) - # Assert expected columns are present - # mode is not tested as Pareto posteriors do not provide it. - match ongoing_campaign.objective: + # Assert expected columns are present. + match objective: case DesirabilityObjective(): targets = ["Desirability"] case _: - targets = [t.name for t in ongoing_campaign.objective.targets] - tested_stats = {"mean"} | ({"std"} if std_instead_of_var else {"variance"}) + targets = [t.name for t in objective.targets] + for t in targets: for stat in tested_stats: + stat_name = f"Q_{stat}" if isinstance(stat, float) else stat assert ( - sum(f"{t}_{stat}" in x for x in stats.columns) == 1 - ), f"{t}_{stat} not in the returned posterior statistics" + sum(f"{t}_{stat_name}" in x for x in stats.columns) == 1 + ), f"{t}_{stat_name} not in the returned posterior statistics" # Assert no NaN's present assert not stats.isna().any().any() + + # Assert correct error for unsupported statistics + with pytest.raises(TypeError, match="does not support the statistic associated"): + ongoing_campaign.posterior_statistics( + ongoing_campaign.measurements, ["invalid"] + ) + + if test_quantiles: + # Assert correct error for invalid quantiles + with pytest.raises(ValueError, match="quantile statistics can only be"): + ongoing_campaign.posterior_statistics(ongoing_campaign.measurements, [-0.1]) + ongoing_campaign.posterior_statistics(ongoing_campaign.measurements, [1.1]) + else: + # Assert correct error for unsupported quantile calculation + with pytest.raises( + TypeError, match="does not support the statistic associated" + ): + ongoing_campaign.posterior_statistics(ongoing_campaign.measurements, [0.1]) From f6519b342c836ff8bc52e00dd83966296fe7b505 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Mon, 10 Mar 2025 18:18:06 +0100 Subject: [PATCH 06/17] Rename method --- CHANGELOG.md | 2 +- baybe/campaign.py | 14 +++++++------- docs/userguide/campaigns.md | 6 +++--- tests/test_campaign.py | 16 +++++++--------- 4 files changed, 18 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 774bf7a93..b66742528 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,7 +36,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `SubstanceParameter`, `CustomDisreteParameter` and `CategoricalParameter` now also support restricting the search space via `active_values`, while `values` continue to identify allowed measurement inputs -- `Campaign.posterior_statistics` as convenience for providing statistical measures +- `Campaign.posterior_stats` as convenience for providing statistical measures about the target predictions of a given set of candidates ### Changed diff --git a/baybe/campaign.py b/baybe/campaign.py index 67eee7a0e..979f20213 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -536,15 +536,15 @@ def posterior(self, candidates: pd.DataFrame) -> Posterior: with torch.no_grad(): return surrogate.posterior(candidates) - def posterior_statistics( - self, candidates: pd.DataFrame, statistics: Sequence[Statistic] | None = None + def posterior_stats( + self, candidates: pd.DataFrame, stats: Sequence[Statistic] | None = None ) -> pd.DataFrame: """Return common posterior statistics for each target. Args: candidates: The candidate points in experimental representation. For details, see :meth:`baybe.surrogates.base.Surrogate.posterior`. - statistics: Sequence indicating which statistics to compute. Also accepts + stats: Sequence indicating which statistics to compute. Also accepts floats, for which the corresponding quantile point will be computed. Raises: @@ -555,13 +555,13 @@ def posterior_statistics( Returns: Data frame with prediction statistics for each target for each candidate. """ - statistics = statistics or ["mean", "std"] - for stat in (x for x in statistics if isinstance(x, float)): + stats = stats or ["mean", "std"] + for stat in (x for x in stats if isinstance(x, float)): if not 0 < stat < 1.0: raise ValueError( f"Posterior quantile statistics can only be computed for quantiles " f"between 0 and 1 (non-inclusive). Provided value: '{stat}' as " - f"part of {statistics=}'." + f"part of {stats=}'." ) posterior = self.posterior(candidates) @@ -577,7 +577,7 @@ def posterior_statistics( result = pd.DataFrame(index=candidates.index) for i, t in enumerate(targets): - for stat in statistics: + for stat in stats: stat_name = f"Q_{stat}" if isinstance(stat, float) else stat try: diff --git a/docs/userguide/campaigns.md b/docs/userguide/campaigns.md index a22e21b2f..2a75a44b6 100644 --- a/docs/userguide/campaigns.md +++ b/docs/userguide/campaigns.md @@ -129,10 +129,10 @@ far. This is done by setting the following Boolean flags: You might be interested in statistics about the predicted target values for your recommendations, or indeed for any set of possible candidate points. The [`posterior`](baybe.campaign.Campaign.posterior) and -[`posterior_statistics`](baybe.campaign.Campaign.posterior_statistics) methods provide +[`posterior_stats`](baybe.campaign.Campaign.posterior_stats) methods provide a simple interface to look at the resulting statistics: ~~~python -stats = campaign.posterior_statistics(rec) +stats = campaign.posterior_stats(rec) ~~~ This will return a table with mean and standard deviation (and possibly other @@ -148,7 +148,7 @@ You can also provide an optional sequence of statistic names to compute other statistics. If a float is provided, the corresponding quantile points will be calculated: ~~~python -stats = campaign.posterior_statistics(rec, statistics=["mode", 0.5]) +stats = campaign.posterior_stats(rec, stats=["mode", 0.5]) ~~~ | | Yield_mode | Yield_Q_0.5 | Selectivity_mode | Selectivity_Q_0.5 | ... | diff --git a/tests/test_campaign.py b/tests/test_campaign.py index 68fd03421..4a96a6aff 100644 --- a/tests/test_campaign.py +++ b/tests/test_campaign.py @@ -169,8 +169,8 @@ def test_setting_allow_flags(flag, space_type, value): ) @pytest.mark.parametrize("n_grid_points", [5], ids=["g5"]) @pytest.mark.parametrize("n_iterations", [1], ids=["i1"]) -def test_posterior_statistics(ongoing_campaign, n_iterations, batch_size): - """Posterior statistics can have expected shape, index and columns.""" +def test_posterior_stats(ongoing_campaign, n_iterations, batch_size): + """Posterior statistics have expected shape, index and columns.""" objective = ongoing_campaign.objective tested_stats = {"mean", "std"} test_quantiles = not ( @@ -180,7 +180,7 @@ def test_posterior_statistics(ongoing_campaign, n_iterations, batch_size): if test_quantiles: tested_stats |= {0.05, 0.95} - stats = ongoing_campaign.posterior_statistics( + stats = ongoing_campaign.posterior_stats( ongoing_campaign.measurements, tested_stats ) @@ -209,18 +209,16 @@ def test_posterior_statistics(ongoing_campaign, n_iterations, batch_size): # Assert correct error for unsupported statistics with pytest.raises(TypeError, match="does not support the statistic associated"): - ongoing_campaign.posterior_statistics( - ongoing_campaign.measurements, ["invalid"] - ) + ongoing_campaign.posterior_stats(ongoing_campaign.measurements, ["invalid"]) if test_quantiles: # Assert correct error for invalid quantiles with pytest.raises(ValueError, match="quantile statistics can only be"): - ongoing_campaign.posterior_statistics(ongoing_campaign.measurements, [-0.1]) - ongoing_campaign.posterior_statistics(ongoing_campaign.measurements, [1.1]) + ongoing_campaign.posterior_stats(ongoing_campaign.measurements, [-0.1]) + ongoing_campaign.posterior_stats(ongoing_campaign.measurements, [1.1]) else: # Assert correct error for unsupported quantile calculation with pytest.raises( TypeError, match="does not support the statistic associated" ): - ongoing_campaign.posterior_statistics(ongoing_campaign.measurements, [0.1]) + ongoing_campaign.posterior_stats(ongoing_campaign.measurements, [0.1]) From 43d31420973fbdfbd4d0f5842dd98057a806645d Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Mon, 10 Mar 2025 18:32:36 +0100 Subject: [PATCH 07/17] Silence MyPy --- baybe/campaign.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/baybe/campaign.py b/baybe/campaign.py index 979f20213..4c55d0204 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -4,9 +4,9 @@ import gc import json -from collections.abc import Callable, Collection +from collections.abc import Callable, Collection, Sequence from functools import reduce -from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeAlias +from typing import TYPE_CHECKING, Any, Literal, TypeAlias import cattrs import numpy as np @@ -576,8 +576,8 @@ def posterior_stats( targets = [t.name for t in self.objective.targets] result = pd.DataFrame(index=candidates.index) - for i, t in enumerate(targets): - for stat in stats: + for k, t in enumerate(targets): + for stat in stats: # type: ignore[assignment] stat_name = f"Q_{stat}" if isinstance(stat, float) else stat try: @@ -599,7 +599,7 @@ def posterior_stats( ) from e vals = vals.cpu().numpy().reshape((len(result), len(targets))) - result[f"{t}_{stat_name}"] = vals[:, i] + result[f"{t}_{stat_name}"] = vals[:, k] return result From fbd4abc00ad6d1144479e801709335786824f24d Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Tue, 11 Mar 2025 14:46:17 +0100 Subject: [PATCH 08/17] Import torch locally --- baybe/campaign.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/baybe/campaign.py b/baybe/campaign.py index 4c55d0204..6792a3929 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -11,7 +11,6 @@ import cattrs import numpy as np import pandas as pd -import torch from attrs import Attribute, Factory, define, evolve, field, fields from attrs.converters import optional from attrs.validators import instance_of @@ -533,6 +532,8 @@ def posterior(self, candidates: pd.DataFrame) -> Posterior: f"provide a '{method_name}' method." ) + import torch + with torch.no_grad(): return surrogate.posterior(candidates) @@ -575,6 +576,8 @@ def posterior_stats( case _: targets = [t.name for t in self.objective.targets] + import torch + result = pd.DataFrame(index=candidates.index) for k, t in enumerate(targets): for stat in stats: # type: ignore[assignment] From eefb091cc4a736a3a37bc25ecc78219260bd202a Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Tue, 11 Mar 2025 14:53:08 +0100 Subject: [PATCH 09/17] Use sequences in test --- tests/test_campaign.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_campaign.py b/tests/test_campaign.py index 4a96a6aff..fc31ef466 100644 --- a/tests/test_campaign.py +++ b/tests/test_campaign.py @@ -168,17 +168,17 @@ def test_setting_allow_flags(flag, space_type, value): ], ) @pytest.mark.parametrize("n_grid_points", [5], ids=["g5"]) -@pytest.mark.parametrize("n_iterations", [1], ids=["i1"]) +@pytest.mark.parametrize("n_iterations", [2], ids=["i2"]) def test_posterior_stats(ongoing_campaign, n_iterations, batch_size): """Posterior statistics have expected shape, index and columns.""" objective = ongoing_campaign.objective - tested_stats = {"mean", "std"} + tested_stats = ["mean", "std"] test_quantiles = not ( isinstance(objective, ParetoObjective) or isinstance(objective.targets[0], BinaryTarget) ) if test_quantiles: - tested_stats |= {0.05, 0.95} + tested_stats += [0.05, 0.95] stats = ongoing_campaign.posterior_stats( ongoing_campaign.measurements, tested_stats From 9662c275de6122e5cc806d0af90ee6e1b51e8431 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Tue, 11 Mar 2025 15:05:04 +0100 Subject: [PATCH 10/17] Rework typing --- baybe/campaign.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/baybe/campaign.py b/baybe/campaign.py index 6792a3929..04ba05840 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -56,6 +56,11 @@ _EXCLUDED = "excluded" _METADATA_COLUMNS = [_RECOMMENDED, _MEASURED, _EXCLUDED] +Statistic: TypeAlias = float | Literal["mean", "std", "variance", "mode"] +"""Type alias for requestable posterior statistics. + +A float will result in the corresponding quantile points.""" + def _make_allow_flag_default_factory( default: bool, @@ -89,9 +94,6 @@ def _validate_allow_flag(campaign: Campaign, attribute: Attribute, value: Any) - ) -Statistic: TypeAlias = float | Literal["mean", "std", "variance", "mode"] - - @define class Campaign(SerialMixin): """Main class for interaction with BayBE. @@ -538,7 +540,7 @@ def posterior(self, candidates: pd.DataFrame) -> Posterior: return surrogate.posterior(candidates) def posterior_stats( - self, candidates: pd.DataFrame, stats: Sequence[Statistic] | None = None + self, candidates: pd.DataFrame, stats: Sequence[Statistic] = ("mean", "std") ) -> pd.DataFrame: """Return common posterior statistics for each target. @@ -556,7 +558,7 @@ def posterior_stats( Returns: Data frame with prediction statistics for each target for each candidate. """ - stats = stats or ["mean", "std"] + stat: Statistic for stat in (x for x in stats if isinstance(x, float)): if not 0 < stat < 1.0: raise ValueError( @@ -580,7 +582,7 @@ def posterior_stats( result = pd.DataFrame(index=candidates.index) for k, t in enumerate(targets): - for stat in stats: # type: ignore[assignment] + for stat in stats: stat_name = f"Q_{stat}" if isinstance(stat, float) else stat try: From 66184a1094d961ba11487a53886d089f83377fb7 Mon Sep 17 00:00:00 2001 From: Martin Fitzner <17951239+Scienfitz@users.noreply.github.com> Date: Tue, 11 Mar 2025 15:11:27 +0100 Subject: [PATCH 11/17] Apply suggestions from code review Co-authored-by: AdrianSosic --- baybe/campaign.py | 29 ++++++++++++++--------------- docs/userguide/campaigns.md | 2 +- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/baybe/campaign.py b/baybe/campaign.py index 04ba05840..2e3419dd5 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -556,15 +556,15 @@ def posterior_stats( a requested statistic. Returns: - Data frame with prediction statistics for each target for each candidate. + A dataframe with posterior statistics for each target and candidate. """ stat: Statistic for stat in (x for x in stats if isinstance(x, float)): - if not 0 < stat < 1.0: + if not 0.0 < stat < 1.0: raise ValueError( f"Posterior quantile statistics can only be computed for quantiles " f"between 0 and 1 (non-inclusive). Provided value: '{stat}' as " - f"part of {stats=}'." + f"part of '{stats=}'." ) posterior = self.posterior(candidates) @@ -581,30 +581,29 @@ def posterior_stats( import torch result = pd.DataFrame(index=candidates.index) - for k, t in enumerate(targets): + for k, target_name in enumerate(targets): for stat in stats: - stat_name = f"Q_{stat}" if isinstance(stat, float) else stat - try: - vals = ( - posterior.quantile(torch.tensor(stat)) - if isinstance(stat, float) - else getattr(posterior, stat if stat != "std" else "variance") - ) - if stat == "std": - vals = torch.sqrt(vals) + if isinstance(stat, float): + stat_name = f"Q_{stat}" + vals = posterior.quantile(torch.tensor(stat)) + else: + stat_name = stat + vals = getattr(posterior, stat if stat != "std" else "variance") except (AttributeError, NotImplementedError) as e: # We could arrive here because an invalid statistics string has # been requested or because a quantile point has been requested, # but the posterior type does not implement quantiles. raise TypeError( f"The utilized posterior of type " - f"{posterior.__class__.__name__} does not support the " + f"'{posterior.__class__.__name__}' does not support the " f"statistic associated with the requested input '{stat}'." ) from e + if stat == "std": + vals = torch.sqrt(vals) vals = vals.cpu().numpy().reshape((len(result), len(targets))) - result[f"{t}_{stat_name}"] = vals[:, k] + result[f"{target_name}_{stat_name}"] = vals[:, k] return result diff --git a/docs/userguide/campaigns.md b/docs/userguide/campaigns.md index 2a75a44b6..8885f3665 100644 --- a/docs/userguide/campaigns.md +++ b/docs/userguide/campaigns.md @@ -130,7 +130,7 @@ You might be interested in statistics about the predicted target values for your recommendations, or indeed for any set of possible candidate points. The [`posterior`](baybe.campaign.Campaign.posterior) and [`posterior_stats`](baybe.campaign.Campaign.posterior_stats) methods provide -a simple interface to look at the resulting statistics: +a simple interface for this: ~~~python stats = campaign.posterior_stats(rec) ~~~ From 736b939c465cd48bfc50df4f7c73aee2418760e7 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Tue, 11 Mar 2025 15:53:54 +0100 Subject: [PATCH 12/17] Make candidates optional --- baybe/campaign.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/baybe/campaign.py b/baybe/campaign.py index 2e3419dd5..1e584969a 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -512,12 +512,14 @@ def recommend( return rec - def posterior(self, candidates: pd.DataFrame) -> Posterior: + def posterior(self, candidates: pd.DataFrame | None = None) -> Posterior: """Get the posterior predictive distribution for the given candidates. Args: - candidates: The candidate points in experimental recommendations. - For details, see :meth:`baybe.surrogates.base.Surrogate.posterior`. + candidates: The candidate points in experimental recommendations. If not + provided, the posterior for the existing campaign measurements is + returned. For details, see + :meth:`baybe.surrogates.base.Surrogate.posterior`. Raises: IncompatibilityError: If the underlying surrogate model exposes no @@ -527,6 +529,9 @@ def posterior(self, candidates: pd.DataFrame) -> Posterior: Posterior: The corresponding posterior object. For details, see :meth:`baybe.surrogates.base.Surrogate.posterior`. """ + if candidates is None: + candidates = self.measurements[[p.name for p in self.parameters]] + surrogate = self.get_surrogate() if not hasattr(surrogate, method_name := "posterior"): raise IncompatibilityError( @@ -540,13 +545,18 @@ def posterior(self, candidates: pd.DataFrame) -> Posterior: return surrogate.posterior(candidates) def posterior_stats( - self, candidates: pd.DataFrame, stats: Sequence[Statistic] = ("mean", "std") + self, + candidates: pd.DataFrame | None = None, + /, + stats: Sequence[Statistic] = ("mean", "std"), ) -> pd.DataFrame: """Return common posterior statistics for each target. Args: - candidates: The candidate points in experimental representation. - For details, see :meth:`baybe.surrogates.base.Surrogate.posterior`. + candidates: The candidate points in experimental representation. If not + provided, the statistics of the existing campaign measurements are + calculated. For details, see + :meth:`baybe.surrogates.base.Surrogate.posterior`. stats: Sequence indicating which statistics to compute. Also accepts floats, for which the corresponding quantile point will be computed. @@ -558,6 +568,9 @@ def posterior_stats( Returns: A dataframe with posterior statistics for each target and candidate. """ + if candidates is None: + candidates = self.measurements[[p.name for p in self.parameters]] + stat: Statistic for stat in (x for x in stats if isinstance(x, float)): if not 0.0 < stat < 1.0: From 085213ac57bd43e314b88c6ae318aa204cd81a5d Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Tue, 11 Mar 2025 16:03:59 +0100 Subject: [PATCH 13/17] Change `variance` to `var` --- baybe/campaign.py | 7 +++++-- tests/test_campaign.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/baybe/campaign.py b/baybe/campaign.py index 1e584969a..4df730f6a 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -56,7 +56,7 @@ _EXCLUDED = "excluded" _METADATA_COLUMNS = [_RECOMMENDED, _MEASURED, _EXCLUDED] -Statistic: TypeAlias = float | Literal["mean", "std", "variance", "mode"] +Statistic: TypeAlias = float | Literal["mean", "std", "var", "mode"] """Type alias for requestable posterior statistics. A float will result in the corresponding quantile points.""" @@ -602,7 +602,10 @@ def posterior_stats( vals = posterior.quantile(torch.tensor(stat)) else: stat_name = stat - vals = getattr(posterior, stat if stat != "std" else "variance") + vals = getattr( + posterior, + stat if stat not in ["std", "var"] else "variance", + ) except (AttributeError, NotImplementedError) as e: # We could arrive here because an invalid statistics string has # been requested or because a quantile point has been requested, diff --git a/tests/test_campaign.py b/tests/test_campaign.py index fc31ef466..dbfc1c62f 100644 --- a/tests/test_campaign.py +++ b/tests/test_campaign.py @@ -172,7 +172,7 @@ def test_setting_allow_flags(flag, space_type, value): def test_posterior_stats(ongoing_campaign, n_iterations, batch_size): """Posterior statistics have expected shape, index and columns.""" objective = ongoing_campaign.objective - tested_stats = ["mean", "std"] + tested_stats = ["mean", "std", "var"] test_quantiles = not ( isinstance(objective, ParetoObjective) or isinstance(objective.targets[0], BinaryTarget) From b7c9b139395a412b433312386b6b97e6c48d5871 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Wed, 12 Mar 2025 16:21:07 +0100 Subject: [PATCH 14/17] Move disabling of gradients --- baybe/campaign.py | 61 ++++++++++++++++++++++------------------------- 1 file changed, 29 insertions(+), 32 deletions(-) diff --git a/baybe/campaign.py b/baybe/campaign.py index 4df730f6a..099ebd79d 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -539,15 +539,11 @@ def posterior(self, candidates: pd.DataFrame | None = None) -> Posterior: f"provide a '{method_name}' method." ) - import torch - - with torch.no_grad(): - return surrogate.posterior(candidates) + return surrogate.posterior(candidates) def posterior_stats( self, candidates: pd.DataFrame | None = None, - /, stats: Sequence[Statistic] = ("mean", "std"), ) -> pd.DataFrame: """Return common posterior statistics for each target. @@ -593,33 +589,34 @@ def posterior_stats( import torch - result = pd.DataFrame(index=candidates.index) - for k, target_name in enumerate(targets): - for stat in stats: - try: - if isinstance(stat, float): - stat_name = f"Q_{stat}" - vals = posterior.quantile(torch.tensor(stat)) - else: - stat_name = stat - vals = getattr( - posterior, - stat if stat not in ["std", "var"] else "variance", - ) - except (AttributeError, NotImplementedError) as e: - # We could arrive here because an invalid statistics string has - # been requested or because a quantile point has been requested, - # but the posterior type does not implement quantiles. - raise TypeError( - f"The utilized posterior of type " - f"'{posterior.__class__.__name__}' does not support the " - f"statistic associated with the requested input '{stat}'." - ) from e - - if stat == "std": - vals = torch.sqrt(vals) - vals = vals.cpu().numpy().reshape((len(result), len(targets))) - result[f"{target_name}_{stat_name}"] = vals[:, k] + with torch.no_grad(): + result = pd.DataFrame(index=candidates.index) + for k, target_name in enumerate(targets): + for stat in stats: + try: + if isinstance(stat, float): + stat_name = f"Q_{stat}" + vals = posterior.quantile(torch.tensor(stat)) + else: + stat_name = stat + vals = getattr( + posterior, + stat if stat not in ["std", "var"] else "variance", + ) + except (AttributeError, NotImplementedError) as e: + # We could arrive here because an invalid statistics string has + # been requested or because a quantile point has been requested, + # but the posterior type does not implement quantiles. + raise TypeError( + f"The utilized posterior of type " + f"'{posterior.__class__.__name__}' does not support the " + f"statistic associated with the requested input '{stat}'." + ) from e + + if stat == "std": + vals = torch.sqrt(vals) + vals = vals.cpu().numpy().reshape((len(result), len(targets))) + result[f"{target_name}_{stat_name}"] = vals[:, k] return result From 09dbfb2ad237d38f6525cb5bd452f845403d797a Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Wed, 12 Mar 2025 16:45:46 +0100 Subject: [PATCH 15/17] Enable quantiles for PosteriorList --- baybe/campaign.py | 26 ++++++++++++++++++++------ tests/test_campaign.py | 13 +++++-------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/baybe/campaign.py b/baybe/campaign.py index 099ebd79d..c970968a1 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -589,15 +589,28 @@ def posterior_stats( import torch + result = pd.DataFrame(index=candidates.index) with torch.no_grad(): - result = pd.DataFrame(index=candidates.index) for k, target_name in enumerate(targets): for stat in stats: try: - if isinstance(stat, float): + if isinstance(stat, float): # Calculate quantile statistic stat_name = f"Q_{stat}" - vals = posterior.quantile(torch.tensor(stat)) - else: + from botorch.posteriors import PosteriorList + + if isinstance(posterior, PosteriorList): + # Special treatment for PosteriorList because .quantile + # is not implemented + vals = torch.cat( + [ + p.quantile(torch.tensor(stat)) + for p in posterior.posteriors + ], + dim=-1, + ) + else: + vals = posterior.quantile(torch.tensor(stat)) + else: # Calculate non-quantile statistic stat_name = stat vals = getattr( posterior, @@ -615,8 +628,9 @@ def posterior_stats( if stat == "std": vals = torch.sqrt(vals) - vals = vals.cpu().numpy().reshape((len(result), len(targets))) - result[f"{target_name}_{stat_name}"] = vals[:, k] + + numpyvals = vals.cpu().numpy().reshape((len(result), len(targets))) + result[f"{target_name}_{stat_name}"] = numpyvals[:, k] return result diff --git a/tests/test_campaign.py b/tests/test_campaign.py index dbfc1c62f..8edd553c8 100644 --- a/tests/test_campaign.py +++ b/tests/test_campaign.py @@ -127,7 +127,7 @@ def test_setting_allow_flags(flag, space_type, value): ("parameter_names", "objective", "surrogate_model", "acqf", "batch_size"), [ param( - ["Categorical_1", "Num_Disc_1", "Conti_finite1"], + ["Categorical_1", "Num_disc_1", "Conti_finite1"], NumericalTarget("t1", "MAX").to_objective(), GaussianProcessSurrogate(), qLogEI(), @@ -135,7 +135,7 @@ def test_setting_allow_flags(flag, space_type, value): id="single_target", ), param( - ["Categorical_1", "Num_Disc_1", "Conti_finite1"], + ["Categorical_1", "Num_disc_1", "Conti_finite1"], DesirabilityObjective( ( NumericalTarget("t1", "MAX", bounds=(0, 1)), @@ -148,7 +148,7 @@ def test_setting_allow_flags(flag, space_type, value): id="desirability", ), param( - ["Categorical_1", "Num_Disc_1", "Conti_finite1"], + ["Categorical_1", "Num_disc_1", "Conti_finite1"], ParetoObjective( (NumericalTarget("t1", "MAX"), NumericalTarget("t2", "MIN")) ), @@ -168,15 +168,12 @@ def test_setting_allow_flags(flag, space_type, value): ], ) @pytest.mark.parametrize("n_grid_points", [5], ids=["g5"]) -@pytest.mark.parametrize("n_iterations", [2], ids=["i2"]) +@pytest.mark.parametrize("n_iterations", [1], ids=["i1"]) def test_posterior_stats(ongoing_campaign, n_iterations, batch_size): """Posterior statistics have expected shape, index and columns.""" objective = ongoing_campaign.objective tested_stats = ["mean", "std", "var"] - test_quantiles = not ( - isinstance(objective, ParetoObjective) - or isinstance(objective.targets[0], BinaryTarget) - ) + test_quantiles = not isinstance(objective.targets[0], BinaryTarget) if test_quantiles: tested_stats += [0.05, 0.95] From 61703ae3d17c7444685de88b5099121d54c35a55 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Wed, 12 Mar 2025 17:14:26 +0100 Subject: [PATCH 16/17] Fix docs --- docs/userguide/campaigns.md | 79 ++++++++++++++++++------------------- tests/docs/test_docs.py | 11 ------ 2 files changed, 39 insertions(+), 51 deletions(-) diff --git a/docs/userguide/campaigns.md b/docs/userguide/campaigns.md index 8885f3665..41cc3fe58 100644 --- a/docs/userguide/campaigns.md +++ b/docs/userguide/campaigns.md @@ -125,46 +125,6 @@ far. This is done by setting the following Boolean flags: `pending_experiments` can be recommended (see [asynchronous workflows](PENDING_EXPERIMENTS)). -### Prediction Statistics -You might be interested in statistics about the predicted target values for your -recommendations, or indeed for any set of possible candidate points. The -[`posterior`](baybe.campaign.Campaign.posterior) and -[`posterior_stats`](baybe.campaign.Campaign.posterior_stats) methods provide -a simple interface for this: -~~~python -stats = campaign.posterior_stats(rec) -~~~ - -This will return a table with mean and standard deviation (and possibly other -statistics) of the target predictions for the provided candidates: - -| | Yield_mean | Yield_std | Selectivity_mean | Selectivity_std | ... | -|---:|:-----------|:----------|:-----------------|:-----------------|-----| -| 15 | 83.54 | 5.23 | 91.22 | 7.42 |.....| -| 18 | 56.12 | 2.34 | 87.32 | 12.38 |.....| -| 9 | 59.10 | 5.34 | 83.72 | 9.62 |.....| - -You can also provide an optional sequence of statistic names to compute other -statistics. If a float is provided, the corresponding quantile points will be -calculated: -~~~python -stats = campaign.posterior_stats(rec, stats=["mode", 0.5]) -~~~ - -| | Yield_mode | Yield_Q_0.5 | Selectivity_mode | Selectivity_Q_0.5 | ... | -|---:|:-----------|:------------|:-----------------|:------------------|-----| -| 15 | 83.54 | 83.54 | 91.22 | 91.22 |.....| -| 18 | 56.12 | 56.12 | 87.32 | 87.32 |.....| -| 9 | 59.10 | 59.10 | 83.72 | 83.72 |.....| - -```{admonition} Posterior Statistics with Desirability Objectives -:class: note -A [`DesirabilityObjective`](baybe.objectives.desirability.DesirabilityObjective) -scalarizes all targets into one single quantity called "Desirability". As a result, -the posterior statistics are only shown for this quantity, and not for individual -targets. -``` - ### Caching of Recommendations The `Campaign` object caches the last batch of recommendations returned, in order to @@ -210,6 +170,45 @@ This requirement can be disabled using the method's `numerical_measurements_must_be_within_tolerance` flag. ``` +## Prediction Statistics +You might be interested in statistics about the predicted target values for your +recommendations, or indeed for any set of possible candidate points. The +[`posterior`](baybe.campaign.Campaign.posterior) and +[`posterior_stats`](baybe.campaign.Campaign.posterior_stats) methods provide +a simple interface for this: +~~~python +stats = campaign.posterior_stats(rec) +~~~ + +This will return a table with mean and standard deviation (and possibly other +statistics) of the target predictions for the provided candidates: + +| | Yield_mean | Yield_std | Selectivity_mean | Selectivity_std | ... | +|---:|:-----------|:----------|:-----------------|:-----------------|-----| +| 15 | 83.54 | 5.23 | 91.22 | 7.42 | ... | +| 18 | 56.12 | 2.34 | 87.32 | 12.38 | ... | +| 9 | 59.10 | 5.34 | 83.72 | 9.62 | ... | + +You can also provide an optional sequence of statistic names to compute other +statistics. If a float is provided, the corresponding quantile points will be +calculated: +~~~python +stats = campaign.posterior_stats(rec, stats=["mode", 0.5]) +~~~ + +| | Yield_mode | Yield_Q_0.5 | Selectivity_mode | Selectivity_Q_0.5 | ... | +|---:|:-----------|:------------|:-----------------|:------------------|-----| +| 15 | 83.54 | 83.54 | 91.22 | 91.22 | ... | +| 18 | 56.12 | 56.12 | 87.32 | 87.32 | ... | +| 9 | 59.10 | 59.10 | 83.72 | 83.72 | ... | + +```{admonition} Posterior Statistics with Desirability Objectives +:class: note +A [`DesirabilityObjective`](baybe.objectives.desirability.DesirabilityObjective) +scalarizes all targets into one single quantity called "Desirability". As a result, +the posterior statistics are only shown for this quantity, and not for individual +targets. +``` ## Serialization diff --git a/tests/docs/test_docs.py b/tests/docs/test_docs.py index 74d86796d..902f5db86 100644 --- a/tests/docs/test_docs.py +++ b/tests/docs/test_docs.py @@ -7,7 +7,6 @@ import pytest from baybe._optional.info import CHEM_INSTALLED, LINT_INSTALLED -from baybe.recommenders import RandomRecommender, TwoPhaseMetaRecommender from .utils import extract_code_blocks @@ -42,20 +41,10 @@ def test_code_executability(file: Path, campaign): # TODO: Needs a refactoring (files codeblocks should be auto-detected) @pytest.mark.parametrize("file", doc_files_pseudocode, ids=doc_files_pseudocode) -@pytest.mark.parametrize( - "recommender", - [ - TwoPhaseMetaRecommender( - initial_recommender=RandomRecommender(), recommender=RandomRecommender() - ) - ], -) def test_pseudocode_executability(file: Path, searchspace, objective, recommender): """The pseudocode blocks in the file are a valid python script when using fixtures. Blocks surrounded with "triple-backticks" are included. - Due to a bug related to the serialization of the default recommender, this currently - uses a non-default recommender. """ userguide_pseudocode = "\n".join(extract_code_blocks(file, include_tilde=True)) exec(userguide_pseudocode) From e78fb53feb17175ea4745394e41f091c34ee55d6 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Wed, 12 Mar 2025 18:28:32 +0100 Subject: [PATCH 17/17] Move core logic to surrogates --- CHANGELOG.md | 5 +- baybe/campaign.py | 90 +++++----------------------------- baybe/constraints/discrete.py | 4 +- baybe/surrogates/base.py | 91 ++++++++++++++++++++++++++++++++++- baybe/surrogates/composite.py | 20 +++++++- docs/userguide/campaigns.md | 8 +-- 6 files changed, 129 insertions(+), 89 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b66742528..4a317ec90 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,8 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `SubstanceParameter`, `CustomDisreteParameter` and `CategoricalParameter` now also support restricting the search space via `active_values`, while `values` continue to identify allowed measurement inputs -- `Campaign.posterior_stats` as convenience for providing statistical measures - about the target predictions of a given set of candidates +- `Campaign.posterior_stats` and `Surrogate.posterior_stats` as convenience methods for + providing statistical measures about the target predictions of a given set of + candidates ### Changed - Acquisition function indicator `is_mc` has been removed in favor of new indicators diff --git a/baybe/campaign.py b/baybe/campaign.py index c970968a1..8d4508155 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -6,7 +6,7 @@ import json from collections.abc import Callable, Collection, Sequence from functools import reduce -from typing import TYPE_CHECKING, Any, Literal, TypeAlias +from typing import TYPE_CHECKING, Any import cattrs import numpy as np @@ -19,7 +19,6 @@ from baybe.constraints.base import DiscreteConstraint from baybe.exceptions import IncompatibilityError, NotEnoughPointsLeftError from baybe.objectives.base import Objective, to_objective -from baybe.objectives.desirability import DesirabilityObjective from baybe.parameters.base import Parameter from baybe.recommenders.base import RecommenderProtocol from baybe.recommenders.meta.base import MetaRecommender @@ -34,7 +33,7 @@ validate_searchspace_from_config, ) from baybe.serialization import SerialMixin, converter -from baybe.surrogates.base import SurrogateProtocol +from baybe.surrogates.base import PosteriorStatistic, SurrogateProtocol from baybe.targets.base import Target from baybe.telemetry import ( TELEM_LABELS, @@ -56,11 +55,6 @@ _EXCLUDED = "excluded" _METADATA_COLUMNS = [_RECOMMENDED, _MEASURED, _EXCLUDED] -Statistic: TypeAlias = float | Literal["mean", "std", "var", "mode"] -"""Type alias for requestable posterior statistics. - -A float will result in the corresponding quantile points.""" - def _make_allow_flag_default_factory( default: bool, @@ -544,15 +538,15 @@ def posterior(self, candidates: pd.DataFrame | None = None) -> Posterior: def posterior_stats( self, candidates: pd.DataFrame | None = None, - stats: Sequence[Statistic] = ("mean", "std"), + stats: Sequence[PosteriorStatistic] = ("mean", "std"), ) -> pd.DataFrame: - """Return common posterior statistics for each target. + """Return posterior statistics for each target. Args: candidates: The candidate points in experimental representation. If not provided, the statistics of the existing campaign measurements are calculated. For details, see - :meth:`baybe.surrogates.base.Surrogate.posterior`. + :meth:`baybe.surrogates.base.Surrogate.posterior_stats`. stats: Sequence indicating which statistics to compute. Also accepts floats, for which the corresponding quantile point will be computed. @@ -567,72 +561,14 @@ def posterior_stats( if candidates is None: candidates = self.measurements[[p.name for p in self.parameters]] - stat: Statistic - for stat in (x for x in stats if isinstance(x, float)): - if not 0.0 < stat < 1.0: - raise ValueError( - f"Posterior quantile statistics can only be computed for quantiles " - f"between 0 and 1 (non-inclusive). Provided value: '{stat}' as " - f"part of '{stats=}'." - ) - posterior = self.posterior(candidates) - - assert self.objective is not None - match self.objective: - case DesirabilityObjective(): - # TODO: Once desirability also supports posterior transforms this check - # here will have to depend on the configuration of the objective and - # whether it uses the transforms or not. - targets = ["Desirability"] - case _: - targets = [t.name for t in self.objective.targets] - - import torch - - result = pd.DataFrame(index=candidates.index) - with torch.no_grad(): - for k, target_name in enumerate(targets): - for stat in stats: - try: - if isinstance(stat, float): # Calculate quantile statistic - stat_name = f"Q_{stat}" - from botorch.posteriors import PosteriorList - - if isinstance(posterior, PosteriorList): - # Special treatment for PosteriorList because .quantile - # is not implemented - vals = torch.cat( - [ - p.quantile(torch.tensor(stat)) - for p in posterior.posteriors - ], - dim=-1, - ) - else: - vals = posterior.quantile(torch.tensor(stat)) - else: # Calculate non-quantile statistic - stat_name = stat - vals = getattr( - posterior, - stat if stat not in ["std", "var"] else "variance", - ) - except (AttributeError, NotImplementedError) as e: - # We could arrive here because an invalid statistics string has - # been requested or because a quantile point has been requested, - # but the posterior type does not implement quantiles. - raise TypeError( - f"The utilized posterior of type " - f"'{posterior.__class__.__name__}' does not support the " - f"statistic associated with the requested input '{stat}'." - ) from e - - if stat == "std": - vals = torch.sqrt(vals) - - numpyvals = vals.cpu().numpy().reshape((len(result), len(targets))) - result[f"{target_name}_{stat_name}"] = numpyvals[:, k] - - return result + surrogate = self.get_surrogate() + if not hasattr(surrogate, method_name := "posterior_stats"): + raise IncompatibilityError( + f"The used surrogate type '{surrogate.__class__.__name__}' does not " + f"provide a '{method_name}' method." + ) + + return surrogate.posterior_stats(candidates, stats) def get_surrogate( self, diff --git a/baybe/constraints/discrete.py b/baybe/constraints/discrete.py index 261624c54..e221ae9fd 100644 --- a/baybe/constraints/discrete.py +++ b/baybe/constraints/discrete.py @@ -71,7 +71,7 @@ class DiscreteSumConstraint(DiscreteConstraint): # class variables numerical_only: ClassVar[bool] = True - # see base class. + # See base class. # object variables condition: ThresholdCondition = field() @@ -99,7 +99,7 @@ class DiscreteProductConstraint(DiscreteConstraint): # class variables numerical_only: ClassVar[bool] = True - # see base class. + # See base class. # object variables condition: ThresholdCondition = field() diff --git a/baybe/surrogates/base.py b/baybe/surrogates/base.py index 340a015dd..43a86ed4e 100644 --- a/baybe/surrogates/base.py +++ b/baybe/surrogates/base.py @@ -4,8 +4,9 @@ import gc from abc import ABC, abstractmethod +from collections.abc import Sequence from enum import Enum, auto -from typing import TYPE_CHECKING, ClassVar, Protocol +from typing import TYPE_CHECKING, ClassVar, Literal, Protocol, TypeAlias import cattrs import pandas as pd @@ -21,6 +22,7 @@ from typing_extensions import override from baybe.exceptions import IncompatibleSurrogateError, ModelNotTrainedError +from baybe.objectives import DesirabilityObjective from baybe.objectives.base import Objective from baybe.parameters.base import Parameter from baybe.searchspace import SearchSpace @@ -43,6 +45,11 @@ from baybe.surrogates.composite import CompositeSurrogate +PosteriorStatistic: TypeAlias = float | Literal["mean", "std", "var", "mode"] +"""Type alias for requestable posterior statistics. + +A float will result in the corresponding quantile points.""" + _ONNX_ENCODING = "latin-1" """Constant signifying the encoding for onnx byte strings in pretrained models. @@ -218,7 +225,7 @@ def _make_output_scaler( return scaler - def posterior(self, candidates: pd.DataFrame, /) -> Posterior: + def posterior(self, candidates: pd.DataFrame) -> Posterior: """Compute the posterior for candidates in experimental representation. Takes a dataframe of parameter configurations in **experimental representation** @@ -306,6 +313,86 @@ def _posterior(self, candidates_comp_scaled: Tensor, /) -> Posterior: obtained via :meth:`baybe.surrogates.base.Surrogate._make_output_scaler`. """ + def posterior_stats( + self, + candidates: pd.DataFrame, + stats: Sequence[PosteriorStatistic] = ("mean", "std"), + ) -> pd.DataFrame: + """Return posterior statistics for each target. + + Args: + candidates: The candidate points in experimental representation. + For details, see :meth:`baybe.surrogates.base.Surrogate.posterior`. + stats: Sequence indicating which statistics to compute. Also accepts + floats, for which the corresponding quantile point will be computed. + + Raises: + ModelNotTrainedError: When called before the model has been trained. + ValueError: If a requested quantile is outside the open interval (0,1). + TypeError: If the posterior utilized by the surrogate does not support + a requested statistic. + + Returns: + A dataframe with posterior statistics for each target and candidate. + """ + if self._objective is None: + raise ModelNotTrainedError( + "The surrogate must be trained before a posterior can be computed." + ) + + stat: PosteriorStatistic + for stat in (x for x in stats if isinstance(x, float)): + if not 0.0 < stat < 1.0: + raise ValueError( + f"Posterior quantile statistics can only be computed for quantiles " + f"between 0 and 1 (non-inclusive). Provided value: '{stat}' as " + f"part of '{stats=}'." + ) + posterior = self.posterior(candidates) + + match self._objective: + case DesirabilityObjective(): + # TODO: Once desirability also supports posterior transforms this check + # here will have to depend on the configuration of the objective and + # whether it uses the transforms or not. + targets = ["Desirability"] + case _: + targets = [t.name for t in self._objective.targets] + + import torch + + result = pd.DataFrame(index=candidates.index) + with torch.no_grad(): + for k, target_name in enumerate(targets): + for stat in stats: + try: + if isinstance(stat, float): # Calculate quantile statistic + stat_name = f"Q_{stat}" + vals = posterior.quantile(torch.tensor(stat)) + else: # Calculate non-quantile statistic + stat_name = stat + vals = getattr( + posterior, + stat if stat not in ["std", "var"] else "variance", + ) + except (AttributeError, NotImplementedError) as e: + # We could arrive here because an invalid statistics string has + # been requested or because a quantile point has been requested, + # but the posterior type does not implement quantiles. + raise TypeError( + f"The utilized posterior of type " + f"'{posterior.__class__.__name__}' does not support the " + f"statistic associated with the requested input '{stat}'." + ) from e + + if stat == "std": + vals = torch.sqrt(vals) + + numpyvals = vals.cpu().numpy().reshape((len(result), len(targets))) + result[f"{target_name}_{stat_name}"] = numpyvals[:, k] + + return result + @override def fit( self, diff --git a/baybe/surrogates/composite.py b/baybe/surrogates/composite.py index 4e6caad69..417a65028 100644 --- a/baybe/surrogates/composite.py +++ b/baybe/surrogates/composite.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Sequence from copy import deepcopy from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar @@ -14,7 +15,7 @@ from baybe.searchspace.core import SearchSpace from baybe.serialization import converter from baybe.serialization.mixin import SerialMixin -from baybe.surrogates.base import SurrogateProtocol +from baybe.surrogates.base import PosteriorStatistic, SurrogateProtocol from baybe.surrogates.gaussian_process.core import GaussianProcessSurrogate from baybe.utils.basic import is_all_instance @@ -113,7 +114,7 @@ def to_botorch(self) -> ModelList: ) return cls(*(s.to_botorch() for s in self._surrogates_flat)) - def posterior(self, candidates: pd.DataFrame, /) -> PosteriorList: + def posterior(self, candidates: pd.DataFrame) -> PosteriorList: """Compute the posterior for candidates in experimental representation. The (independent joint) posterior is represented as a collection of individual @@ -133,6 +134,21 @@ def posterior(self, candidates: pd.DataFrame, /) -> PosteriorList: posteriors = [s.posterior(candidates) for s in self._surrogates_flat] # type: ignore[attr-defined] return PosteriorList(*posteriors) + def posterior_stats( + self, + candidates: pd.DataFrame, + stats: Sequence[PosteriorStatistic] = ("mean", "std"), + ) -> pd.DataFrame: + """See :meth:`baybe.surrogates.base.Surrogate.posterior_stats`.""" + if not all(hasattr(s, "posterior_stats") for s in self._surrogates_flat): + raise IncompatibleSurrogateError( + "Posterior statistics can only be computed if all involved surrogates " + "offer this computation." + ) + + dfs = [s.posterior_stats(candidates, stats) for s in self._surrogates_flat] # type: ignore[attr-defined] + return pd.concat(dfs, axis=1) + def _structure_surrogate_getter(obj: dict, _) -> _SurrogateGetter: """Resolve the object type.""" diff --git a/docs/userguide/campaigns.md b/docs/userguide/campaigns.md index 41cc3fe58..683b550e6 100644 --- a/docs/userguide/campaigns.md +++ b/docs/userguide/campaigns.md @@ -170,12 +170,12 @@ This requirement can be disabled using the method's `numerical_measurements_must_be_within_tolerance` flag. ``` -## Prediction Statistics +## Predictive Statistics You might be interested in statistics about the predicted target values for your recommendations, or indeed for any set of possible candidate points. The -[`posterior`](baybe.campaign.Campaign.posterior) and -[`posterior_stats`](baybe.campaign.Campaign.posterior_stats) methods provide -a simple interface for this: +[`Campaign.posterior_stats`](baybe.campaign.Campaign.posterior_stats) and +[`Surrogate.posterior_stats`](baybe.surrogates.base.Surrogate.posterior_stats) methods +provide a simple interface for this: ~~~python stats = campaign.posterior_stats(rec) ~~~