diff --git a/Snakefile b/Snakefile index a45b1f33..b2f7a885 100644 --- a/Snakefile +++ b/Snakefile @@ -91,7 +91,7 @@ def _get_node_data_for_predictors(wildcards): rule convert_node_data_to_table: input: - tree = rules.refine.output.tree, + tree = rules.prune_reference.output.tree, metadata = rules.parse.output.metadata, node_data = _get_node_data_for_predictors output: @@ -114,7 +114,7 @@ rule convert_node_data_to_table: rule convert_frequencies_to_table: input: - tree = rules.refine.output.tree, + tree = rules.prune_reference.output.tree, frequencies = rules.tip_frequencies.output.tip_freq output: table = "results/frequencies_{center}_{lineage}_{segment}_{resolution}_{passage}_{assay}.tsv" @@ -230,7 +230,7 @@ rule merge_weighted_distances_to_future: rule export: input: - tree = rules.refine.output.tree, + tree = rules.prune_reference.output.tree, metadata = rules.parse.output.metadata, auspice_config = files.auspice_config, node_data = _get_node_data_for_export, diff --git a/Snakefile_WHO b/Snakefile_WHO index fbaf1c75..12a75c26 100644 --- a/Snakefile_WHO +++ b/Snakefile_WHO @@ -146,7 +146,7 @@ rule global_mutation_frequencies: rule scores: input: metadata = rules.parse.output.metadata, - tree = rules.refine.output.tree + tree = rules.prune_reference.output.tree output: scores = "results/scores_{center}_{lineage}_{segment}_{resolution}_{passage}_{assay}.json" conda: "environment.yaml" @@ -209,7 +209,7 @@ rule export_entropy: rule export_sequence_json: input: aln = rules.ancestral.output.node_data, - tree = rules.refine.output.tree, + tree = rules.prune_reference.output.tree, aa_seqs = translations params: genes = gene_names @@ -260,7 +260,7 @@ def _get_node_data_for_report_export(wildcards): rule export_who: input: - tree = rules.refine.output.tree, + tree = rules.prune_reference.output.tree, metadata = rules.parse.output.metadata, auspice_config = "config/auspice_config_who_{lineage}.json", node_data = _get_node_data_for_report_export, diff --git a/Snakefile_base b/Snakefile_base index 1c72686d..7819aade 100644 --- a/Snakefile_base +++ b/Snakefile_base @@ -369,10 +369,36 @@ rule select_strains: --output {output.strains} """ +def _get_lineage_reference(wildcards): + if wildcards.resolution in {'2y', '6m'}: + if wildcards.lineage == "h1n1pdm": + return "A/California/7/2009" + if wildcards.lineage == "h3n2": + return "A/Wisconsin/67/2005" + if wildcards.lineage == "vic": + return "B/Brisbane/60/2008" + + return "" + + +rule include_reference_strain: + message: "Adding reference strain to selected strains (only for 2y and 6m builds)" + input: + strains = "results/strains_{center}_{lineage}_{resolution}_{passage}_{assay}.txt" + output: + strains = "results/strains_with_reference_{center}_{lineage}_{resolution}_{passage}_{assay}.txt" + params: + reference = _get_lineage_reference + shell: + """ + cp {input.strains} {output.strains} + echo "\n{params.reference}" >> {output.strains} + """ + rule extract: input: sequences = rules.filter.output.sequences, - strains = rules.select_strains.output.strains + strains = rules.include_reference_strain.output.strains output: sequences = 'results/extracted_{center}_{lineage}_{segment}_{resolution}_{passage}_{assay}.fasta' conda: "environment.yaml" @@ -467,6 +493,9 @@ rule tree: --exclude-sites {input.exclude_sites} """ +def _get_refine_root(wildcards): + return _get_lineage_reference(wildcards) or "best" + rule refine: message: """ @@ -488,7 +517,8 @@ rule refine: date_inference = "marginal", clock_filter_iqd = 4, clock_rate = clock_rate, - clock_std_dev = clock_std_dev + clock_std_dev = clock_std_dev, + root = _get_refine_root conda: "environment.yaml" resources: mem_mb=16000 @@ -507,13 +537,31 @@ rule refine: --coalescent {params.coalescent} \ --date-confidence \ --date-inference {params.date_inference} \ - --clock-filter-iqd {params.clock_filter_iqd} + --clock-filter-iqd {params.clock_filter_iqd} \ + --root {params.root} """ +rule prune_reference: + message: "Pruning reference strain from the tree (only for 2y and 6m builds)" + input: + tree = rules.refine.output.tree + output: + tree = "results/tree_pruned_{center}_{lineage}_{segment}_{resolution}_{passage}_{assay}.nwk" + params: + reference = _get_lineage_reference + shell: + """ + python3 scripts/prune_reference.py \ + --tree {input.tree} \ + --reference {params.reference} \ + --output {output.tree} + """ + + rule ancestral: message: "Reconstructing ancestral sequences and mutations" input: - tree = rules.refine.output.tree, + tree = rules.prune_reference.output.tree, alignment = rules.align.output output: node_data = "results/nt-muts_{center}_{lineage}_{segment}_{resolution}_{passage}_{assay}.json" @@ -534,7 +582,7 @@ rule ancestral: rule translate: message: "Translating amino acid sequences" input: - tree = rules.refine.output.tree, + tree = rules.prune_reference.output.tree, node_data = rules.ancestral.output.node_data, reference = files.reference output: @@ -552,7 +600,7 @@ rule translate: rule reconstruct_translations: message: "Reconstructing translations required for titer models and frequencies" input: - tree = rules.refine.output.tree, + tree = rules.prune_reference.output.tree, node_data = "results/aa-muts_{center}_{lineage}_{segment}_{resolution}_{passage}_{assay}.json", output: aa_alignment = "results/aa-seq_{center}_{lineage}_{segment}_{resolution}_{passage}_{assay}_{gene}.fasta" @@ -572,7 +620,7 @@ rule reconstruct_translations: rule convert_translations_to_json: input: - tree = rules.refine.output.tree, + tree = rules.prune_reference.output.tree, translations = translations output: translations = "results/aa-seq_{center}_{lineage}_{segment}_{resolution}_{passage}_{assay}.json" @@ -595,7 +643,7 @@ rule traits: Inferring ancestral traits for {params.columns!s} """ input: - tree = rules.refine.output.tree, + tree = rules.prune_reference.output.tree, metadata = rules.parse.output.metadata output: node_data = "results/traits_{center}_{lineage}_{segment}_{resolution}_{passage}_{assay}.json", @@ -617,7 +665,7 @@ rule titers_sub: titers = rules.download_titers.output.titers, aa_muts = rules.translate.output, alignments = translations, - tree = rules.refine.output.tree + tree = rules.prune_reference.output.tree params: genes = gene_names output: @@ -637,7 +685,7 @@ rule titers_sub: rule titers_tree: input: titers = rules.download_titers.output.titers, - tree = rules.refine.output.tree + tree = rules.prune_reference.output.tree output: titers_model = "results/titers-tree-model_{center}_{lineage}_{segment}_{resolution}_{passage}_{assay}.json", conda: "environment.yaml" @@ -652,7 +700,7 @@ rule titers_tree: rule tip_frequencies: input: - tree = rules.refine.output.tree, + tree = rules.prune_reference.output.tree, metadata = rules.parse.output.metadata, params: narrow_bandwidth = 2 / 12.0, @@ -681,7 +729,7 @@ rule tip_frequencies: rule tree_frequencies: input: - tree = rules.refine.output.tree, + tree = rules.prune_reference.output.tree, metadata = rules.parse.output.metadata, params: min_date = min_date, @@ -709,7 +757,7 @@ rule tree_frequencies: rule diffusion_frequencies: input: - tree = rules.refine.output.tree, + tree = rules.prune_reference.output.tree, metadata = rules.parse.output.metadata, params: min_date = min_date, @@ -735,7 +783,7 @@ rule diffusion_frequencies: rule delta_frequency: input: - tree = rules.refine.output.tree, + tree = rules.prune_reference.output.tree, frequencies = rules.diffusion_frequencies.output.frequencies output: delta_frequency = "results/delta_frequency_{center}_{lineage}_{segment}_{resolution}_{passage}_{assay}.json" @@ -756,7 +804,7 @@ rule delta_frequency: rule clades: message: "Annotating clades" input: - tree = "results/tree_{center}_{lineage}_ha_{resolution}_{passage}_{assay}.nwk", + tree = "results/tree_pruned_{center}_{lineage}_ha_{resolution}_{passage}_{assay}.nwk", nt_muts = rules.ancestral.output, aa_muts = rules.translate.output, clades = _get_clades_file_for_wildcards @@ -781,7 +829,7 @@ rule clades: rule antigenic_distances_between_strains: input: - tree="results/tree_{center}_{lineage}_{segment}_{resolution}_{passage}_{assay}.nwk", + tree="results/tree_pruned_{center}_{lineage}_{segment}_{resolution}_{passage}_{assay}.nwk", clades="results/clades_{center}_{lineage}_{segment}_{resolution}_{passage}_{assay}.json", titer_model="results/titers-sub-model_{center}_{lineage}_{segment}_{resolution}_{passage}_{assay}.json", titers="data/{center}_{lineage}_{passage}_{assay}_titers.tsv", @@ -836,7 +884,7 @@ rule plot_antigenic_distances_between_strains: rule distances: input: - tree = rules.refine.output.tree, + tree = rules.prune_reference.output.tree, alignments = translations, distance_maps = _get_distance_maps_by_lineage_and_segment params: @@ -862,7 +910,7 @@ rule distances: rule pairwise_titer_tree_distances: input: - tree = rules.refine.output.tree, + tree = rules.prune_reference.output.tree, frequencies = rules.tip_frequencies.output.tip_freq, model = rules.titers_tree.output.titers_model, date_annotations = rules.refine.output.node_data @@ -917,7 +965,7 @@ rule titer_tree_cross_immunities: rule glyc: input: - tree = rules.refine.output.tree, + tree = rules.prune_reference.output.tree, alignment = _get_glyc_alignment output: glyc = "results/glyc_{center}_{lineage}_{segment}_{resolution}_{passage}_{assay}.json" @@ -933,7 +981,7 @@ rule glyc: rule lbi: message: "Calculating LBI" input: - tree = rules.refine.output.tree, + tree = rules.prune_reference.output.tree, branch_lengths = rules.refine.output.node_data params: tau = _get_lbi_tau_for_wildcards, diff --git a/Snakefile_countries b/Snakefile_countries index c3bc3a7e..06d1611f 100644 --- a/Snakefile_countries +++ b/Snakefile_countries @@ -104,7 +104,7 @@ def _get_node_data_for_export(wildcards): rule export: input: - tree = rules.refine.output.tree, + tree = rules.prune_reference.output.tree, metadata = rules.parse.output.metadata, auspice_config = files.auspice_config, node_data = _get_node_data_for_export, diff --git a/scripts/prune_reference.py b/scripts/prune_reference.py new file mode 100644 index 00000000..99cd6c32 --- /dev/null +++ b/scripts/prune_reference.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +""" +Prunes a reference strain from the provided tree. +""" +import argparse +from Bio import Phylo +import shutil +import sys + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--tree", help="Newick tree to prune") + parser.add_argument("--reference", nargs="?", help="Name of the reference strain to prune") + parser.add_argument("--output", help="Output Newick tree file") + + args = parser.parse_args() + + # If reference is not provided, then just copy the input to output without modifications + if not args.reference: + print("WARNING: No reference was provided, copying input tree to output tree", file=sys.stdout) + shutil.copy(args.tree, args.output) + else: + T = Phylo.read(args.tree, "newick") + references = [ c for c in T.find_clades() if str(c.name) == args.reference ] + if references: + T.prune(references[0]) + + Phylo.write(T, args.output, "newick") +