Skip to content

Commit a26f78f

Browse files
committed
Adapt the Gaussian data module to the new noising paradigm.
1 parent 0a2ce8b commit a26f78f

File tree

3 files changed

+79
-32
lines changed

3 files changed

+79
-32
lines changed

src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/gaussian_data_module.py

+67-30
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@
99

1010
from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.data_module_parameters import \
1111
DataModuleParameters
12+
from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.noising_transform import \
13+
NoisingTransform
1214
from diffusion_for_multi_scale_molecular_dynamics.data.element_types import \
1315
ElementTypes
1416
from diffusion_for_multi_scale_molecular_dynamics.namespace import (
1517
ATOM_TYPES, CARTESIAN_FORCES, RELATIVE_COORDINATES)
18+
from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \
19+
NoiseParameters
1620
from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \
1721
map_relative_coordinates_to_unit_cell
1822

@@ -24,6 +28,9 @@ class GaussianDataModuleParameters(DataModuleParameters):
2428
"""Hyper-parameters for a Gaussian, in memory data module."""
2529
data_source = "gaussian"
2630

31+
noise_parameters: NoiseParameters
32+
use_optimal_transport: bool
33+
2734
random_seed: int
2835
# the number of atoms in a configuration.
2936
number_of_atoms: int
@@ -42,14 +49,18 @@ def __post_init__(self):
4249
"""Post init."""
4350
assert self.sigma_d > 0.0, "the sigma_d parameter should be positive."
4451

45-
assert len(self.equilibrium_relative_coordinates) == self.number_of_atoms, \
46-
"There should be exactly one list of equilibrium coordinates per atom."
52+
assert (
53+
len(self.equilibrium_relative_coordinates) == self.number_of_atoms
54+
), "There should be exactly one list of equilibrium coordinates per atom."
4755

4856
for x in self.equilibrium_relative_coordinates:
49-
assert len(x) == self.spatial_dimension, \
50-
"The equilibrium coordinates should be consistent with the spatial dimension."
57+
assert (
58+
len(x) == self.spatial_dimension
59+
), "The equilibrium coordinates should be consistent with the spatial dimension."
5160

52-
assert len(self.elements) == 1, "There can only be one element type for the gaussian data module."
61+
assert (
62+
len(self.elements) == 1
63+
), "There can only be one element type for the gaussian data module."
5364

5465

5566
class GaussianDataModule(pl.LightningDataModule):
@@ -69,8 +80,9 @@ def __init__(
6980
self.number_of_atoms = hyper_params.number_of_atoms
7081
self.spatial_dimension = hyper_params.spatial_dimension
7182
self.sigma_d = hyper_params.sigma_d
72-
self.equilibrium_coordinates = torch.tensor(hyper_params.equilibrium_relative_coordinates,
73-
dtype=torch.float)
83+
self.equilibrium_coordinates = torch.tensor(
84+
hyper_params.equilibrium_relative_coordinates, dtype=torch.float
85+
)
7486

7587
self.train_dataset_size = hyper_params.train_dataset_size
7688
self.valid_dataset_size = hyper_params.valid_dataset_size
@@ -85,6 +97,41 @@ def __init__(
8597

8698
self.element_types = ElementTypes(hyper_params.elements)
8799

100+
self.noising_transform = NoisingTransform(
101+
noise_parameters=hyper_params.noise_parameters,
102+
num_atom_types=len(hyper_params.elements),
103+
spatial_dimension=self.spatial_dimension,
104+
use_optimal_transport=hyper_params.use_optimal_transport,
105+
)
106+
107+
def get_raw_dataset(self, batch_size: int, rng: torch.Generator):
108+
"""Get raw dataset."""
109+
box = torch.ones(batch_size, self.spatial_dimension, dtype=torch.float)
110+
atom_types = torch.zeros(batch_size, self.number_of_atoms, dtype=torch.long)
111+
112+
mean = einops.repeat(
113+
self.equilibrium_coordinates,
114+
"natoms space -> batch natoms space",
115+
batch=batch_size,
116+
)
117+
std = self.sigma_d * torch.ones_like(mean)
118+
relative_coordinates = map_relative_coordinates_to_unit_cell(
119+
torch.normal(mean=mean, std=std, generator=rng).to(torch.float)
120+
)
121+
122+
natoms = self.number_of_atoms * torch.ones(batch_size)
123+
potential_energy = torch.zeros(batch_size)
124+
125+
raw_dataset = {
126+
"natom": natoms,
127+
"box": box,
128+
RELATIVE_COORDINATES: relative_coordinates,
129+
ATOM_TYPES: atom_types,
130+
CARTESIAN_FORCES: torch.zeros_like(relative_coordinates),
131+
"potential_energy": potential_energy,
132+
}
133+
return raw_dataset
134+
88135
def setup(self, stage: Optional[str] = None):
89136
"""Setup method."""
90137
self.train_dataset = []
@@ -93,29 +140,19 @@ def setup(self, stage: Optional[str] = None):
93140
rng = torch.Generator()
94141
rng.manual_seed(self.random_seed)
95142

96-
box = torch.ones(self.spatial_dimension, dtype=torch.float)
97-
98-
atom_types = torch.zeros(self.number_of_atoms, dtype=torch.long)
99-
100-
for dataset, batch_size in zip([self.train_dataset, self.valid_dataset],
101-
[self.train_dataset_size, self.valid_dataset_size]):
102-
103-
mean = einops.repeat(self.equilibrium_coordinates,
104-
"natoms space -> batch natoms space", batch=batch_size)
105-
std = self.sigma_d * torch.ones_like(mean)
106-
relative_coordinates = map_relative_coordinates_to_unit_cell(
107-
torch.normal(mean=mean, std=std, generator=rng).to(torch.float))
108-
109-
for x in relative_coordinates:
110-
row = {
111-
"natom": self.number_of_atoms,
112-
"box": box,
113-
RELATIVE_COORDINATES: x,
114-
ATOM_TYPES: atom_types,
115-
CARTESIAN_FORCES: torch.zeros_like(x),
116-
"potential_energy": torch.tensor([0.0], dtype=torch.float),
117-
}
118-
dataset.append(row)
143+
for dataset, batch_size in zip(
144+
[self.train_dataset, self.valid_dataset],
145+
[self.train_dataset_size, self.valid_dataset_size],
146+
):
147+
148+
raw_dataset_as_single_batch = self.get_raw_dataset(batch_size, rng)
149+
dataset_as_single_batch = self.noising_transform.transform(
150+
raw_dataset_as_single_batch
151+
)
152+
153+
keys = dataset_as_single_batch.keys()
154+
for idx in range(batch_size):
155+
dataset.append({key: dataset_as_single_batch[key][idx] for key in keys})
119156

120157
def train_dataloader(self) -> DataLoader:
121158
"""Create the training dataloader using the training data parser."""

src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/instantiate_data_module.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def load_data_module(hyper_params: Dict[AnyStr, Any], args: argparse.Namespace)
4646
working_cache_dir=args.dataset_working_dir)
4747

4848
case "gaussian":
49-
data_params = GaussianDataModuleParameters(**data_config, elements=hyper_params["elements"])
49+
data_params = GaussianDataModuleParameters(**data_config,
50+
noise_parameters=noise_parameters,
51+
elements=hyper_params["elements"])
5052
data_module = GaussianDataModule(data_params)
5153
case _:
5254
raise NotImplementedError(

tests/data/diffusion/test_gaussian_data_module.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
GaussianDataModule, GaussianDataModuleParameters)
66
from diffusion_for_multi_scale_molecular_dynamics.namespace import \
77
RELATIVE_COORDINATES
8+
from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \
9+
NoiseParameters
810

911

1012
class TestGaussianDataModule:
@@ -33,17 +35,23 @@ def spatial_dimension(self):
3335
def sigma_d(self):
3436
return 0.01
3537

38+
@pytest.fixture(params=[True, False])
39+
def use_optimal_transport(self, request):
40+
return request.param
41+
3642
@pytest.fixture()
3743
def equilibrium_relative_coordinates(self, number_of_atoms, spatial_dimension):
3844
list_x = torch.rand(number_of_atoms, spatial_dimension)
3945
equilibrium_relative_coordinates = [list(x) for x in list_x.numpy()]
4046
return equilibrium_relative_coordinates
4147

4248
@pytest.fixture
43-
def data_module_hyperparameters(self, batch_size, train_dataset_size, valid_dataset_size,
49+
def data_module_hyperparameters(self, batch_size, train_dataset_size, valid_dataset_size, use_optimal_transport,
4450
number_of_atoms, spatial_dimension, sigma_d, equilibrium_relative_coordinates):
4551
return GaussianDataModuleParameters(
4652
batch_size=batch_size,
53+
noise_parameters=NoiseParameters(total_time_steps=10),
54+
use_optimal_transport=use_optimal_transport,
4755
random_seed=42,
4856
num_workers=0,
4957
sigma_d=sigma_d,

0 commit comments

Comments
 (0)