Skip to content

Commit bab73f8

Browse files
authored
Botorch with cardinality constraint via sampling (#301)
(edited by @AdrianSosic) This PR adds support for cardinality constraints to `BotorchRecommender`. The core idea is to tackle the problem in an exhaustive search like manner, i.e. by * enumerating the possible combinations of in-/active parameters dictated by the cardinality constraints * optimizing the corresponding restricted subspaces, where the cardinality constraint can then be simply removed since the in-/active sets are fixed within these subspaces * aggregating the optimization results of the individual subspaces into a single recommendation batch. The PR implements two mechanisms for determining the configuration of inactive parameters: - When the combinatorial list of possible inactive parameter configurations is not too large, we iterate the full list - otherwise, a fixed amount of inactive parameter configurations is randomly selected The current aggregation step is to simply optimize all subspaces independently of each other and then return the batch from the subspace where the highest acquisition value is achieved. This has the side-effect that the set of inactive parameters is the same across the entire recommendation batch. This can be a desirable property in many use cases but potentially higher acquisition values can be obtained by altering the in-/activity sets across the batch. A simple way to achieve this (though out of scope for this PR) is by generalizing the sequential greedy principle to multiple subspaces. ### Out of scope * Fulfilling cardinality constraints by passing them in suitable form as nonlinear constraints to the optimizer * Sequential greedy optimization to achieve varying in-/activity sets (see explanation above)
2 parents d32e05a + c582314 commit bab73f8

File tree

13 files changed

+1024
-93
lines changed

13 files changed

+1024
-93
lines changed

CHANGELOG.md

+13
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1919
- `is_multi_output` attribute to `Objective`
2020
- `supports_multi_output` attribute/property to `Surrogate`/`AcquisitionFunction`
2121
- `n_outputs` property to `Objective`
22+
- `ContinuousCardinalityConstraint` is now compatible with `BotorchRecommender`
23+
- A `MinimumCardinalityViolatedWarning` is triggered when minimum cardinality
24+
constraints are violated
25+
- Attribute `max_n_subspaces` to `BotorchRecommender`, allowing to control
26+
optimization behavior in the presence of cardinality constraints
27+
- Utilities `inactive_parameter_combinations` and`n_inactive_parameter_combinations`
28+
to both `ContinuousCardinalityConstraint`and `SubspaceContinuous` for iterating
29+
over cardinality-constrained parameter sets
30+
- Attribute `relative_threshold` and method `get_absolute_thresholds` to
31+
`ContinuousCardinalityConstraint` for handling inactivity ranges
32+
- Utilities `activate_parameter` and `is_cardinality_fulfilled` for enforcing and
33+
validating cardinality constraints
34+
- Utility `is_inactive` for determining if parameters are inactive
2235

2336
### Changed
2437
- Acquisition function indicator `is_mc` has been removed in favor of new indicators

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ The following provides a non-comprehensive overview:
4343
- 🎭 Hybrid (mixed continuous and discrete) spaces
4444
- 🚀 Transfer learning: Mix data from multiple campaigns and accelerate optimization
4545
- 🎰 Bandit models: Efficiently find the best among many options in noisy environments (e.g. A/B Testing)
46+
- 🔢 Cardinality constraints: Control the number of active factors in your design
4647
- 🌎 Distributed workflows: Run campaigns asynchronously with pending experiments
4748
- 🎓 Active learning: Perform smart data acquisition campaigns
4849
- ⚙️ Custom surrogate models: Enhance your predictions through mechanistic understanding

baybe/constraints/continuous.py

+67-2
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44

55
import gc
66
import math
7-
from collections.abc import Collection, Sequence
7+
from collections.abc import Collection, Iterator, Sequence
8+
from itertools import combinations
9+
from math import comb
810
from typing import TYPE_CHECKING, Any
911

1012
import numpy as np
11-
from attr.validators import in_
13+
from attr.validators import gt, in_, lt
1214
from attrs import define, field
1315

1416
from baybe.constraints.base import (
@@ -17,6 +19,7 @@
1719
ContinuousNonlinearConstraint,
1820
)
1921
from baybe.parameters import NumericalContinuousParameter
22+
from baybe.utils.interval import Interval
2023
from baybe.utils.numerical import DTypeFloatNumpy
2124
from baybe.utils.validation import finite_float
2225

@@ -138,6 +141,40 @@ class ContinuousCardinalityConstraint(
138141
):
139142
"""Class for continuous cardinality constraints."""
140143

144+
relative_threshold: float = field(
145+
default=1e-3, converter=float, validator=[gt(0.0), lt(1.0)]
146+
)
147+
"""A relative threshold for determining if a value is considered zero.
148+
149+
The threshold is translated into an asymmetric open interval around zero via
150+
:meth:`get_absolute_thresholds`.
151+
152+
**Note:** The interval induced by the threshold is considered **open** because
153+
numerical routines that optimize parameter values on the complementary set (i.e. the
154+
value range considered "nonzero") may push the numerical value exactly to the
155+
interval boundary, which should therefore also be considered "nonzero".
156+
"""
157+
158+
@property
159+
def n_inactive_parameter_combinations(self) -> int:
160+
"""The number of possible inactive parameter combinations."""
161+
return sum(
162+
comb(len(self.parameters), n_inactive_parameters)
163+
for n_inactive_parameters in self._inactive_set_sizes()
164+
)
165+
166+
def _inactive_set_sizes(self) -> range:
167+
"""Get all possible sizes of inactive parameter sets."""
168+
return range(
169+
len(self.parameters) - self.max_cardinality,
170+
len(self.parameters) - self.min_cardinality + 1,
171+
)
172+
173+
def inactive_parameter_combinations(self) -> Iterator[frozenset[str]]:
174+
"""Get an iterator over all possible combinations of inactive parameters."""
175+
for n_inactive_parameters in self._inactive_set_sizes():
176+
yield from combinations(self.parameters, n_inactive_parameters)
177+
141178
def sample_inactive_parameters(self, batch_size: int = 1) -> list[set[str]]:
142179
"""Sample sets of inactive parameters according to the cardinality constraints.
143180
@@ -176,6 +213,34 @@ def sample_inactive_parameters(self, batch_size: int = 1) -> list[set[str]]:
176213

177214
return inactive_params
178215

216+
def get_absolute_thresholds(self, bounds: Interval, /) -> Interval:
217+
"""Get the absolute thresholds for a given interval.
218+
219+
Turns the relative threshold of the constraint into absolute thresholds
220+
for the considered interval. That is, for a given interval ``(a, b)`` with
221+
``a <= 0`` and ``b >= 0``, the method returns the interval ``(r*a, r*b)``,
222+
where ``r`` is the relative threshold defined by the constraint.
223+
224+
Args:
225+
bounds: The specified interval.
226+
227+
Returns:
228+
The absolute thresholds represented as an interval.
229+
230+
Raises:
231+
ValueError: When the specified interval does not contain zero.
232+
"""
233+
if not bounds.contains(0.0):
234+
raise ValueError(
235+
f"The specified interval must contain zero. "
236+
f"Given: {bounds.to_tuple()}."
237+
)
238+
239+
return Interval(
240+
lower=self.relative_threshold * bounds.lower,
241+
upper=self.relative_threshold * bounds.upper,
242+
)
243+
179244

180245
# Collect leftover original slotted classes processed by `attrs.define`
181246
gc.collect()

baybe/constraints/utils.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""Constraint utilities."""
2+
3+
import numpy as np
4+
import pandas as pd
5+
6+
from baybe.parameters.utils import is_inactive
7+
from baybe.searchspace import SubspaceContinuous
8+
9+
10+
def is_cardinality_fulfilled(
11+
df: pd.DataFrame,
12+
subspace_continuous: SubspaceContinuous,
13+
*,
14+
check_minimum: bool = True,
15+
check_maximum: bool = True,
16+
) -> bool:
17+
"""Validate cardinality constraints in a dataframe of parameter configurations.
18+
19+
Args:
20+
df: The dataframe to be checked.
21+
subspace_continuous: The subspace spanned by the considered parameters.
22+
check_minimum: If ``True``, minimum cardinality constraints are validated.
23+
check_maximum: If ``True``, maximum cardinality constraints are validated.
24+
25+
Returns:
26+
``True`` if all cardinality constraints are fulfilled, ``False`` otherwise.
27+
"""
28+
for c in subspace_continuous.constraints_cardinality:
29+
# Get the activity thresholds for all parameters
30+
cols = df[c.parameters]
31+
thresholds = {
32+
p.name: c.get_absolute_thresholds(p.bounds)
33+
for p in subspace_continuous.get_parameters_by_name(c.parameters)
34+
}
35+
lower_thresholds = [thresholds[p].lower for p in cols.columns]
36+
upper_thresholds = [thresholds[p].upper for p in cols.columns]
37+
38+
# Count the number of active values per dataframe row
39+
inactives = is_inactive(cols, lower_thresholds, upper_thresholds)
40+
n_zeros = inactives.sum(axis=1)
41+
n_active = len(c.parameters) - n_zeros
42+
43+
# Check if cardinality is violated
44+
if check_minimum and np.any(n_active < c.min_cardinality):
45+
return False
46+
if check_maximum and np.any(n_active > c.max_cardinality):
47+
return False
48+
49+
return True

baybe/constraints/validation.py

+63
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,14 @@
88
from baybe.constraints.discrete import (
99
DiscreteDependenciesConstraint,
1010
)
11+
from baybe.parameters import NumericalContinuousParameter
1112
from baybe.parameters.base import Parameter
1213

14+
try: # For python < 3.11, use the exceptiongroup backport
15+
ExceptionGroup
16+
except NameError:
17+
from exceptiongroup import ExceptionGroup
18+
1319

1420
def validate_constraints( # noqa: DOC101, DOC103
1521
constraints: Collection[Constraint], parameters: Collection[Parameter]
@@ -26,6 +32,8 @@ def validate_constraints( # noqa: DOC101, DOC103
2632
ValueError: If any discrete constraint includes a continuous parameter.
2733
ValueError: If any discrete constraint that is valid only for numerical
2834
discrete parameters includes non-numerical discrete parameters.
35+
ValueError: If any parameter affected by a cardinality constraint does
36+
not include zero.
2937
"""
3038
if sum(isinstance(itm, DiscreteDependenciesConstraint) for itm in constraints) > 1:
3139
raise ValueError(
@@ -41,6 +49,9 @@ def validate_constraints( # noqa: DOC101, DOC103
4149
param_names_discrete = [p.name for p in parameters if p.is_discrete]
4250
param_names_continuous = [p.name for p in parameters if p.is_continuous]
4351
param_names_non_numerical = [p.name for p in parameters if not p.is_numerical]
52+
params_continuous: list[NumericalContinuousParameter] = [
53+
p for p in parameters if isinstance(p, NumericalContinuousParameter)
54+
]
4455

4556
for constraint in constraints:
4657
if not all(p in param_names_all for p in constraint.parameters):
@@ -78,6 +89,11 @@ def validate_constraints( # noqa: DOC101, DOC103
7889
f"Parameter list of the affected constraint: {constraint.parameters}."
7990
)
8091

92+
if isinstance(constraint, ContinuousCardinalityConstraint):
93+
validate_cardinality_constraint_parameter_bounds(
94+
constraint, params_continuous
95+
)
96+
8197

8298
def validate_cardinality_constraints_are_nonoverlapping(
8399
constraints: Collection[ContinuousCardinalityConstraint],
@@ -98,3 +114,50 @@ def validate_cardinality_constraints_are_nonoverlapping(
98114
f"cannot share the same parameters. Found the following overlapping "
99115
f"parameter sets: {s1}, {s2}."
100116
)
117+
118+
119+
def validate_cardinality_constraint_parameter_bounds(
120+
constraint: ContinuousCardinalityConstraint,
121+
parameters: Collection[NumericalContinuousParameter],
122+
) -> None:
123+
"""Validate that all parameters of a continuous cardinality constraint include zero.
124+
125+
Args:
126+
constraint: A continuous cardinality constraint.
127+
parameters: A collection of parameters, including those affected by the
128+
constraint.
129+
130+
Raises:
131+
ValueError: If one of the affected parameters does not include zero.
132+
ExceptionGroup: If several of the affected parameters do not include zero.
133+
"""
134+
exceptions = []
135+
for name in constraint.parameters:
136+
try:
137+
parameter = next(p for p in parameters if p.name == name)
138+
except StopIteration as ex:
139+
raise ValueError(
140+
f"The parameter '{name}' referenced by the constraint is not contained "
141+
f"in the given collection of parameters."
142+
) from ex
143+
144+
if not parameter.is_in_range(0.0):
145+
exceptions.append(
146+
ValueError(
147+
f"The bounds of all parameters affected by a constraint of type "
148+
f"'{ContinuousCardinalityConstraint.__name__}' must include zero, "
149+
f"but the bounds of parameter '{name}' are "
150+
f"{parameter.bounds.to_tuple()}, which may indicate unintended "
151+
f"settings in your parameter definition. "
152+
f"A parameter whose value range excludes zero trivially "
153+
f"increases the cardinality of the resulting configuration by one. "
154+
f"Therefore, if your parameter definitions are all correct, "
155+
f"consider excluding the parameter from the constraint and "
156+
f"reducing the cardinality limits by one accordingly."
157+
)
158+
)
159+
160+
if exceptions:
161+
if len(exceptions) == 1:
162+
raise exceptions[0]
163+
raise ExceptionGroup("Invalid parameter bounds", exceptions)

baybe/exceptions.py

+8
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ def __str__(self):
3333
return self.message
3434

3535

36+
class MinimumCardinalityViolatedWarning(UserWarning):
37+
"""Minimum cardinality constraints are violated."""
38+
39+
3640
##### Exceptions #####
3741

3842

@@ -63,6 +67,10 @@ class IncompatibleArgumentError(IncompatibilityError):
6367
"""An incompatible argument was passed to a callable."""
6468

6569

70+
class InfeasibilityError(Exception):
71+
"""An optimization problem has no feasible solution."""
72+
73+
6674
class NotEnoughPointsLeftError(Exception):
6775
"""
6876
More recommendations are requested than there are viable parameter configurations

baybe/parameters/numerical.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def is_in_range(self, item: float) -> bool:
136136

137137
@override
138138
@property
139-
def comp_rep_columns(self) -> tuple[str, ...]:
139+
def comp_rep_columns(self) -> tuple[str]:
140140
return (self.name,)
141141

142142
@override
@@ -150,5 +150,38 @@ def summary(self) -> dict:
150150
return param_dict
151151

152152

153+
@define(frozen=True, slots=False)
154+
class _FixedNumericalContinuousParameter(ContinuousParameter):
155+
"""Parameter class for fixed numerical parameters."""
156+
157+
is_numeric: ClassVar[bool] = True
158+
# See base class.
159+
160+
value: float = field(converter=float)
161+
"""The fixed value of the parameter."""
162+
163+
@property
164+
def bounds(self) -> Interval:
165+
"""The value of the parameter as a degenerate interval."""
166+
return Interval(self.value, self.value)
167+
168+
@override
169+
def is_in_range(self, item: float) -> bool:
170+
return item == self.value
171+
172+
@override
173+
@property
174+
def comp_rep_columns(self) -> tuple[str]:
175+
return (self.name,)
176+
177+
@override
178+
def summary(self) -> dict:
179+
return dict(
180+
Name=self.name,
181+
Type=self.__class__.__name__,
182+
Value=self.value,
183+
)
184+
185+
153186
# Collect leftover original slotted classes processed by `attrs.define`
154187
gc.collect()

0 commit comments

Comments
 (0)