Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Run rSPR when tree has duplicated nodes #191

Merged
merged 2 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions bin/rspr_approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,21 @@ def parse_args(args=None):
)
return parser.parse_args(args)

def check_formatted_tree(tree_string):
"""Check if formatted tree has duplicate nodes"""

pattern = r'([a-zA-Z]+\w{3,}):.*\1'
match = re.search(pattern, tree_string)

return bool(match)

def read_tree(input_path):
with open(input_path, "r") as f:
tree_string = f.read()
formatted = re.sub(r";[^:]+:", ":", tree_string)
return Tree(formatted)
is_duplicated = check_formatted_tree(formatted)

return Tree(formatted), is_duplicated


#####################################################################
Expand All @@ -102,12 +111,27 @@ def read_tree(input_path):
#####################################################################


def root_tree(input_path, output_path):
tre = read_tree(input_path)
def root_tree(input_path, basename, output_path):
tre,is_duplicated = read_tree(input_path)
midpoint = tre.get_midpoint_outgroup()
tre.set_outgroup(midpoint)
if is_duplicated:
outdir = Path(output_path) / "multiple"
Path(outdir).mkdir(exist_ok=True, parents=True)
output_path = outdir / basename
output_path = str(output_path).replace(".tre", ".tre.multiple")
else:
outdir = Path(output_path) / "unique"
Path(outdir).mkdir(exist_ok=True, parents=True)
output_path = outdir / basename

tre.write(outfile=output_path)
return tre.write(), len(tre.get_leaves()), output_path, is_duplicated

def root_reference_tree(input_path, output_path):
tre, _ = read_tree(input_path)
midpoint = tre.get_midpoint_outgroup()
tre.set_outgroup(midpoint)
if not os.path.exists(os.path.dirname(output_path)):
os.makedirs(os.path.dirname(output_path))
tre.write(outfile=output_path)
return tre.write(), len(tre.get_leaves())

Expand Down Expand Up @@ -135,20 +159,23 @@ def root_trees(core_tree, gene_trees_path, output_dir, results, merge_pair=False
rooted_reference_tree = os.path.join(
output_dir, "rooted_reference_tree/core_gene_alignment.tre"
)
refer_content, refer_tree_size = root_tree(reference_tree, rooted_reference_tree)
refer_content, refer_tree_size = root_reference_tree(reference_tree, rooted_reference_tree)

df_gene_trees = pd.read_csv(gene_trees_path)
rooted_gene_trees_path = os.path.join(output_dir, "rooted_gene_trees")
for filename in df_gene_trees["path"]:
basename = Path(filename).name
rooted_gene_tree_path = os.path.join(rooted_gene_trees_path, basename)
gene_content, gene_tree_size = root_tree(filename, rooted_gene_tree_path)
results.loc[basename, "tree_size"] = gene_tree_size
gene_content, gene_tree_size, gene_tree_path, is_duplicated = root_tree(
filename,
basename,
rooted_gene_trees_path)
if not is_duplicated:
results.loc[basename, "tree_size"] = gene_tree_size
if merge_pair:
with open(rooted_gene_tree_path, "w") as f2:
with open(gene_tree_path, "w") as f2:
f2.write(refer_content + "\n" + gene_content)
#'''
return rooted_gene_trees_path
return os.path.join(rooted_gene_trees_path, "unique")


#####################################################################
Expand Down Expand Up @@ -212,7 +239,7 @@ def approx_rspr(
"-length " + str(min_branch_len),
"-support " + str(max_support_threshold),
]

group_size = 10000
cur_count = 0
lst_filename = []
Expand Down Expand Up @@ -498,7 +525,7 @@ def main(args=None):
# Generate group heatmap
group_fig_path = os.path.join(args.OUTPUT_DIR, "group_output.png")
make_group_heatmap(
results,
results,
group_fig_path,
args.MIN_HEATMAP_RSPR_DISTANCE,
args.MAX_HEATMAP_RSPR_DISTANCE
Expand Down
6 changes: 3 additions & 3 deletions bin/rspr_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def fpt_rspr(results_df, min_branch_len=0, max_support_threshold=0.7, gather_clu
"-support " + str(max_support_threshold),
]

trees_path = os.path.join("rooted_gene_trees")
trees_path = os.path.join("rooted_gene_trees/unique")

cluster_file = None
if gather_cluster_info:
Expand Down Expand Up @@ -123,13 +123,13 @@ def fpt_rspr(results_df, min_branch_len=0, max_support_threshold=0.7, gather_clu
continue
elif "Clusters end" in line:
clustering_start = False

if clustering_start:
updated_line = line.replace('(', '').replace(')', '').replace('\n', '')
cluster_nodes = updated_line.split(',')
cluster_nodes = [int(node) for node in cluster_nodes if "X" not in node]
clusters.append(cluster_nodes)

output_lines.append(line)
cluster_file.write(json.dumps(clusters) + '\n')
process.wait()
Expand Down
Loading