Skip to content

Commit 277ccee

Browse files
committed
Rework of aryl_halides benchmark
1 parent f3c1302 commit 277ccee

File tree

4 files changed

+225
-8
lines changed

4 files changed

+225
-8
lines changed

benchmarks/domains/__init__.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
"""Benchmark domains."""
22

33
from benchmarks.definition.base import Benchmark
4-
from benchmarks.domains.arylhalides_tl_substance import (
5-
arylhalides_tl_substance_benchmark,
6-
)
7-
from benchmarks.domains.direct_arylation_tl_temp import (
4+
from benchmarks.domains.direct_arylation_tl_temperature import (
85
direct_arylation_tl_temp_benchmark,
96
)
107
from benchmarks.domains.easom_tl_noise import easom_tl_noise_benchmark
118
from benchmarks.domains.hartmann_tl_inverted_noise import (
129
hartmann_tl_inverted_noise_benchmark,
1310
)
1411
from benchmarks.domains.michalewicz_tl_noise import michalewicz_tl_noise_benchmark
15-
16-
# from benchmarks.domains.synthetic_2C1D_1C import synthetic_2C1D_1C_benchmark
12+
from benchmarks.domains.synthetic_2C1D_1C import synthetic_2C1D_1C_benchmark
13+
from benchmarks.domains.transfer_learning.aryl_halides.ChlorTrifluour_IodMeth import (
14+
arylhalides_1Iodo4Metho_1Chloro4Trifluour_benchmark,
15+
)
1716

1817
BENCHMARKS: list[Benchmark] = [
19-
# synthetic_2C1D_1C_benchmark,
20-
arylhalides_tl_substance_benchmark,
18+
synthetic_2C1D_1C_benchmark,
19+
arylhalides_1Iodo4Metho_1Chloro4Trifluour_benchmark,
2120
direct_arylation_tl_temp_benchmark,
2221
hartmann_tl_inverted_noise_benchmark,
2322
easom_tl_noise_benchmark,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""Aryl-Halide benchmark for transfer learning.
2+
3+
As source parameter, this benchmark uses 1-chloro-4-(trifluoromethyl)benzene.
4+
As target parameter, this benchmark uses 1-iodo-4-methoxybenzene.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import pandas as pd
10+
11+
from benchmarks.definition import (
12+
ConvergenceBenchmark,
13+
ConvergenceBenchmarkSettings,
14+
)
15+
from benchmarks.domains.transfer_learning.aryl_halides.base import (
16+
abstract_arylhalides_tl_substance_benchmark,
17+
)
18+
19+
20+
def arylhalides_ChloroTrifluour_IodoMetho(
21+
settings: ConvergenceBenchmarkSettings,
22+
) -> pd.DataFrame:
23+
"""Actual benchmark function."""
24+
return abstract_arylhalides_tl_substance_benchmark(
25+
settings=settings,
26+
source_tasks=["1-chloro-4-(trifluoromethyl)benzene"],
27+
target_tasks=["1-iodo-4-methoxybenzene"],
28+
percentages=[0.01, 0.02],
29+
)
30+
31+
32+
benchmark_config = ConvergenceBenchmarkSettings(
33+
batch_size=2,
34+
n_doe_iterations=10,
35+
n_mc_iterations=30,
36+
)
37+
38+
arylhalides_1Iodo4Metho_1Chloro4Trifluour_benchmark = ConvergenceBenchmark(
39+
function=arylhalides_ChloroTrifluour_IodoMetho,
40+
optimal_target_values={"yield": 68.24812709999999},
41+
settings=benchmark_config,
42+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Aryl halides transfer learning benchmarks."""
2+
3+
from benchmarks.domains.transfer_learning.aryl_halides.ChlorTrifluour_IodMeth import (
4+
arylhalides_ChloroTrifluour_IodoMetho,
5+
)
6+
7+
__all__ = ["arylhalides_ChloroTrifluour_IodoMetho"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
"""Benchmark on ArylHalides data with two distinct arylhalides as TL tasks.
2+
3+
This file provides the basic structure such that one can easily create different
4+
benchmarks by changing the source and target tasks. The benchmark compares TL and
5+
non-TL campaigns.
6+
7+
By convention, the benchmarks are named in the format "SourceHalides-TargetHalides.py"
8+
where SourceHalides and TargetHalides are abbreviations of the used source and target
9+
tasks respectively.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import pandas as pd
15+
16+
from baybe.campaign import Campaign
17+
from baybe.objectives import SingleTargetObjective
18+
from baybe.parameters import SubstanceParameter, TaskParameter
19+
from baybe.searchspace import SearchSpace
20+
from baybe.simulation import simulate_scenarios
21+
from baybe.targets import NumericalTarget
22+
from benchmarks.data.utils import DATA_PATH
23+
from benchmarks.definition import (
24+
ConvergenceBenchmarkSettings,
25+
)
26+
27+
28+
def get_data() -> pd.DataFrame:
29+
"""Load the data for the benchmark."""
30+
data_path = DATA_PATH / "ArylHalides"
31+
data = pd.read_table(data_path / "data.csv", sep=",").dropna(
32+
subset=["base", "ligand", "additive", "aryl_halide"]
33+
)
34+
# Only keep relevant columns
35+
data = data[
36+
[
37+
"base",
38+
"ligand",
39+
"additive",
40+
"ligand_smiles",
41+
"base_smiles",
42+
"additive_smiles",
43+
"aryl_halide",
44+
"yield",
45+
]
46+
]
47+
return data
48+
49+
50+
def create_searchspace(
51+
data: pd.DataFrame,
52+
use_task_parameter: bool,
53+
target_tasks: list[str],
54+
source_tasks: list[str],
55+
) -> SearchSpace:
56+
"""Create the search space for the benchmark."""
57+
params = [
58+
SubstanceParameter(
59+
name=substance,
60+
data=dict(zip(data[substance], data[f"{substance}_smiles"])),
61+
encoding="MORDRED",
62+
)
63+
for substance in ["base", "ligand", "additive"]
64+
]
65+
if use_task_parameter:
66+
params.append(
67+
TaskParameter(
68+
name="aryl_halide",
69+
values=target_tasks + source_tasks,
70+
active_values=target_tasks,
71+
)
72+
)
73+
return SearchSpace.from_product(parameters=params)
74+
75+
76+
def create_objective() -> SingleTargetObjective:
77+
"""Create the objective for the benchmark."""
78+
return SingleTargetObjective(NumericalTarget(name="yield", mode="MAX"))
79+
80+
81+
def create_lookup(data: pd.DataFrame, target_tasks: list[str]) -> pd.DataFrame:
82+
"""Create the lookup for the benchmark."""
83+
return data[data["aryl_halide"].isin(target_tasks)]
84+
85+
86+
def create_initial_data(data: pd.DataFrame, source_tasks: list[str]) -> pd.DataFrame:
87+
"""Create the initial data for the benchmark."""
88+
return data[data["aryl_halide"].isin(source_tasks)]
89+
90+
91+
def abstract_arylhalides_tl_substance_benchmark(
92+
settings: ConvergenceBenchmarkSettings,
93+
source_tasks: list[str],
94+
target_tasks: list[str],
95+
percentages: list[float],
96+
) -> pd.DataFrame:
97+
"""Benchmark function comparing TL and non-TL campaigns.
98+
99+
Inputs:
100+
base: Substance with MORDRED encoding
101+
ligand: Substance with MORDRED encoding
102+
additive: Substance with MORDRED encoding
103+
aryl_halide: Task parameter
104+
Output: Continuous (yield)
105+
Objective: Maximization
106+
Optimal Inputs:
107+
base: "MTBD",
108+
ligand: "AdBrettPhos",
109+
additive: "N,N-dibenzylisoxazol-3-amine"
110+
Optimal Output: 68.24812709999999
111+
"""
112+
data = get_data()
113+
114+
# target_tasks = ["1-iodo-4-methoxybenzene"]
115+
# source_tasks = [
116+
# # Dissimilar source task
117+
# "1-chloro-4-(trifluoromethyl)benzene"
118+
# ]
119+
searchspace = create_searchspace(
120+
data=data,
121+
use_task_parameter=True,
122+
source_tasks=source_tasks,
123+
target_tasks=target_tasks,
124+
)
125+
searchspace_nontl = create_searchspace(
126+
data=data,
127+
use_task_parameter=False,
128+
source_tasks=source_tasks,
129+
target_tasks=target_tasks,
130+
)
131+
132+
lookup = create_lookup(data, target_tasks)
133+
initial_data = create_initial_data(data, source_tasks)
134+
135+
tl_campaign = Campaign(
136+
searchspace=searchspace,
137+
objective=create_objective(),
138+
)
139+
non_tl_campaign = Campaign(
140+
searchspace=searchspace_nontl, objective=create_objective()
141+
)
142+
143+
results = []
144+
for p in percentages:
145+
results.append(
146+
simulate_scenarios(
147+
{f"{int(100 * p)}": tl_campaign},
148+
lookup,
149+
initial_data=[
150+
initial_data.sample(frac=p) for _ in range(settings.n_mc_iterations)
151+
],
152+
batch_size=settings.batch_size,
153+
n_doe_iterations=settings.n_doe_iterations,
154+
impute_mode="error",
155+
)
156+
)
157+
# No training data and non-TL campaign
158+
results.append(
159+
simulate_scenarios(
160+
{"0": tl_campaign, "non_TL": non_tl_campaign},
161+
lookup,
162+
batch_size=settings.batch_size,
163+
n_doe_iterations=settings.n_doe_iterations,
164+
n_mc_iterations=settings.n_mc_iterations,
165+
impute_mode="error",
166+
)
167+
)
168+
results = pd.concat(results)
169+
return results

0 commit comments

Comments
 (0)