9
9
10
10
from diffusion_for_multi_scale_molecular_dynamics .data .diffusion .data_module_parameters import \
11
11
DataModuleParameters
12
+ from diffusion_for_multi_scale_molecular_dynamics .data .diffusion .noising_transform import \
13
+ NoisingTransform
12
14
from diffusion_for_multi_scale_molecular_dynamics .data .element_types import \
13
15
ElementTypes
14
16
from diffusion_for_multi_scale_molecular_dynamics .namespace import (
15
17
ATOM_TYPES , CARTESIAN_FORCES , RELATIVE_COORDINATES )
18
+ from diffusion_for_multi_scale_molecular_dynamics .noise_schedulers .noise_parameters import \
19
+ NoiseParameters
16
20
from diffusion_for_multi_scale_molecular_dynamics .utils .basis_transformations import \
17
21
map_relative_coordinates_to_unit_cell
18
22
@@ -24,6 +28,9 @@ class GaussianDataModuleParameters(DataModuleParameters):
24
28
"""Hyper-parameters for a Gaussian, in memory data module."""
25
29
data_source = "gaussian"
26
30
31
+ noise_parameters : NoiseParameters
32
+ use_optimal_transport : bool
33
+
27
34
random_seed : int
28
35
# the number of atoms in a configuration.
29
36
number_of_atoms : int
@@ -42,14 +49,18 @@ def __post_init__(self):
42
49
"""Post init."""
43
50
assert self .sigma_d > 0.0 , "the sigma_d parameter should be positive."
44
51
45
- assert len (self .equilibrium_relative_coordinates ) == self .number_of_atoms , \
46
- "There should be exactly one list of equilibrium coordinates per atom."
52
+ assert (
53
+ len (self .equilibrium_relative_coordinates ) == self .number_of_atoms
54
+ ), "There should be exactly one list of equilibrium coordinates per atom."
47
55
48
56
for x in self .equilibrium_relative_coordinates :
49
- assert len (x ) == self .spatial_dimension , \
50
- "The equilibrium coordinates should be consistent with the spatial dimension."
57
+ assert (
58
+ len (x ) == self .spatial_dimension
59
+ ), "The equilibrium coordinates should be consistent with the spatial dimension."
51
60
52
- assert len (self .elements ) == 1 , "There can only be one element type for the gaussian data module."
61
+ assert (
62
+ len (self .elements ) == 1
63
+ ), "There can only be one element type for the gaussian data module."
53
64
54
65
55
66
class GaussianDataModule (pl .LightningDataModule ):
@@ -69,8 +80,9 @@ def __init__(
69
80
self .number_of_atoms = hyper_params .number_of_atoms
70
81
self .spatial_dimension = hyper_params .spatial_dimension
71
82
self .sigma_d = hyper_params .sigma_d
72
- self .equilibrium_coordinates = torch .tensor (hyper_params .equilibrium_relative_coordinates ,
73
- dtype = torch .float )
83
+ self .equilibrium_coordinates = torch .tensor (
84
+ hyper_params .equilibrium_relative_coordinates , dtype = torch .float
85
+ )
74
86
75
87
self .train_dataset_size = hyper_params .train_dataset_size
76
88
self .valid_dataset_size = hyper_params .valid_dataset_size
@@ -85,6 +97,41 @@ def __init__(
85
97
86
98
self .element_types = ElementTypes (hyper_params .elements )
87
99
100
+ self .noising_transform = NoisingTransform (
101
+ noise_parameters = hyper_params .noise_parameters ,
102
+ num_atom_types = len (hyper_params .elements ),
103
+ spatial_dimension = self .spatial_dimension ,
104
+ use_optimal_transport = hyper_params .use_optimal_transport ,
105
+ )
106
+
107
+ def get_raw_dataset (self , batch_size : int , rng : torch .Generator ):
108
+ """Get raw dataset."""
109
+ box = torch .ones (batch_size , self .spatial_dimension , dtype = torch .float )
110
+ atom_types = torch .zeros (batch_size , self .number_of_atoms , dtype = torch .long )
111
+
112
+ mean = einops .repeat (
113
+ self .equilibrium_coordinates ,
114
+ "natoms space -> batch natoms space" ,
115
+ batch = batch_size ,
116
+ )
117
+ std = self .sigma_d * torch .ones_like (mean )
118
+ relative_coordinates = map_relative_coordinates_to_unit_cell (
119
+ torch .normal (mean = mean , std = std , generator = rng ).to (torch .float )
120
+ )
121
+
122
+ natoms = self .number_of_atoms * torch .ones (batch_size )
123
+ potential_energy = torch .zeros (batch_size )
124
+
125
+ raw_dataset = {
126
+ "natom" : natoms ,
127
+ "box" : box ,
128
+ RELATIVE_COORDINATES : relative_coordinates ,
129
+ ATOM_TYPES : atom_types ,
130
+ CARTESIAN_FORCES : torch .zeros_like (relative_coordinates ),
131
+ "potential_energy" : potential_energy ,
132
+ }
133
+ return raw_dataset
134
+
88
135
def setup (self , stage : Optional [str ] = None ):
89
136
"""Setup method."""
90
137
self .train_dataset = []
@@ -93,29 +140,19 @@ def setup(self, stage: Optional[str] = None):
93
140
rng = torch .Generator ()
94
141
rng .manual_seed (self .random_seed )
95
142
96
- box = torch .ones (self .spatial_dimension , dtype = torch .float )
97
-
98
- atom_types = torch .zeros (self .number_of_atoms , dtype = torch .long )
99
-
100
- for dataset , batch_size in zip ([self .train_dataset , self .valid_dataset ],
101
- [self .train_dataset_size , self .valid_dataset_size ]):
102
-
103
- mean = einops .repeat (self .equilibrium_coordinates ,
104
- "natoms space -> batch natoms space" , batch = batch_size )
105
- std = self .sigma_d * torch .ones_like (mean )
106
- relative_coordinates = map_relative_coordinates_to_unit_cell (
107
- torch .normal (mean = mean , std = std , generator = rng ).to (torch .float ))
108
-
109
- for x in relative_coordinates :
110
- row = {
111
- "natom" : self .number_of_atoms ,
112
- "box" : box ,
113
- RELATIVE_COORDINATES : x ,
114
- ATOM_TYPES : atom_types ,
115
- CARTESIAN_FORCES : torch .zeros_like (x ),
116
- "potential_energy" : torch .tensor ([0.0 ], dtype = torch .float ),
117
- }
118
- dataset .append (row )
143
+ for dataset , batch_size in zip (
144
+ [self .train_dataset , self .valid_dataset ],
145
+ [self .train_dataset_size , self .valid_dataset_size ],
146
+ ):
147
+
148
+ raw_dataset_as_single_batch = self .get_raw_dataset (batch_size , rng )
149
+ dataset_as_single_batch = self .noising_transform .transform (
150
+ raw_dataset_as_single_batch
151
+ )
152
+
153
+ keys = dataset_as_single_batch .keys ()
154
+ for idx in range (batch_size ):
155
+ dataset .append ({key : dataset_as_single_batch [key ][idx ] for key in keys })
119
156
120
157
def train_dataloader (self ) -> DataLoader :
121
158
"""Create the training dataloader using the training data parser."""
0 commit comments