forked from theislab/chemCPA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmanual_run.yaml
136 lines (127 loc) · 4.92 KB
/
manual_run.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Config for manual run
seml:
executable: chemCPA/experiments_run.py
name: manual_run_chemCPA
output_dir: project_folder/logs
conda_environment: chemical_CPA
project_root_dir: .
fixed:
profiling.run_profiler: False
profiling.outdir: "./"
training.checkpoint_freq: 25 # checkpoint frequency to run evaluate, and maybe save checkpoint
training.num_epochs: 101 # maximum epochs for training. One epoch updates either autoencoder, or adversary, depending on adversary_steps.
training.max_minutes: 1200 # maximum computation time
training.full_eval_during_train: False
training.run_eval_disentangle: True # whether to calc the disentanglement loss when running the full eval
training.run_eval_r2: True
training.run_eval_r2_sc: False
training.run_eval_logfold: False
training.save_checkpoints: True # checkpoints tend to be ~250MB large for LINCS.
training.save_dir: project_folder/checkpoints
dataset.dataset_type: trapnell
dataset.data_params.perturbation_key: condition # stores name of the drug
dataset.data_params.pert_category: cov_drug_dose_name # stores celltype_drugname_drugdose
dataset.data_params.dose_key: dose # stores drug dose as a float
dataset.data_params.covariate_keys: cell_type # necessary field for cell types. Fill it with a dummy variable if no celltypes present.
dataset.data_params.smiles_key: SMILES
dataset.data_params.use_drugs_idx: True # If false, will use One-hot encoding instead
dataset.data_params.split_key: split_ood_multi_task
model.pretrained_model_path: project_folder/checkpoints
model.pretrained_model_hashes: null
model.additional_params.patience: 50 # patience for early stopping. Effective epochs: patience * checkpoint_freq.
model.additional_params.decoder_activation: ReLU
model.additional_params.doser_type: amortized
model.additional_params.multi_task: false
model.embedding.directory: embeddings
model.additional_params.seed: 1337
model.hparams.dim: 32
model.hparams.dropout: 0.262378
model.hparams.autoencoder_width: 256
model.hparams.autoencoder_depth: 4
model.hparams.reg_multi_task: 0
random:
samples: 1
seed: 42
model.hparams.batch_size:
type: choice
options:
- 128
model.hparams.autoencoder_lr:
type: loguniform
min: 1e-4
max: 1e-2
model.hparams.autoencoder_wd:
type: loguniform
min: 1e-8
max: 1e-5
model.hparams.adversary_width:
type: choice
options:
- 64
- 128
- 256
model.hparams.adversary_depth:
type: choice
options:
- 2
- 3
- 4
model.hparams.adversary_lr:
type: loguniform
min: 5e-5
max: 1e-2
model.hparams.adversary_wd:
type: loguniform
min: 1e-8
max: 1e-3
model.hparams.adversary_steps: # every X steps, update the adversary INSTEAD OF the autoencoder.
type: choice
options:
- 2
- 3
model.hparams.reg_adversary:
type: loguniform
min: 1
max: 40
model.hparams.reg_adversary_cov:
type: loguniform
min: 5
max: 50
model.hparams.penalty_adversary:
type: loguniform
min: 0.05
max: 2
model.hparams.dosers_lr:
type: loguniform
min: 1e-4
max: 1e-2
model.hparams.dosers_wd:
type: loguniform
min: 1e-8
max: 1e-5
grid:
model.load_pretrained:
type: choice
options:
# - True
- False
rdkit_model:
fixed:
model.embedding.model: rdkit # Alternative: grover_base, jtvae
model.hparams.dosers_width: 64 # should be larger for jt-vae and grover
model.hparams.dosers_depth: 3
model.hparams.step_size_lr: 50 # this applies to all optimizers (AE, ADV, DRUG)
model.hparams.embedding_encoder_width: 128 # should be larger for jtvae and grover
model.hparams.embedding_encoder_depth: 4
#____________________________________________________________________________________________________
#_Shared_gene_set_(setting_1)________________________________________________________________________
# model.append_ae_layer: False
# dataset.data_params.dataset_path: project_folder/datasets/sciplex_complete_middle_subset_lincs_genes.h5ad
# dataset.data_params.degs_key: lincs_DEGs # `uns` column name denoting the DEGs for each perturbation
#____________________________________________________________________________________________________
#_Extended_gene_set_(setting_2)______________________________________________________________________
model.append_ae_layer: True
model.enable_cpa_mode: False
dataset.data_params.dataset_path: project_folder/datasets/sciplex_complete_middle_subset.h5ad
dataset.data_params.degs_key: all_DEGs
#____________________________________________________________________________________________________