From b6bb3de53ab4e65f033b05154a807e06bea7430b Mon Sep 17 00:00:00 2001 From: Matthew Wells Date: Tue, 7 May 2024 17:13:53 -0500 Subject: [PATCH 1/7] began refactoring search --- locidex/classes/seq_intake.py | 192 +++++++++++++++++++++------------- locidex/manifest.py | 9 +- locidex/search.py | 65 +----------- locidex/utils.py | 5 +- tests/test_seq_intake.py | 21 ++-- 5 files changed, 144 insertions(+), 148 deletions(-) diff --git a/locidex/classes/seq_intake.py b/locidex/classes/seq_intake.py index 1cfc0fa..cf350a4 100644 --- a/locidex/classes/seq_intake.py +++ b/locidex/classes/seq_intake.py @@ -6,6 +6,27 @@ from locidex.utils import guess_alphabet, calc_md5, six_frame_translation from locidex.classes.prodigal import gene_prediction from locidex.constants import DNA_AMBIG_CHARS, DNA_IUPAC_CHARS +from typing import NamedTuple, Optional +from dataclasses import dataclass + +@dataclass +class SeqObject: + + parent_id: str + locus_name: str + seq_id: str + dna_seq: str + dna_ambig_count: int + dna_hash: str + dna_len: int + aa_seq: str + aa_hash: str + aa_len: int + start_codon: Optional[str] + end_codon: Optional[str] + count_internal_stop: Optional[int] + __slots__ = ("parent_id", "locus_name", "seq_id", "dna_seq", "dna_ambig_count", + "dna_hash", "dna_len", "aa_seq", "aa_hash", "aa_len", "start_codon", "end_codon", "count_internal_stop") class seq_intake: input_file = '' @@ -14,9 +35,8 @@ class seq_intake: feat_key = 'CDS' translation_table = 11 is_file_valid = '' - status = True + messages = [] - seq_data = [] prodigal_genes = [] skip_trans = False @@ -26,6 +46,9 @@ def __init__(self,input_file,file_type,feat_key='CDS',translation_table=11,perfo self.translation_table = translation_table self.feat_key = feat_key self.skip_trans = skip_trans + self.status = True + self.num_threads = num_threads + #self.seq_data = self.process_fasta() if not os.path.isfile(self.input_file): self.messages.append(f'Error {self.input_file} does not exist') @@ -35,15 +58,17 @@ def __init__(self,input_file,file_type,feat_key='CDS',translation_table=11,perfo return if file_type == 'genbank': - self.status = self.process_gbk() + self.seq_data = self.process_gbk() + self.status = True elif file_type == 'fasta' and perform_annotation==True: - sobj = gene_prediction(self.input_file) - sobj.predict(num_threads) - self.prodigal_genes = sobj.genes - self.process_seq_hash(sobj.sequences) - self.process_fasta() + #sobj = gene_prediction(self.input_file) + #sobj.predict(num_threads) + #self.prodigal_genes = sobj.genes + #self.process_seq_hash(sobj.sequences) + #self.seq_data = self.process_fasta() + self.seq_data = self.annotate_fasta(self.input_file, num_threads=self.num_threads) elif file_type == 'fasta' and perform_annotation==False: - self.process_fasta() + self.seq_data = self.process_fasta() elif file_type == 'gff': self.status = False elif file_type == 'gtf': @@ -51,12 +76,19 @@ def __init__(self,input_file,file_type,feat_key='CDS',translation_table=11,perfo if self.status: self.add_codon_data() + + def annotate_fasta(self, input_file, num_threads): + sobj = gene_prediction(input_file) + sobj.predict(num_threads) + self.prodigal_genes = sobj.genes + seq_data = self.process_seq_hash(sobj.sequences) + return self.process_fasta(seq_data) def add_codon_data(self): for record in self.seq_data: - if record['aa_len'] == 0: + if record.aa_len == 0: continue - dna_seq = record['dna_seq'].lower().replace('-','') + dna_seq = record.dna_seq.lower().replace('-','') dna_len = len(dna_seq) start_codon = '' stop_codon = '' @@ -64,15 +96,15 @@ def add_codon_data(self): if dna_len >= 6: start_codon = dna_seq[0:3] stop_codon = dna_seq[-3:] - count_internal_stop = record['aa_seq'][:-1].count('*') - record['start_codon'] = start_codon - record['stop_codon'] = stop_codon - record['count_internal_stop'] = count_internal_stop + count_internal_stop = record.aa_seq[:-1].count('*') + record.start_codon = start_codon + record.end_codon = stop_codon + record.count_internal_stop = count_internal_stop - def process_gbk(self): + def process_gbk(self) -> list[SeqObject]: obj = parse_gbk(self.input_file) - + seq_data = [] if obj.status == False: return False acs = obj.get_acs() @@ -84,22 +116,25 @@ def process_gbk(self): for char in DNA_IUPAC_CHARS: s = s.replace(char,"n") - self.seq_data.append( { - 'parent_id': a, - 'locus_name':seq['gene_name'], - 'seq_id': seq['gene_name'], - 'dna_seq': s, - 'dna_ambig_count': self.count_ambig_chars(seq['dna_seq'], DNA_AMBIG_CHARS), - 'dna_hash': calc_md5([s])[0], - 'dna_len': len(seq['dna_seq']), - 'aa_seq': seq['aa_seq'], - 'aa_hash': seq['aa_hash'], - 'aa_len': len(seq['aa_seq']), - - } ) - return True - - def process_fasta(self): + seq_data.append(SeqObject( + parent_id = a, + locus_name = seq['gene_name'], + seq_id = seq['gene_name'], + dna_seq = s, + dna_ambig_count = self.count_ambig_chars(seq['dna_seq'], DNA_AMBIG_CHARS), + dna_hash = calc_md5([s])[0], + dna_len = len(seq['dna_seq']), + aa_seq = seq['aa_seq'], + aa_hash = seq['aa_hash'], + aa_len = len(seq['aa_seq']), + start_codon=None, + end_codon=None, + count_internal_stop=None, + )) + return seq_data + + def process_fasta(self, seq_data = []) -> list[SeqObject]: + obj = parse_fasta(self.input_file) if obj.status == False: return @@ -132,23 +167,25 @@ def process_fasta(self): aa_hash = calc_md5([aa_seq])[0] aa_len = len(aa_seq) - self.seq_data.append({ - 'parent_id': features['gene_name'], - 'locus_name': features['gene_name'], - 'seq_id': features['seq_id'], - 'dna_seq': dna_seq, - 'dna_ambig_count': self.count_ambig_chars(dna_seq, DNA_AMBIG_CHARS), - 'dna_hash': dna_hash, - 'dna_len': dna_len, - 'aa_seq': aa_seq, - 'aa_hash': aa_hash, - 'aa_len': aa_len, - - }) - - return - - def process_seq_hash(self,sequences): + seq_data.append(SeqObject( + parent_id=features['gene_name'], + locus_name = features['gene_name'], + seq_id = features['seq_id'], + dna_seq = dna_seq, + dna_ambig_count = self.count_ambig_chars(dna_seq, DNA_AMBIG_CHARS), + dna_hash = dna_hash, + dna_len = dna_len, + aa_seq = aa_seq, + aa_hash = aa_hash, + aa_len = aa_len, + start_codon=None, + end_codon=None, + count_internal_stop=None,)) + return seq_data + + + def process_seq_hash(self,sequences) -> list[SeqObject]: + seq_data = [] for id in sequences: seq = sequences[id] dtype = guess_alphabet(seq) @@ -174,21 +211,23 @@ def process_seq_hash(self,sequences): aa_seq = seq.lower().replace('-','') aa_hash = calc_md5([aa_seq])[0] aa_len = len(aa_seq) - self.seq_data.append({ - 'parent_id': id, - 'locus_name': id, - 'seq_id': id, - 'dna_seq': dna_seq, - 'dna_hash': dna_hash, - 'dna_ambig_count':self.count_ambig_chars(dna_seq, DNA_AMBIG_CHARS), - 'dna_len': dna_len, - 'aa_seq': aa_seq, - 'aa_hash': aa_hash, - 'aa_len': aa_len, - - }) - - return + seq_data.append(SeqObject( + parent_id=id, + locus_name=id, + seq_id=id, + dna_seq= dna_seq, + dna_hash= dna_hash, + dna_ambig_count=self.count_ambig_chars(dna_seq, DNA_AMBIG_CHARS), + dna_len=dna_len, + aa_seq=aa_seq, + aa_hash=aa_hash, + aa_len=aa_len, + start_codon=None, + end_codon=None, + count_internal_stop=None, + )) + + return seq_data def count_ambig_chars(self,seq,chars): count = 0 @@ -200,7 +239,7 @@ def count_ambig_chars(self,seq,chars): class seq_store: - stored_fields = ['parent_id','locus_name','seq_id','dna_hash','dna_len','aa_hash','aa_len','start_codon','stop_codon','count_internal_stop','dna_ambig_count'] + #stored_fields = ['parent_id','locus_name','seq_id','dna_hash','dna_len','aa_hash','aa_len','start_codon','stop_codon','count_internal_stop','dna_ambig_count'] record = { 'db_info': {}, 'db_seq_info': {}, @@ -213,13 +252,14 @@ class seq_store: } } - def __init__(self,sample_name,db_config_dict,metadata_dict,query_seq_records,blast_columns,filters={},stored_fields=[]): + #def __init__(self,sample_name,db_config_dict,metadata_dict,query_seq_records,blast_columns,filters={},stored_fields=[]): + def __init__(self,sample_name,db_config_dict,metadata_dict,query_seq_records,blast_columns,filters={}): self.sample_name = sample_name self.record['query_data']['sample_name'] = sample_name self.add_db_config(db_config_dict) self.add_seq_data(query_seq_records) - if len(stored_fields) > 0 : - self.stored_fields = stored_fields + #if len(stored_fields) > 0 : + # self.stored_fields = stored_fields self.add_db_metadata(metadata_dict) self.add_hit_cols(blast_columns) self.filters = filters @@ -233,12 +273,14 @@ def add_hit_cols(self,columns): self.record['query_hit_columns'] = columns def add_seq_data(self,query_seq_records): - for idx in range(0, len(query_seq_records)): - self.record['query_data']['query_seq_data'][idx] = {} - for f in self.stored_fields: - self.record['query_data']['query_seq_data'][idx][f] = '' - if f in query_seq_records[idx]: - self.record['query_data']['query_seq_data'][idx][f] = query_seq_records[idx][f] + print(type(query_seq_records[0])) + for idx, v in enumerate(query_seq_records): + self.record['query_data']['query_seq_data'][idx] = v + #print(self.record['query_data']['query_seq_data'][idx]) + #for f in self.stored_fields: + # self.record['query_data']['query_seq_data'][idx][f] = '' + # if f in query_seq_records[idx]: + # self.record['query_data']['query_seq_data'][idx][f] = query_seq_records[idx][f] def add_db_metadata(self,metadata_dict): locus_profile = {} diff --git a/locidex/manifest.py b/locidex/manifest.py index 7fe010c..13c48c7 100644 --- a/locidex/manifest.py +++ b/locidex/manifest.py @@ -157,7 +157,7 @@ def run(cmd_args=None): manifest = create_manifest(directory_in) return write_manifest(directory_in, manifest) -def select_db(manifest_data: Dict[str, List[ManifestItem]], name: str, version: str): +def select_db(manifest_data: Dict[str, List[ManifestItem]], name: str, version: str) -> ManifestItem: """ Select a locidex database from the manifest file provided. @@ -195,10 +195,13 @@ def read_manifest(input_file: pathlib.Path) -> dict: manifest_data[k].append(manifest_item) return manifest_data -def get_manifest_db(input_file: pathlib.Path, name: str, version: str): +def get_manifest_db(input_file: pathlib.Path, name: str, version: str) -> ManifestItem: + """ + Retruns path to the database file selected + """ output = read_manifest(input_file) db_out = select_db(output, name, version) - return db_out.db_path + return db_out # call main function if __name__ == '__main__': diff --git a/locidex/search.py b/locidex/search.py index 3fa045c..ede6fef 100644 --- a/locidex/search.py +++ b/locidex/search.py @@ -95,8 +95,7 @@ def run_search(config): sample_name = config['name'] perform_annotation = config['annotate'] max_target_seqs = config['max_target_seqs'] - db_name = config['db_name'] - db_version = config['db_version'] + if 'max_ambig_count' in config: max_ambig_count = config['max_ambig_count'] else: @@ -113,65 +112,11 @@ def run_search(config): run_data['analysis_start_time'] = datetime.now().strftime("%d/%m/%Y %H:%M:%S") run_data['parameters'] = config - #check if user supplied a manifest of different databases - if os.path.isfile(db_dir): - if db_name is None or db_name == '': - print(f'You specified a file as the locidex db but no db_name to run, please specify a valid --db_name: {db_dir}') - sys.exit() - - with open(db_dir ,'r') as fh: - manifest = json.load(fh) - - if db_name not in manifest: - print(f'You specified a db name "{db_name}" which does not exist in the manifest file: {db_dir}') - print(f'list of keys in manifest: {list(manifest.keys())}') - sys.exit() - - if db_version is not None and db_version != '': - if db_version not in manifest[db_name]: - print(f'You specified a db name "{db_name}" and db version but the db version "{db_version}"was not found in the manifest {list(manifest[db_name].keys())} ') - sys.exit() - else: - version_codes = list(manifest[db_name].keys()) - if len(version_codes) == 1: - version_code = version_codes[0] - else: - latest_date = None - version_code = None - - for code in version_codes: - if not 'db_date' in manifest[db_name][code]: - print(f'Error db_date field missing from manifest for {db_name}, this field is required if more than 1 db version exists') - sys.exit() - db_date = manifest[db_name][code]['db_date'] - db_date = datetime.strptime(db_date, '%Y/%d/%m') - if version_code is None: - version_code = code - latest_date = db_date - continue - if db_date > latest_date: - latest_date = db_date - version_code = code - db_version = version_code - db_dir_prefix = str(os.path.dirname(db_dir)).split('/') - db_dir_rel_path = manifest[db_name][db_version]['db_relative_path_dir'].split('/') - if db_dir_prefix[-1] == db_dir_rel_path[0]: - db_dir_prefix = db_dir_prefix[0:-1] - db_dir_prefix = os.path.join(db_dir_prefix) - db_dir_rel_path = os.path.join(db_dir_rel_path) - db_dir = os.path.join(db_dir_prefix,db_dir_rel_path) - - if not os.path.isdir(db_dir): - print(f'Error DB does not exist: {db_dir}') - sys.exit() - - - # Validate database is valid - db_database_config = search_db_conf(db_dir, DB_EXPECTED_FILES, DBConfig._keys()) - if db_database_config.status == False: - print(f'There is an issue with provided db directory: {db_dir}\n {db_database_config.messages}') - sys.exit() + #db_database_config = search_db_conf(db_dir, DB_EXPECTED_FILES, DBConfig._keys()) + #if db_database_config.status == False: + # print(f'There is an issue with provided db directory: {db_dir}\n {db_database_config.messages}') + # sys.exit() metadata_path = db_database_config.meta_file_path metadata_obj = db_config(metadata_path, ['meta', 'info']) diff --git a/locidex/utils.py b/locidex/utils.py index a93af5a..0546626 100644 --- a/locidex/utils.py +++ b/locidex/utils.py @@ -4,7 +4,7 @@ import argparse from collections import Counter from pathlib import Path - +from locidex.manifest import ManifestItem from Bio.Seq import Seq from locidex.constants import NT_SUB, PROTEIN_ALPHA, DNA_ALPHA, OPTION_GROUPS @@ -21,7 +21,8 @@ def check_db_groups(analysis_params: dict, cmd_args: argparse.Namespace, param_d raise AttributeError("Missing required parameter: {}".format(option)) if cmd_args.db_group is not None: - analysis_params[param_db] = str(manifest.get_manifest_db(input_file=Path(cmd_args.db_group), name=cmd_args.db_name, version=cmd_args.db_version)) + manifest_data = manifest.get_manifest_db(input_file=Path(cmd_args.db_group), name=cmd_args.db_name, version=cmd_args.db_version) + analysis_params[param_db] = str(manifest_data.db_path) return analysis_params diff --git a/tests/test_seq_intake.py b/tests/test_seq_intake.py index 37865cf..75df91e 100644 --- a/tests/test_seq_intake.py +++ b/tests/test_seq_intake.py @@ -1,8 +1,10 @@ import os, warnings import locidex.classes.seq_intake +from locidex.classes.seq_intake import SeqObject from locidex.constants import BLAST_TABLE_COLS, DB_EXPECTED_FILES, DBConfig from locidex.classes.db import search_db_conf, db_config from collections import Counter +from dataclasses import asdict PACKAGE_ROOT = os.path.dirname(locidex.__file__) @@ -36,8 +38,10 @@ def test_seq_store_class(): assert len(seq_store_obj.record['query_data']['query_seq_data']) == 1 else: warnings.warn(f"expected len(seq_store_obj.record['query_data']['query_seq_data']) == 1 but got {len(seq_store_obj.record['query_data']['query_seq_data'])}") - assert list(seq_store_obj.record['query_data']['query_seq_data'][0].keys()) == ['parent_id', 'locus_name', 'seq_id', 'dna_hash', 'dna_len', 'aa_hash', - 'aa_len', 'start_codon', 'stop_codon', 'count_internal_stop', 'dna_ambig_count'] + + compare_dict = asdict(seq_store_obj.record['query_data']['query_seq_data'][0]) + assert set(compare_dict.keys()) == set(['parent_id', 'locus_name', 'seq_id', 'dna_hash', 'dna_len', 'aa_hash', + 'aa_len', 'start_codon', 'end_codon', 'count_internal_stop', 'dna_ambig_count', 'dna_seq', 'aa_seq']) assert list(seq_store_obj.record['query_data']['locus_profile'].keys()) == ['aroC', 'dnaN', 'hemD', 'hisD', 'purE', 'sucA', 'thrA'] assert seq_store_obj.record['query_data']['query_hit_columns'] == [] assert seq_store_obj.record['query_data']['query_hits'] == {} @@ -45,7 +49,7 @@ def test_seq_store_class(): def test_read_gbk_file(): seq_intake_object = seq_intake_class_init(input_file=os.path.join(PACKAGE_ROOT, 'example/search/NC_003198.1.gbk'), - file_type='genbank', perform_annotation=False) + file_type='genbank', perform_annotation=False) expected_orfs = 4644 assert seq_intake_object.file_type == 'genbank' @@ -58,12 +62,13 @@ def test_read_gbk_file(): else: assert len(seq_intake_object.seq_data) == expected_orfs - assert seq_intake_object.seq_data[0] == {'parent_id': 'NC_003198', 'locus_name': 'STY_RS00005', 'seq_id': 'STY_RS00005', + assert seq_intake_object.seq_data[0] == SeqObject(**{'parent_id': 'NC_003198', 'locus_name': 'STY_RS00005', 'seq_id': 'STY_RS00005', 'dna_seq': 'atgaaccgcatcagcaccaccaccattaccaccatcaccattaccacaggtaacggtgcgggctga', - 'dna_hash': 'f46b7aac05dba47f42391aaa5ac25edf', 'count_internal_stop': 0, 'dna_ambig_count': 0, 'start_codon': 'atg', 'stop_codon': 'tga', - 'dna_len': 66, 'aa_seq': 'mnristttittitittgngag', 'aa_hash': '8b370db9e32fd0a8362c35f3535303d8', 'aa_len': 21} - assert all(['start_codon' in record for record in seq_intake_object.seq_data if record['aa_len'] > 0]) == True - assert dict(Counter(['start_codon' in record if record['aa_len'] > 0 else False in record for record in seq_intake_object.seq_data ])) == {True: 4325, False: 319} + 'dna_hash': 'f46b7aac05dba47f42391aaa5ac25edf', 'count_internal_stop': 0, 'dna_ambig_count': 0, 'start_codon': 'atg', 'end_codon': 'tga', + 'dna_len': 66, 'aa_seq': 'mnristttittitittgngag', 'aa_hash': '8b370db9e32fd0a8362c35f3535303d8', 'aa_len': 21}) + assert all([record.start_codon for record in seq_intake_object.seq_data if record.aa_len > 0]) == True + #assert dict(Counter(['start_codon' in record if record.aa_len > 0 else False in record for record in seq_intake_object.seq_data ])) == {True: 4325, False: 319} + assert dict(Counter([record.start_codon is not None if record.aa_len > 0 else False for record in seq_intake_object.seq_data ])) == {True: 4325, False: 319} def test_read_fasta_file(): From b81bab848bc139690c738ae58297d8aaa54a2e47 Mon Sep 17 00:00:00 2001 From: Matthew Wells Date: Wed, 8 May 2024 16:42:17 -0500 Subject: [PATCH 2/7] updated blast module --- locidex/classes/blast.py | 86 ++++++++--------- locidex/classes/fasta.py | 46 +++++---- locidex/classes/seq_intake.py | 46 ++++----- locidex/constants.py | 94 ++++++++++--------- locidex/example/build_db_mlst_out/config.json | 4 +- locidex/extract.py | 6 +- locidex/format.py | 4 +- locidex/manifest.py | 67 ++++++++++++- locidex/search.py | 83 ++++++++-------- locidex/utils.py | 8 +- tests/test_blast.py | 73 +++++++------- tests/test_extractor.py | 33 ++++--- tests/test_fasta.py | 8 +- tests/test_seq_intake.py | 36 ++++--- 14 files changed, 337 insertions(+), 257 deletions(-) diff --git a/locidex/classes/blast.py b/locidex/classes/blast.py index f9aa4bd..148a9c7 100644 --- a/locidex/classes/blast.py +++ b/locidex/classes/blast.py @@ -1,30 +1,30 @@ import pandas as pd from locidex.classes import run_command +from locidex.utils import slots +from dataclasses import dataclass +from typing import Optional import os + +@dataclass +class FilterOptions: + min: Optional[float] + max: Optional[float] + include: Optional[bool] + __slots__ = slots(__annotations__) + class blast_search: VALID_BLAST_METHODS = ['blastn','tblastn','blastp'] - BLAST_TABLE_COLS = [] - input_db_path = None - input_query_path = None - output_db_path = None - output_results = None - blast_params = {} - blast_method = None - create_db = False - parse_seqids = False - status = True - messages = [] - - def __init__(self,input_db_path,input_query_path,output_results,blast_params,blast_method,blast_columns,output_db_path=None,create_db=False,parse_seqids=False): + + def __init__(self,input_db_path,input_query_path,output_results,blast_params,blast_method,blast_columns,output_db_path=None,parse_seqids=False): self.input_query_path = input_query_path self.input_db_path = input_db_path self.output_db_path = output_db_path self.output_results = output_results self.blast_method = blast_method self.blast_params = blast_params - self.create_db = create_db + #self.create_db = create_db self.BLAST_TABLE_COLS = blast_columns self.parse_seqids = parse_seqids @@ -32,27 +32,20 @@ def __init__(self,input_db_path,input_query_path,output_results,blast_params,bla self.output_db_path = input_db_path if not os.path.isfile(self.input_query_path): - self.messages.append(f'Error {self.input_query_path} query fasta does not exist') - self.status = False - - elif not create_db and not self.is_blast_db_valid(): - self.messages.append(f'Error {self.input_db_path} is not a valid blast db and creation of db is disabled') - self.status = False - return - + raise ValueError(f'Error {self.input_query_path} query fasta does not exist') + elif not self.is_blast_db_valid(): + raise ValueError(f'Error {self.input_db_path} is not a valid blast db') elif not blast_method in self.VALID_BLAST_METHODS: - self.messages.append(f'Error {blast_method} is not a supported blast method: {self.blast_method}') - self.status = False - - elif create_db and not self.is_blast_db_valid(): - (stdout,stderr) = self.makeblastdb() - if not self.is_blast_db_valid(): - self.messages.append(f'Error {self.output_db_path} is not a valid blast db and creation of db failed') - self.status = False - - - self.messages.append(self.run_blast()) + raise ValueError(f'Error {blast_method} is not a supported blast method: {self.blast_method}') + #elif create_db and not self.is_blast_db_valid(): + # (stdout,stderr) = self.makeblastdb() + # print(stdout, stderr) + # if not self.is_blast_db_valid(): + # raise ValueError(f'Error {self.output_db_path} is not a valid blast db and creation of db failed') + # TODO add logging for blast messages + (stdout, stderr) = self.run_blast() + print(stdout, stderr) def makeblastdb(self): @@ -71,10 +64,11 @@ def makeblastdb(self): def is_blast_db_valid(self): extensions = ['nsq', 'nin', 'nhr'] for e in extensions: - if not os.path.isfile(f'{self.input_db_path}.{e}'): + print(self.output_db_path, e) + if not os.path.isfile(f'{self.output_db_path}.{e}'): extensions2 = ['pto', 'ptf', 'phr'] for e2 in extensions2: - if not os.path.isfile(f'{self.input_db_path}.{e2}'): + if not os.path.isfile(f'{self.output_db_path}.{e2}'): return False return True @@ -96,13 +90,9 @@ def run_blast(self): class parse_blast: - BLAST_TABLE_COLS = [] input_file = None - df = None columns = [] filter_options = {} - status = True - messages = [] def __init__(self, input_file,blast_columns,filter_options): self.input_file = input_file @@ -110,9 +100,10 @@ def __init__(self, input_file,blast_columns,filter_options): self.BLAST_TABLE_COLS = blast_columns if not os.path.isfile(self.input_file): - self.messages.append(f'Error {self.input_file} does not exist') - self.status = False - self.read_hit_table() + raise FileNotFoundError(f'Error {self.input_file} does not exist') + + self.df = self.read_hit_table() + print(self.df) self.columns = self.df.columns.tolist() for id_col in ['qseqid','sseqid']: @@ -122,13 +113,16 @@ def __init__(self, input_file,blast_columns,filter_options): self.df = self.df.astype(tp) for col_name in self.filter_options: if col_name in self.columns: - min_value = self.filter_options[col_name]['min'] - max_value = self.filter_options[col_name]['max'] - include = self.filter_options[col_name]['include'] + #min_value = self.filter_options[col_name]['min'] + #max_value = self.filter_options[col_name]['max'] + #include = self.filter_options[col_name]['include'] + min_value = self.filter_options[col_name].min + max_value = self.filter_options[col_name].max + include = self.filter_options[col_name].include self.filter_df(col_name, min_value, max_value,include) def read_hit_table(self): - self.df = pd.read_csv(self.input_file,header=None,names=self.BLAST_TABLE_COLS,sep="\t",low_memory=False) + return pd.read_csv(self.input_file,header=None,names=self.BLAST_TABLE_COLS,sep="\t",low_memory=False) def filter_df(self,col_name,min_value,max_value,include): diff --git a/locidex/classes/fasta.py b/locidex/classes/fasta.py index 83b123b..b6cf7a3 100644 --- a/locidex/classes/fasta.py +++ b/locidex/classes/fasta.py @@ -3,41 +3,49 @@ from mimetypes import guess_type from functools import partial import os -from locidex.utils import calc_md5 +from locidex.utils import calc_md5, slots +from dataclasses import dataclass + +@dataclass +class Fasta: + """ + """ + gene_name: str + seq_id: str + seq: str + hash: str + length: int + __slots__ = slots(__annotations__) class parse_fasta: - input_file = None - seq_obj = None - status = True - messages = [] def __init__(self, input_file,parse_def=False,seq_type=None,delim="|"): self.input_file = input_file if not os.path.isfile(self.input_file): - self.messages.append(f'Error {self.input_file} does not exist') - self.status = False + raise FileNotFoundError("Input file: {} not found.".format(self.input_file)) self.delim = delim self.seq_type = seq_type self.parse_def = parse_def self.seq_obj = self.parse_fasta() - - return + @staticmethod + def normalize_sequence(fasta:str) -> str: + """ + Remove INDELS and lower all characters in sequence + """ + return fasta.lower().replace("-", "") def get_seqids(self): - if self.seq_obj is not None: + if self.seq_obj: return list(self.seq_obj.keys()) - else: - return [] - - def get_seq_by_id(self, id): - if self.seq_obj is not None and id in self.seq_obj: - return self.seq_obj[id] - else: - return {} + raise AssertionError("No fasta file loaded.") + def get_seq_by_id(self, fasta_id: str): + if seq_data := self.seq_obj.get(fasta_id): #is not None and id in self.seq_obj: + return seq_data + raise KeyError("Missing sequence id: {}".format(fasta_id)) def parse_fasta(self): encoding = guess_type(self.input_file)[1] @@ -54,5 +62,5 @@ def parse_fasta(self): if len(h) > 1: gene_name = h[0] seq_id = h[1] - data[id] = {'gene_name': gene_name, 'seq_id':seq_id,'seq': seq, 'hash':calc_md5([seq])[0],'length':len(seq)} + data[id] = Fasta(gene_name=gene_name, seq_id=seq_id, seq=self.normalize_sequence(seq), hash=calc_md5([seq])[0], length=len(seq)) return data diff --git a/locidex/classes/seq_intake.py b/locidex/classes/seq_intake.py index cf350a4..5db080f 100644 --- a/locidex/classes/seq_intake.py +++ b/locidex/classes/seq_intake.py @@ -3,12 +3,25 @@ from locidex.classes.gbk import parse_gbk from locidex.classes.fasta import parse_fasta -from locidex.utils import guess_alphabet, calc_md5, six_frame_translation +from locidex.utils import guess_alphabet, calc_md5, six_frame_translation, slots from locidex.classes.prodigal import gene_prediction -from locidex.constants import DNA_AMBIG_CHARS, DNA_IUPAC_CHARS +from locidex.constants import DNA_AMBIG_CHARS, DNA_IUPAC_CHARS, CharacterConstants from typing import NamedTuple, Optional from dataclasses import dataclass +@dataclass +class HitFilters: + min_dna_len: int + max_dna_len: int + min_dna_ident: float + min_dna_match_cov: float + min_aa_len: int + max_aa_len: int + min_aa_ident: float + min_aa_match_cov: float + dna_ambig_count: int + __slots__ = slots(__annotations__) + @dataclass class SeqObject: @@ -25,8 +38,8 @@ class SeqObject: start_codon: Optional[str] end_codon: Optional[str] count_internal_stop: Optional[int] - __slots__ = ("parent_id", "locus_name", "seq_id", "dna_seq", "dna_ambig_count", - "dna_hash", "dna_len", "aa_seq", "aa_hash", "aa_len", "start_codon", "end_codon", "count_internal_stop") + # Manually adding slots for compatibility + __slots__ = slots(__annotations__) class seq_intake: input_file = '' @@ -96,7 +109,7 @@ def add_codon_data(self): if dna_len >= 6: start_codon = dna_seq[0:3] stop_codon = dna_seq[-3:] - count_internal_stop = record.aa_seq[:-1].count('*') + count_internal_stop = record.aa_seq[:-1].count(CharacterConstants.stop_codon) record.start_codon = start_codon record.end_codon = stop_codon record.count_internal_stop = count_internal_stop @@ -134,14 +147,11 @@ def process_gbk(self) -> list[SeqObject]: return seq_data def process_fasta(self, seq_data = []) -> list[SeqObject]: - obj = parse_fasta(self.input_file) - if obj.status == False: - return ids = obj.get_seqids() for id in ids: features = obj.get_seq_by_id(id) - seq = features['seq'].lower().replace('-','') + seq = features.seq dtype = guess_alphabet(seq) dna_seq = '' dna_hash = '' @@ -168,9 +178,9 @@ def process_fasta(self, seq_data = []) -> list[SeqObject]: aa_len = len(aa_seq) seq_data.append(SeqObject( - parent_id=features['gene_name'], - locus_name = features['gene_name'], - seq_id = features['seq_id'], + parent_id=features.gene_name, + locus_name = features.gene_name, + seq_id = features.seq_id, dna_seq = dna_seq, dna_ambig_count = self.count_ambig_chars(dna_seq, DNA_AMBIG_CHARS), dna_hash = dna_hash, @@ -236,8 +246,6 @@ def count_ambig_chars(self,seq,chars): return count - - class seq_store: #stored_fields = ['parent_id','locus_name','seq_id','dna_hash','dna_len','aa_hash','aa_len','start_codon','stop_codon','count_internal_stop','dna_ambig_count'] record = { @@ -253,13 +261,11 @@ class seq_store: } #def __init__(self,sample_name,db_config_dict,metadata_dict,query_seq_records,blast_columns,filters={},stored_fields=[]): - def __init__(self,sample_name,db_config_dict,metadata_dict,query_seq_records,blast_columns,filters={}): + def __init__(self,sample_name,db_config_dict,metadata_dict,query_seq_records,blast_columns,filters: HitFilters): self.sample_name = sample_name self.record['query_data']['sample_name'] = sample_name self.add_db_config(db_config_dict) self.add_seq_data(query_seq_records) - #if len(stored_fields) > 0 : - # self.stored_fields = stored_fields self.add_db_metadata(metadata_dict) self.add_hit_cols(blast_columns) self.filters = filters @@ -273,14 +279,8 @@ def add_hit_cols(self,columns): self.record['query_hit_columns'] = columns def add_seq_data(self,query_seq_records): - print(type(query_seq_records[0])) for idx, v in enumerate(query_seq_records): self.record['query_data']['query_seq_data'][idx] = v - #print(self.record['query_data']['query_seq_data'][idx]) - #for f in self.stored_fields: - # self.record['query_data']['query_seq_data'][idx][f] = '' - # if f in query_seq_records[idx]: - # self.record['query_data']['query_seq_data'][idx][f] = query_seq_records[idx][f] def add_db_metadata(self,metadata_dict): locus_profile = {} diff --git a/locidex/constants.py b/locidex/constants.py index 10707fa..0ba58c0 100644 --- a/locidex/constants.py +++ b/locidex/constants.py @@ -18,25 +18,57 @@ STOP_CODONS = ['taa','tag','tta','tca','tga','aga','agg'] -BLAST_TABLE_COLS = ''' -qseqid -sseqid -qlen -slen -qstart -qend -sstart -send -length -mismatch -pident -qcovhsp -qcovs -sstrand -evalue -bitscore -'''.strip().split('\n') +@dataclass(frozen=True) +class CharacterConstants: + stop_codon: str = "*" + +#BLAST_TABLE_COLS = ''' +#qseqid +#sseqid +#qlen +#slen +#qstart +#qend +#sstart +#send +#length +#mismatch +#pident +#qcovhsp +#qcovs +#sstrand +#evalue +#bitscore +#'''.strip().split('\n') + +class BlastColumns(NamedTuple): + qseqid: str + sseqid: str + qlen: int + slen: int + qstart: int + qend: int + sstart: int + send: int + length: int + mismatch: str + pident: float + qcovhsp: float + qcovs: float + sstrand: str + evalue: float + bitscore: float + +@dataclass(frozen=True) +class BlastCommands: + # upgrading this to a string enum would be nice + tblastn: str = "tblastn" + blastn: str = "blastn" + blastp: str = "blastp" + @classmethod + def _keys(cls) -> list: + return [i.name for i in fields(cls)] FILE_TYPES = { 'genbank': ["gbk","genbank","gbf","gbk.gz","genbank.gz","gbf.gz","gbff","gbff.gz"], @@ -153,29 +185,3 @@ class LocidexDBHeader(NamedTuple): min_aa_match_cov: Optional[int] count_int_stops: int dna_ambig_count: int - - -#LOCIDEX_DB_HEADER = [ -# 'seq_id', -# 'locus_name', -# 'locus_name_alt', -# 'locus_product', -# 'locus_description', -# 'locus_uid', -# 'dna_seq', -# 'dna_seq_len', -# 'dna_seq_hash', -# 'aa_seq', -# 'aa_seq_len', -# 'aa_seq_hash', -# 'dna_min_len', -# 'dna_max_len', -# 'aa_min_len', -# 'aa_max_len', -# 'dna_min_ident', -# 'aa_min_ident', -# 'min_dna_match_cov', -# 'min_aa_match_cov', -# 'count_int_stops', -# 'dna_ambig_count' -#] diff --git a/locidex/example/build_db_mlst_out/config.json b/locidex/example/build_db_mlst_out/config.json index b896622..e367f41 100644 --- a/locidex/example/build_db_mlst_out/config.json +++ b/locidex/example/build_db_mlst_out/config.json @@ -2,8 +2,8 @@ "db_name": "Locidex Database", "db_version": "1.0.0", "db_date": "2024/30/04", - "db_author": "", - "db_desc": "", + "db_author": "mw", + "db_desc": "test", "db_num_seqs": 53, "is_nucl": true, "is_prot": true, diff --git a/locidex/extract.py b/locidex/extract.py index c0637cd..637dda8 100644 --- a/locidex/extract.py +++ b/locidex/extract.py @@ -12,7 +12,7 @@ from locidex.classes.blast import blast_search, parse_blast from locidex.classes.db import search_db_conf, db_config from locidex.classes.seq_intake import seq_intake, seq_store -from locidex.constants import SEARCH_RUN_DATA, FILE_TYPES, BLAST_TABLE_COLS, DBConfig, DB_EXPECTED_FILES, NT_SUB, EXTRACT_MODES, OPTION_GROUPS +from locidex.constants import SEARCH_RUN_DATA, FILE_TYPES, BlastColumns, DBConfig, DB_EXPECTED_FILES, NT_SUB, EXTRACT_MODES, OPTION_GROUPS from locidex.version import __version__ from locidex.classes.aligner import perform_alignment, aligner from locidex.utils import check_db_groups @@ -170,7 +170,7 @@ def run_extract(config): hit_file = os.path.join(blast_dir_base, "hsps.txt") obj = blast_search(input_db_path=db_path, input_query_path=nt_db, output_results=hit_file, blast_params=blast_params, blast_method='blastn', - blast_columns=BLAST_TABLE_COLS,create_db=True) + blast_columns=BlastColumns._fields,create_db=True) if obj.status == False: print("Error something went wrong, please check error messages above") @@ -183,7 +183,7 @@ def run_extract(config): 'qcovhsp': {'min': min_dna_match_cov, 'max': None, 'include': None}, } - hit_df = parse_blast(hit_file, BLAST_TABLE_COLS, filter_options).df + hit_df = parse_blast(hit_file, BlastColumns._fields, filter_options).df hit_df['sseqid'] = hit_df['sseqid'].astype(str) hit_df['qseqid'] = hit_df['qseqid'].astype(str) diff --git a/locidex/format.py b/locidex/format.py index d77bb6e..6e2d919 100644 --- a/locidex/format.py +++ b/locidex/format.py @@ -15,7 +15,7 @@ from Bio.Seq import Seq from pyrodigal import GeneFinder -from locidex.constants import FILE_TYPES, LocidexDBHeader +from locidex.constants import FILE_TYPES, LocidexDBHeader, CharacterConstants from locidex.utils import six_frame_translation, revcomp, calc_md5 from locidex.version import __version__ @@ -24,7 +24,7 @@ class locidex_format: input_type = None delim = '_' status = True - __stop_codon = "*" + __stop_codon = CharacterConstants.stop_codon # ? These two parameters below can probably be cleaned up __file_input = "file" diff --git a/locidex/manifest.py b/locidex/manifest.py index 13c48c7..5fe1e77 100644 --- a/locidex/manifest.py +++ b/locidex/manifest.py @@ -1,6 +1,6 @@ import pathlib import json -from typing import List, Union, Tuple, Dict +from typing import List, Union, Tuple, Dict, Optional from dataclasses import dataclass import os import re @@ -12,6 +12,65 @@ +class DBData: + """ + Validate and get all database data for other modules. + + * This class will create some redundancy and will need to be refactored to reflect + * the overall use of this module in the future. But once refactoring is complete + * it should be much easier to refactor this and the other modules. Additionally, at that + * point we should have a better understanding of how all the modules fit together. + """ + + __db_names = ["nucleotide", "protein"] + __nucleotide_path = pathlib.Path(__db_names[0]) + __protein_path = pathlib.Path(__db_names[1]) + + def __init__(self, db_dir: pathlib.Path): + self.db_dir = db_dir + self.config_data: DBConfig = self._get_config(self.db_dir) + self.metadata: dict = self._get_metadata(self.db_dir) + self.nucleotide, self.protein = self._get_blast_dbs(db_dir, self.config_data) + + @property + def nucleotide_blast_db(self): + if self.nucleotide is None: + raise ValueError("Nucleotide blast database does not exist") + return self.nucleotide / self.__nucleotide_path + + @property + def protein_blast_db(self): + if self.protein is None: + raise ValueError("Protein blast database does not exist") + return self.protein / self.__protein_path + + def _get_config(self, db_dir: pathlib.Path) -> DBConfig: + """ + Validates the config file and searializes the data into a DBConfig object + """ + return check_config(db_dir) + + def _get_metadata(self, db_dir: pathlib.Path) -> dict: + metadata_file = db_dir.joinpath(DBFiles.meta_file) + if not metadata_file.exists(): + raise FileNotFoundError("Metadata file does not exist. Database path maybe incorrect: {}".format(db_dir)) + + def _get_blast_dbs(self, db_dir: pathlib.Path, config_data: DBConfig) -> Tuple[Optional[pathlib.Path], Optional[pathlib.Path]]: + blast_db = db_dir.joinpath(DBFiles.blast_dir) + nucleotide: Optional[pathlib.Path] = None + protein: Optional[pathlib.Path] = None + if not blast_db.exists(): + raise OSError("blast directory not found. Database path maybe incorrect: {}".format(db_dir)) + if config_data.is_nucl: + nucleotide = blast_db.joinpath(self.__nucleotide_path) + if not nucleotide.exists(): + raise FileNotFoundError("Cannot find nucleotide database, but it should exist. {}".format(nucleotide)) + if config_data.is_prot: + protein = blast_db.joinpath(self.__protein_path) + if not protein.exists(): + raise FileNotFoundError("Cannot find protein database, but it should exist. {}".format(protein)) + return nucleotide, protein + class ManifestItem: """ @@ -68,7 +127,7 @@ def add_args(parser=None): return parser -def check_config(directory: pathlib.Path) -> None: +def check_config(directory: pathlib.Path) -> DBConfig: """ Validate config file in a directory. Throws an error if any required parameters are missing. @@ -77,7 +136,7 @@ def check_config(directory: pathlib.Path) -> None: """ config_dir = pathlib.Path(directory / DBFiles.config_file) - config_data: Union[DBConfig, None] = None + config_data: Optional[DBConfig] = None with open(config_dir, 'r') as conf: config_data = DBConfig(**json.load(conf)) for k, v in config_data.to_dict().items(): @@ -197,7 +256,7 @@ def read_manifest(input_file: pathlib.Path) -> dict: def get_manifest_db(input_file: pathlib.Path, name: str, version: str) -> ManifestItem: """ - Retruns path to the database file selected + Returns path to the database file selected """ output = read_manifest(input_file) db_out = select_db(output, name, version) diff --git a/locidex/search.py b/locidex/search.py index ede6fef..9f41794 100644 --- a/locidex/search.py +++ b/locidex/search.py @@ -5,14 +5,17 @@ from pathlib import Path from argparse import (ArgumentParser, ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter) from datetime import datetime +from typing import Optional +from dataclasses import dataclass import pandas as pd -from locidex.classes.blast import blast_search, parse_blast +from locidex.classes.blast import blast_search, parse_blast, FilterOptions from locidex.classes.db import search_db_conf, db_config -from locidex.classes.seq_intake import seq_intake, seq_store -from locidex.constants import SEARCH_RUN_DATA, FILE_TYPES, BLAST_TABLE_COLS, DB_EXPECTED_FILES, OPTION_GROUPS, DBConfig -from locidex.utils import write_seq_dict, check_db_groups +from locidex.manifest import DBData +from locidex.classes.seq_intake import seq_intake, seq_store, HitFilters +from locidex.constants import SEARCH_RUN_DATA, FILE_TYPES, BlastColumns, DB_EXPECTED_FILES, OPTION_GROUPS, DBConfig +from locidex.utils import write_seq_dict, check_db_groups, slots from locidex.version import __version__ def add_args(parser=None): @@ -62,8 +65,6 @@ def add_args(parser=None): return parser - - def perform_search(query_file,results_file,db_path,blast_prog,blast_params,columns): return blast_search(db_path,query_file,results_file,blast_params,blast_prog,columns) @@ -72,7 +73,6 @@ def create_fasta_from_df(df,label_col,seq_col,out_file): write_seq_dict(dict(zip(df[label_col].tolist(), df[seq_col])), out_file) - def run_search(config): # Input Parameters @@ -96,10 +96,10 @@ def run_search(config): perform_annotation = config['annotate'] max_target_seqs = config['max_target_seqs'] - if 'max_ambig_count' in config: - max_ambig_count = config['max_ambig_count'] + if max_count := config.get('max_ambig_count'): + max_ambig_count = max_count else: - max_ambig_count = 99999999999999 + max_ambig_count = float('inf') if not perform_annotation: perform_annotation = False @@ -118,8 +118,12 @@ def run_search(config): # print(f'There is an issue with provided db directory: {db_dir}\n {db_database_config.messages}') # sys.exit() - metadata_path = db_database_config.meta_file_path - metadata_obj = db_config(metadata_path, ['meta', 'info']) + db_data = DBData(db_dir=db_dir) + + #metadata_path = db_database_config.meta_file_path + #metadata_obj = db_config(metadata_path, ['meta', 'info']) + metadata_obj = db_data.metadata + #blast_database_paths = db_database_config.blast_paths blast_database_paths = db_database_config.blast_paths if os.path.isdir(outdir) and not force: print(f'Error {outdir} exists, if you would like to overwrite, then specify --force') @@ -145,11 +149,6 @@ def run_search(config): seq_obj = seq_intake(query_file, format, 'CDS', translation_table, perform_annotation) - if seq_obj.status == False: - print( - f'Something went wrong parsing query file: {query_file}, please check logs and messages:\n{seq_obj.messages}') - sys.exit() - if perform_annotation: gbk_data = [] for idx,genes in enumerate(seq_obj.prodigal_genes): @@ -174,42 +173,42 @@ def run_search(config): } filter_options = { - 'evalue': {'min': None, 'max': min_evalue, 'include': None}, + 'evalue': FilterOptions(min=None, max=min_evalue, include=None) } df = pd.DataFrame.from_dict(seq_obj.seq_data) filtered_df = df filtered_df['index'] = filtered_df.index.to_list() - hit_filters = { - 'min_dna_len': min_dna_len, - 'max_dna_len': max_dna_len, - 'min_dna_ident': min_dna_ident, - 'min_dna_match_cov': min_dna_match_cov, - 'min_aa_len': min_aa_len, - 'max_aa_len': max_aa_len, - 'min_aa_ident': min_aa_ident, - 'min_aa_match_cov': min_aa_match_cov, - 'dna_ambig_count':max_ambig_count + hit_filters = HitFilters( + min_dna_len = min_dna_len, + max_dna_len=max_dna_len, + min_dna_ident=min_dna_ident, + min_dna_match_cov=min_dna_match_cov, + min_aa_len=min_aa_len, + max_aa_len=max_aa_len, + min_aa_ident=min_aa_ident, + min_aa_match_cov=min_aa_match_cov, + dna_ambig_count=max_ambig_count) + - } - store_obj = seq_store(sample_name, db_database_config.config_obj.config, metadata_obj.config['meta'], - seq_obj.seq_data, BLAST_TABLE_COLS, hit_filters) + store_obj = seq_store(sample_name, db_data.config, metadata_obj.config['meta'], + seq_obj.seq_data, BlastColumns._fields, hit_filters) - for db_label in blast_database_paths: + for db_label in (db_data.nucleotide,): label_col = 'index' - if db_label == 'nucleotide': + if db_data.nucleotide: blast_prog = 'blastn' seq_col = 'dna_seq' - d = os.path.join(blast_dir_base, 'nucleotide') - filter_options['pident'] = {'min': min_dna_ident, 'max': None, 'include': None} - filter_options['qcovs'] = {'min': min_dna_match_cov, 'max': None, 'include': None} + d = db_data.nucleotide + filter_options['pident'] = FilterOptions(min=min_dna_ident, max=None, include=None) + filter_options['qcovs'] = FilterOptions(min=min_dna_match_cov, max=None, include=None) - elif db_label == 'protein': + elif db_data.protein: blast_prog = 'blastp' seq_col = 'aa_seq' - d = os.path.join(blast_dir_base, 'protein') - filter_options['pident'] = {'min': min_aa_ident, 'max': None, 'include': None} - filter_options['qcovs'] = {'min': min_aa_match_cov, 'max': None, 'include': None} + d = db_data.protein + filter_options['pident'] = FilterOptions(min=min_aa_ident, max=None, include=None) + filter_options['qcovs'] = FilterOptions(min=min_aa_match_cov, max=None, include=None) if not os.path.isdir(d): os.makedirs(d, 0o755) @@ -222,8 +221,8 @@ def run_search(config): db_path = blast_database_paths[db_label] create_fasta_from_df(filtered_df, label_col, seq_col, os.path.join(d, "queries.fasta")) perform_search(os.path.join(d, "queries.fasta"), os.path.join(d, "hsps.txt"), db_path, blast_prog, blast_params, - BLAST_TABLE_COLS) - hit_obj = parse_blast(os.path.join(d, "hsps.txt"), BLAST_TABLE_COLS, filter_options) + BlastColumns._fields) + hit_obj = parse_blast(os.path.join(d, "hsps.txt"), BlastColumns._fields, filter_options) hit_df = hit_obj.df store_obj.add_hit_data(hit_df, db_label, 'qseqid') diff --git a/locidex/utils.py b/locidex/utils.py index 0546626..3d9b79c 100644 --- a/locidex/utils.py +++ b/locidex/utils.py @@ -6,10 +6,16 @@ from pathlib import Path from locidex.manifest import ManifestItem from Bio.Seq import Seq - +from typing import Dict, FrozenSet from locidex.constants import NT_SUB, PROTEIN_ALPHA, DNA_ALPHA, OPTION_GROUPS import locidex.manifest as manifest +def slots(annotations: Dict[str, object]) -> FrozenSet[str]: + """ + Thank you for this: https://stackoverflow.com/a/63658478 + """ + return frozenset(annotations.keys()) + def check_db_groups(analysis_params: dict, cmd_args: argparse.Namespace, param_db: str = "db") -> dict: """ Verify that a locidex database, or database group passed has all of the require parameters diff --git a/tests/test_blast.py b/tests/test_blast.py index 364c313..4ae8998 100644 --- a/tests/test_blast.py +++ b/tests/test_blast.py @@ -1,51 +1,60 @@ import pytest, os import pandas as pd import locidex.classes.blast -from locidex.constants import BLAST_TABLE_COLS +import shutil +from locidex.classes.blast import FilterOptions +from locidex.constants import BlastColumns PACKAGE_ROOT = os.path.dirname(locidex.__file__) -@pytest.fixture +@pytest.fixture() def blast_search_class_init(tmpdir): - blast_search_obj = locidex.classes.blast.blast_search(input_db_path=None, + test_dir = tmpdir + blast_search_obj = locidex.classes.blast.blast_search(input_db_path=os.path.join(PACKAGE_ROOT, "example/build_db_mlst_out/blast/nucleotide/nucleotide"), input_query_path=os.path.join(PACKAGE_ROOT, 'example/search/NC_003198.1.fasta'), - output_results=os.path.join(tmpdir,"hsps.txt"), blast_params={'evalue': 0.0001,'max_target_seqs': 10,'num_threads': 1}, + output_results=os.path.join(test_dir,"hsps.txt"), + blast_params={'evalue': 0.0001,'max_target_seqs': 10,'num_threads': 1}, blast_method='blastn', - blast_columns=BLAST_TABLE_COLS,create_db=False) - return blast_search_obj - - -def test_make_mlst_database(blast_search_class_init, tmpdir): - blast_search_obj = blast_search_class_init - blast_search_obj.input_db_path = os.path.join(PACKAGE_ROOT, "example/build_db_mlst_out/blast/nucleotide/nucleotide.fasta") - blast_search_obj.input_query_path = os.path.join(PACKAGE_ROOT, 'example/search/NC_003198.1.fasta') - blast_search_obj.output_db_path = os.path.join(tmpdir,"nucleotide_mlst_database") - blast_search_obj.makeblastdb(); - - assert len([file for file in os.listdir(tmpdir) if "nucleotide_mlst_database" in file]) > 0 - - blast_search_obj.input_db_path = os.path.join(tmpdir, "nucleotide_mlst_database") # assign new path to freshly created database to check its validity - assert blast_search_obj.is_blast_db_valid() == True - - - -def test_run_blast_on_genome_and_check_output(blast_search_class_init, tmpdir): - test_make_mlst_database(blast_search_class_init,tmpdir) - blast_search_obj = blast_search_class_init - blast_search_obj.input_db_path = os.path.join(tmpdir, "nucleotide_mlst_database") - blast_search_obj.run_blast() - output_blast_results_path = os.path.join(tmpdir,"hsps.txt") + blast_columns=BlastColumns._fields) + #blast_search_obj.run_blast() + return test_dir, blast_search_obj + + +#def test_make_mlst_database(blast_search_class_init): +# test_dir, obj = blast_search_class_init +# #print(test_dir) +# #blast_search_obj.input_db_path = os.path.join(PACKAGE_ROOT, "example/build_db_mlst_out/blast/nucleotide/nucleotide.fasta") +# #blast_search_obj.input_query_path = os.path.join(PACKAGE_ROOT, 'example/search/NC_003198.1.fasta') +# #blast_search_obj.output_db_path = os.path.join(tmpdir,"nucleotide_mlst_database") +# #blast_search_obj.makeblastdb(); +# +# assert len([file for file in os.listdir(test_dir) if "nucleotide_mlst_database" in file]) > 0 +# +# #blast_search_obj.input_db_path = os.path.join(test_dir, "nucleotide_mlst_database") # assign new path to freshly created database to check its validity +# #assert blast_search_obj.is_blast_db_valid() == True + + + +def test_run_blast_on_genome_and_check_output(blast_search_class_init): + #test_make_mlst_database(blast_search_class_init,tmpdir) + blast_search_obj, obj = blast_search_class_init + #blast_search_obj.input_db_path = os.path.join(tmpdir, "nucleotide_mlst_database") + #obj.run_blast() + #output_blast_results_path = os.path.join(tmpdir,"hsps.txt") + output_blast_results_path = os.path.join(blast_search_obj, "hsps.txt") assert os.path.exists(output_blast_results_path) == True + print(output_blast_results_path) + with open(output_blast_results_path, "r") as fp: output_blast_results_file = fp.readlines() assert len(output_blast_results_file) == 10 - + print(output_blast_results_path) parse_blast_obj = locidex.classes.blast.parse_blast(input_file = output_blast_results_path, - blast_columns = BLAST_TABLE_COLS, - filter_options={'bitscore':{'min':600, 'max':None, 'include':None}}) + blast_columns = BlastColumns._fields, + filter_options={'bitscore': FilterOptions(min=600, max=None, include=None)}) assert parse_blast_obj.df.shape[0] == 7 assert parse_blast_obj.df['bitscore'].max() == 926 - assert len([item for item in parse_blast_obj.df.columns.to_list() if item not in BLAST_TABLE_COLS]) == 0 #check if columns in df and constant are identical + assert len([item for item in parse_blast_obj.df.columns.to_list() if item not in BlastColumns._fields]) == 0 #check if columns in df and constant are identical diff --git a/tests/test_extractor.py b/tests/test_extractor.py index 41953a8..1336ca2 100644 --- a/tests/test_extractor.py +++ b/tests/test_extractor.py @@ -1,7 +1,8 @@ import pytest import os -import locidex.classes.blast -from locidex.constants import BLAST_TABLE_COLS +import locidex +from locidex.classes.blast import FilterOptions, blast_search, parse_blast +from locidex.constants import BlastColumns from locidex.classes.extractor import extractor from locidex.classes.seq_intake import seq_intake from locidex.classes.db import db_config @@ -14,15 +15,15 @@ def blast_db_and_search(tmpdir,input_db_path): - blast_search_obj = locidex.classes.blast.blast_search(input_db_path=input_db_path, + blast_search_obj = blast_search(input_db_path=input_db_path, input_query_path=os.path.join(PACKAGE_ROOT, 'example/build_db_mlst_out/blast/nucleotide/nucleotide.fasta'), output_results=os.path.join(tmpdir,"hsps.txt"), blast_params={'evalue': 0.0001,'max_target_seqs': 10,'num_threads': 1}, blast_method='blastn', - blast_columns=BLAST_TABLE_COLS,create_db=True) + blast_columns=BlastColumns._fields,create_db=True) blast_search_obj.run_blast() output_blast_results_path = os.path.join(tmpdir,"hsps.txt") - parse_blast_obj = locidex.classes.blast.parse_blast(input_file = output_blast_results_path, - blast_columns = BLAST_TABLE_COLS, + parse_blast_obj = parse_blast(input_file = output_blast_results_path, + blast_columns = BlastColumns._fields, filter_options={'bitscore':{'min':600, 'max':None, 'include':None}}) return parse_blast_obj @@ -38,20 +39,22 @@ def seq_intake_fixture(): def test_extractor_initialization(seq_intake_fixture, tmpdir): db_path=os.path.join(tmpdir,"contigs.fasta") - nt_db = os.path.join(PACKAGE_ROOT,'example/build_db_mlst_out/blast/nucleotide/nucleotide.fasta') + nt_db_test = os.path.join(PACKAGE_ROOT,'example/build_db_mlst_out/blast/nucleotide/nucleotide.fasta') + nt_db = os.path.join(PACKAGE_ROOT,'example/build_db_mlst_out/blast/nucleotide/') hit_file = os.path.join(tmpdir,"hsps.txt") blast_params={'evalue': 0.0001, 'max_target_seqs': 10, 'num_threads': 1} metadata_path = os.path.join(PACKAGE_ROOT,'example/build_db_mlst_out/meta.json') seq_obj = seq_intake_fixture seq_data={} - with open(db_path,'w') as oh: - for idx,seq in enumerate(seq_obj.seq_data): - seq_data[str(idx)] = {'id':str(seq['seq_id']),'seq':seq['dna_seq']} - oh.write(">{}\n{}\n".format(idx,seq['dna_seq'])) - locidex.classes.blast.blast_search(input_db_path=db_path, input_query_path=nt_db, - output_results=hit_file, blast_params=blast_params, blast_method='blastn', - blast_columns=BLAST_TABLE_COLS,create_db=True) - hit_df = locidex.classes.blast.parse_blast(hit_file, BLAST_TABLE_COLS, {}).df + #with open(db_path,'w') as oh: + # for idx,seq in enumerate(seq_obj.seq_data): + # seq_data[str(idx)] = {'id':str(seq.seq_id),'seq':seq.dna_seq} + # oh.write(">{}\n{}\n".format(idx,seq.dna_seq)) + blast_search(input_db_path=nt_db, + input_query_path=nt_db_test, + output_results=hit_file, blast_params=blast_params, blast_method='blastn', + blast_columns=BlastColumns._fields) + hit_df = parse_blast(hit_file, BlastColumns._fields, {}).df loci = []; metadata_obj = db_config(metadata_path, ['meta', 'info']) for idx,row in hit_df.iterrows(): qid = str(row['qseqid']) diff --git a/tests/test_fasta.py b/tests/test_fasta.py index db64e02..c4cfc18 100644 --- a/tests/test_fasta.py +++ b/tests/test_fasta.py @@ -24,10 +24,9 @@ def fasta_file(fasta_content): def test_parse_fasta_normal(fasta_file): parser = parse_fasta(fasta_file) - assert parser.status, "Parser status should be True" assert len(parser.get_seqids()) == 2, "There should be two sequences" seq_data = parser.get_seq_by_id("gene1|123") - assert seq_data['seq'] == "ATGCGTACGTAGCTAGC", "Sequence data should match the input" + assert seq_data.seq == "ATGCGTACGTAGCTAGC", "Sequence data should match the input" def test_parse_fasta_with_nonexistent_file(): # Assuming that the locidex errors out with File not found error if FASTA is non existant: @@ -39,7 +38,6 @@ def test_parse_fasta_with_nonexistent_file(): def test_parse_fasta_with_definitions(fasta_file): parser = parse_fasta(fasta_file, parse_def=True, delim="|") - assert parser.status, "Parser should correctly parse with definitions" seq_data = parser.get_seq_by_id("gene1|123") - assert seq_data['gene_name'] == "gene1", "Gene name should be correctly parsed from definition" - assert seq_data['seq_id'] == "123", "Seq ID should be correctly parsed from definition" + assert seq_data.gene_name == "gene1", "Gene name should be correctly parsed from definition" + assert seq_data.seq_id == "123", "Seq ID should be correctly parsed from definition" diff --git a/tests/test_seq_intake.py b/tests/test_seq_intake.py index 75df91e..318985b 100644 --- a/tests/test_seq_intake.py +++ b/tests/test_seq_intake.py @@ -1,7 +1,7 @@ import os, warnings -import locidex.classes.seq_intake -from locidex.classes.seq_intake import SeqObject -from locidex.constants import BLAST_TABLE_COLS, DB_EXPECTED_FILES, DBConfig +import locidex +from locidex.classes.seq_intake import SeqObject, seq_intake, seq_store +from locidex.constants import BlastColumns, DB_EXPECTED_FILES, DBConfig from locidex.classes.db import search_db_conf, db_config from collections import Counter from dataclasses import asdict @@ -11,10 +11,8 @@ def seq_intake_class_init(input_file, file_type, perform_annotation): #reset global class variables to avoid ambiguous results - locidex.classes.seq_intake.seq_intake.seq_data = [] - locidex.classes.seq_intake.seq_intake.messages = [] - locidex.classes.seq_intake.seq_intake.prodigal_genes = [] - obj = locidex.classes.seq_intake.seq_intake(input_file=input_file, + + obj = seq_intake(input_file=input_file, file_type=file_type,feat_key='CDS',translation_table=11, perform_annotation=perform_annotation,num_threads=1,skip_trans=False) return obj @@ -24,12 +22,12 @@ def test_seq_store_class(): db_database_config = search_db_conf(db_dir, DB_EXPECTED_FILES, DBConfig._keys()) metadata_obj = db_config(db_database_config.meta_file_path, ['meta', 'info']) sample_name = 'NC_003198.1.fasta' - seq_obj = locidex.classes.seq_intake.seq_intake(input_file=os.path.join(PACKAGE_ROOT, 'example/search/NC_003198.1.fasta'), + seq_obj = seq_intake(input_file=os.path.join(PACKAGE_ROOT, 'example/search/NC_003198.1.fasta'), file_type='fasta', perform_annotation=False) hit_filters = {'min_dna_len': 1, 'max_dna_len': 10000000, 'min_dna_ident': 80.0, 'min_dna_match_cov': 80.0, 'min_aa_len': 1, 'max_aa_len': 10000000, 'min_aa_ident': 80.0, 'min_aa_match_cov': 80.0, 'dna_ambig_count': 99999999999999} - seq_store_obj = locidex.classes.seq_intake.seq_store(sample_name, db_database_config.config_obj.config, metadata_obj.config['meta'], - seq_obj.seq_data, BLAST_TABLE_COLS, hit_filters) + seq_store_obj = seq_store(sample_name, db_database_config.config_obj.config, metadata_obj.config['meta'], + seq_obj.seq_data, BlastColumns._fields, hit_filters) assert list(seq_store_obj.record.keys()) == ['db_info', 'db_seq_info', 'query_data', 'query_hit_columns'] assert list(seq_store_obj.record['db_info'].keys()) == ['db_name', 'db_version', 'db_date', 'db_author', 'db_desc', 'db_num_seqs', 'is_nucl', 'is_prot', 'nucleotide_db_name', 'protein_db_name'] @@ -84,14 +82,14 @@ def test_read_fasta_file(): msg = f"Expected ORFs number is {expected_orfs} but found {len(seq_intake_object.seq_data)}! Check pyrodigal and python versions." warnings.warn(msg) assert len(seq_intake_object.seq_data) > 0 - assert sum([contig['aa_len'] for contig in seq_intake_object.seq_data]) > 0 - assert any([True if 'NC_003198.1' in contig['parent_id'] else False for contig in seq_intake_object.seq_data]) == True - assert any([True if 'NC_003198.1' in contig['locus_name'] else False for contig in seq_intake_object.seq_data]) == True - assert any([True if 'NC_003198.1' in contig['seq_id'] else False for contig in seq_intake_object.seq_data]) == True - assert len(seq_intake_object.seq_data[0]['dna_seq']) > 0 - assert seq_intake_object.seq_data[0]['dna_len'] > 0 - assert len(seq_intake_object.seq_data[0]['dna_hash']) > 0 - assert any([ True if contig['dna_ambig_count'] == 0 else False for contig in seq_intake_object.seq_data]) == True - + assert sum([contig.aa_len for contig in seq_intake_object.seq_data]) > 0 + assert any([True if 'NC_003198.1' in contig.parent_id else False for contig in seq_intake_object.seq_data]) == True + assert any([True if 'NC_003198.1' in contig.locus_name else False for contig in seq_intake_object.seq_data]) == True + assert any([True if 'NC_003198.1' in contig.seq_id else False for contig in seq_intake_object.seq_data]) == True + assert len(seq_intake_object.seq_data[0].dna_seq) > 0 + assert seq_intake_object.seq_data[0].dna_len > 0 + assert len(seq_intake_object.seq_data[0].dna_hash) > 0 + assert any([ True if contig.dna_ambig_count == 0 else False for contig in seq_intake_object.seq_data]) == True + From 66de6d6053536aef7191bb7c8dd1af10f2dbe564 Mon Sep 17 00:00:00 2001 From: Matthew Wells Date: Wed, 8 May 2024 16:53:18 -0500 Subject: [PATCH 3/7] updated blast module to other class --- locidex/classes/blast2.py | 124 ++++++++++++++++++++++++++++++++++++++ locidex/search.py | 32 ++++++---- tests/test_blast2.py | 40 ++++++++++++ 3 files changed, 183 insertions(+), 13 deletions(-) create mode 100644 locidex/classes/blast2.py create mode 100644 tests/test_blast2.py diff --git a/locidex/classes/blast2.py b/locidex/classes/blast2.py new file mode 100644 index 0000000..f69ba1f --- /dev/null +++ b/locidex/classes/blast2.py @@ -0,0 +1,124 @@ +""" +Blast module refactored +""" + +import pandas as pd + +from locidex.classes import run_command +from locidex.utils import slots +from locidex.manifest import DBData +from locidex.constants import BlastCommands +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, List, Dict +import os + + +@dataclass +class FilterOptions: + min: Optional[float] + max: Optional[float] + include: Optional[bool] + __slots__ = slots(__annotations__) + + +class BlastSearch: + __blast_commands = set(BlastCommands._keys()) + __blast_extensions_nt = frozenset(['.nsq', '.nin', '.nhr']) + __blast_extensions_pt = frozenset(['.pto', '.ptf', '.phr']) + __filter_columns = ["qseqid", "sseqid"] + + def __init__(self, db_data: DBData, query_path: Path, blast_params: dict, blast_method: str, blast_columns: List[str], filter_options: Dict[str, FilterOptions]): + self.db_data = db_data + if blast_method not in self.__blast_commands: + raise ValueError("{} is not a valid blast command please pick from {}".format(blast_method, self.__blast_commands)) + self.query_path = query_path + self.blast_params = blast_params + self.blast_method = blast_method + self.blast_columns = blast_columns + self.filter_options = filter_options + + def get_blast_data(self, db_path, output) -> pd.DataFrame: + """ + Run blast and parse results + """ + stdout, stderr = self._run_blast(db=db_path, output=output) + blast_data = self.parse_blast(output_file=output) + return blast_data + + def parse_blast(self, output_file: Path): + """ + Parse a blast output file + output_file Path: Generate blast data + """ + df = self.read_hit_table(output_file) + columns = df.columns.tolist() + for id_col in self.__filter_columns: + tp = {} + if id_col in columns: + tp[id_col] = 'object' + df = df.astype(tp) + for col_name in self.filter_options: + if col_name in self.columns: + min_value = self.filter_options[col_name].min + max_value = self.filter_options[col_name].max + include = self.filter_options[col_name].include + self.filter_df(col_name, min_value, max_value, include) + return df + + + def filter_df(self,col_name,min_value,max_value,include): + if col_name not in self.columns: + return False + if min_value is not None: + self.df = self.df[self.df[col_name] >= min_value] + if max_value is not None: + self.df = self.df[self.df[col_name] <= max_value] + if include is not None: + self.df = self.df[self.df[col_name].isin(include)] + return True + + def read_hit_table(self, blast_data): + return pd.read_csv(blast_data,header=None,names=self.blast_columns,sep="\t",low_memory=False) + + + def _check_blast_files(self, db_dir: Path, extensions: frozenset): + """ + """ + extensions_ = set([i.suffix for i in db_dir.iterdir()]) + if not extensions_.issuperset(extensions): + raise ValueError("Missing required blast files. {}".format([i for i in extensions_ if i not in extensions])) + + def validate_blast_db(self, db_data=None): + """ + """ + if db_data is None: + db_data = self.db_data + if db_data.nucleotide: + self._check_blast_files(db_data.nucleotide, self.__blast_extensions_nt) + + if db_data.protein: + self._check_blast_files(db_data.protein, self.__blast_extensions_pt) + + def _run_blast(self, db: Path, output: Path): + """ + db PAth: Path to the blast database to use, + output Path: Path to file for blast output + """ + command = [ + self.blast_method, + '-query', self.query_path, + '-db', str(db), + '-out', str(output), + '-outfmt', "'6 {}'".format(' '.join(self.blast_columns)), + ] + for param in self.blast_params: + if param == "parse_seqids": + command.append(f"-{param}") + else: + command += [f'-{param}', f'{self.blast_params[param]}'] + return run_command(" ".join([str(x) for x in command])) + + + + diff --git a/locidex/search.py b/locidex/search.py index 9f41794..b62bc35 100644 --- a/locidex/search.py +++ b/locidex/search.py @@ -10,7 +10,8 @@ import pandas as pd -from locidex.classes.blast import blast_search, parse_blast, FilterOptions +#from locidex.classes.blast import blast_search, parse_blast, FilterOptions +from locidex.classes.blast2 import BlastSearch, FilterOptions from locidex.classes.db import search_db_conf, db_config from locidex.manifest import DBData from locidex.classes.seq_intake import seq_intake, seq_store, HitFilters @@ -124,7 +125,7 @@ def run_search(config): #metadata_obj = db_config(metadata_path, ['meta', 'info']) metadata_obj = db_data.metadata #blast_database_paths = db_database_config.blast_paths - blast_database_paths = db_database_config.blast_paths + #blast_database_paths = db_database_config.blast_paths if os.path.isdir(outdir) and not force: print(f'Error {outdir} exists, if you would like to overwrite, then specify --force') sys.exit() @@ -162,9 +163,9 @@ def run_search(config): oh.write("\n".join([str(x) for x in gbk_data])) - blast_dir_base = os.path.join(outdir, 'blast') - if not os.path.isdir(blast_dir_base): - os.makedirs(blast_dir_base, 0o755) + #blast_dir_base = os.path.join(outdir, 'blast') + #if not os.path.isdir(blast_dir_base): + # os.makedirs(blast_dir_base, 0o755) blast_params = { 'evalue': min_evalue, @@ -194,7 +195,8 @@ def run_search(config): store_obj = seq_store(sample_name, db_data.config, metadata_obj.config['meta'], seq_obj.seq_data, BlastColumns._fields, hit_filters) - for db_label in (db_data.nucleotide,): + ############# Tommorow wrap this in two individual functions + for db_label in dbs: label_col = 'index' if db_data.nucleotide: blast_prog = 'blastn' @@ -218,13 +220,17 @@ def run_search(config): if os.path.isfile(os.path.join(d, "hsps.txt")): os.remove(os.path.join(d, "hsps.txt")) - db_path = blast_database_paths[db_label] + #db_path = blast_database_paths[db_label] create_fasta_from_df(filtered_df, label_col, seq_col, os.path.join(d, "queries.fasta")) - perform_search(os.path.join(d, "queries.fasta"), os.path.join(d, "hsps.txt"), db_path, blast_prog, blast_params, - BlastColumns._fields) - hit_obj = parse_blast(os.path.join(d, "hsps.txt"), BlastColumns._fields, filter_options) - hit_df = hit_obj.df - store_obj.add_hit_data(hit_df, db_label, 'qseqid') + #perform_search(os.path.join(d, "queries.fasta"), os.path.join(d, "hsps.txt"), db_path, blast_prog, blast_params, + # BlastColumns._fields) + #hit_obj = parse_blast(os.path.join(d, "hsps.txt"), BlastColumns._fields, filter_options) + #hit_df = hit_obj.df + search_data = BlastSearch(db_data, os.path.join(d, "queries.fasta"), blast_params, blast_prog, BlastColumns._fields, filter_options) + searched_df = search_data.get_blast_data(db_data.nucleotide_blast_db, os.path.join(d, "hsps.txt")) + + #store_obj.add_hit_data(hit_df, db_label, 'qseqid') + store_obj.add_hit_data(searched_df, db_label, 'qseqid') store_obj.filter_hits() store_obj.convert_profile_to_list() @@ -234,7 +240,7 @@ def run_search(config): fh.write(json.dumps(store_obj.record, indent=4)) run_data['analysis_end_time'] = datetime.now().strftime("%d/%m/%Y %H:%M:%S") - print(run_data) + with open(os.path.join(outdir,"run.json"),'w' ) as fh: fh.write(json.dumps(run_data, indent=4)) diff --git a/tests/test_blast2.py b/tests/test_blast2.py new file mode 100644 index 0000000..4663c82 --- /dev/null +++ b/tests/test_blast2.py @@ -0,0 +1,40 @@ +import pytest +import os +import locidex +from pathlib import Path +from locidex.manifest import DBData +from locidex.classes import blast2 +from locidex.constants import BlastCommands, BlastColumns + + +PACKAGE_ROOT = os.path.dirname(locidex.__file__) + + +@pytest.fixture() +def db_data(): + db_dir = DBData(Path(PACKAGE_ROOT).joinpath("example", "build_db_mlst_out")) + return db_dir + +@pytest.fixture() +def fasta(): + return Path(PACKAGE_ROOT).joinpath("example", "search", "NC_003198.1.fasta") + +def test_validate_blast_db(db_data): + + test_class = blast2.BlastSearch(db_data, Path("home"), dict(), BlastCommands.blastn, BlastColumns._fields, dict()) + test_class.validate_blast_db() + + +def test_blast_runs(db_data, fasta, tmpdir): + test_class = blast2.BlastSearch(db_data, fasta, dict(), BlastCommands.blastn, BlastColumns._fields, dict()) + out_file = "out.txt" + stdout, stderr = test_class._run_blast(db_data.nucleotide_blast_db, tmpdir / out_file) + with open(tmpdir / out_file, "r") as fp: + assert len(fp.readlines()) == 30 + + +def test_blast_runs(db_data, fasta, tmpdir): + test_class = blast2.BlastSearch(db_data, fasta, dict(), BlastCommands.blastn, BlastColumns._fields, dict()) + out_file = tmpdir / "out.txt" + bd = test_class.get_blast_data(db_data.nucleotide_blast_db, out_file) + assert len(bd) == 30 \ No newline at end of file From 3e240b233b37755f2a7e2beedae2a80bcaf6d440 Mon Sep 17 00:00:00 2001 From: Matthew Wells Date: Thu, 9 May 2024 17:15:54 -0500 Subject: [PATCH 4/7] in way deeper than I want to be --- .gitignore | 4 +- locidex/classes/blast2.py | 48 +++++++++--- locidex/classes/fasta.py | 8 +- locidex/classes/gbk.py | 2 - locidex/classes/seq_intake.py | 54 ++++++------- locidex/constants.py | 4 +- locidex/extract.py | 59 +++++++++------ locidex/manifest.py | 36 +++++++-- locidex/search.py | 139 +++++++++++++++++----------------- locidex/utils.py | 1 + tests/test_fasta.py | 8 +- tests/test_workflows.yml | 4 +- 12 files changed, 213 insertions(+), 154 deletions(-) diff --git a/.gitignore b/.gitignore index 9816aea..ea3ec62 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ __pycache__ *.egg* -.vscode \ No newline at end of file +.vscode +.pytest_cache +tmp \ No newline at end of file diff --git a/locidex/classes/blast2.py b/locidex/classes/blast2.py index f69ba1f..1cc8244 100644 --- a/locidex/classes/blast2.py +++ b/locidex/classes/blast2.py @@ -11,8 +11,11 @@ from dataclasses import dataclass from pathlib import Path from typing import Optional, List, Dict -import os +import logging +import sys +logger = logging.getLogger(__name__) +logging.basicConfig(filemode=sys.stderr, level=logging.INFO) @dataclass class FilterOptions: @@ -22,6 +25,27 @@ class FilterOptions: __slots__ = slots(__annotations__) +class BlastMakeDB: + """ + Create a blast database + """ + + def __init__(self, input_file: Path, db_type: str, parse_seqids: bool, output_db_path: Optional[Path]): + self.input_file = input_file + self.db_type = db_type + self.parse_seqids = parse_seqids + self.output_db_path = output_db_path + + def makeblastdb(self): + command = ['makeblastdb', '-in', str(self.input_file), '-dbtype', self.db_type] + if self.parse_seqids: + command.append('-parse_seqids') + if self.output_db_path != None: + command +=['-out', str(self.output_db_path)] + stdout, stderr = run_command(" ".join([str(x) for x in command])) + print(stdout, stderr) + return self.output_db_path + class BlastSearch: __blast_commands = set(BlastCommands._keys()) __blast_extensions_nt = frozenset(['.nsq', '.nin', '.nhr']) @@ -38,11 +62,15 @@ def __init__(self, db_data: DBData, query_path: Path, blast_params: dict, blast_ self.blast_columns = blast_columns self.filter_options = filter_options - def get_blast_data(self, db_path, output) -> pd.DataFrame: + def get_blast_data(self, db_path: Path, output: Path) -> pd.DataFrame: """ Run blast and parse results + TODO need to clean up the db_path hand off from the DBData obj its dirty """ + stdout, stderr = self._run_blast(db=db_path, output=output) + #logger.info("Blast stdout: {}".format(stdout)) + #logger.info("Blast stderr: {}".format(stderr)) blast_data = self.parse_blast(output_file=output) return blast_data @@ -59,24 +87,24 @@ def parse_blast(self, output_file: Path): tp[id_col] = 'object' df = df.astype(tp) for col_name in self.filter_options: - if col_name in self.columns: + if col_name in columns: min_value = self.filter_options[col_name].min max_value = self.filter_options[col_name].max include = self.filter_options[col_name].include - self.filter_df(col_name, min_value, max_value, include) + df = self.filter_df(df,col_name, min_value, max_value, include, columns) return df - def filter_df(self,col_name,min_value,max_value,include): - if col_name not in self.columns: + def filter_df(self,df, col_name,min_value,max_value,include, columns): + if col_name not in columns: return False if min_value is not None: - self.df = self.df[self.df[col_name] >= min_value] + df = df[df[col_name] >= min_value] if max_value is not None: - self.df = self.df[self.df[col_name] <= max_value] + df = df[df[col_name] <= max_value] if include is not None: - self.df = self.df[self.df[col_name].isin(include)] - return True + df = df[df[col_name].isin(include)] + return df def read_hit_table(self, blast_data): return pd.read_csv(blast_data,header=None,names=self.blast_columns,sep="\t",low_memory=False) diff --git a/locidex/classes/fasta.py b/locidex/classes/fasta.py index b6cf7a3..a0884e8 100644 --- a/locidex/classes/fasta.py +++ b/locidex/classes/fasta.py @@ -5,6 +5,7 @@ import os from locidex.utils import calc_md5, slots from dataclasses import dataclass +from pathlib import Path @dataclass class Fasta: @@ -18,11 +19,11 @@ class Fasta: __slots__ = slots(__annotations__) -class parse_fasta: +class ParseFasta: - def __init__(self, input_file,parse_def=False,seq_type=None,delim="|"): + def __init__(self, input_file: Path,parse_def=False,seq_type=None,delim="|"): self.input_file = input_file - if not os.path.isfile(self.input_file): + if not self.input_file.exists(): raise FileNotFoundError("Input file: {} not found.".format(self.input_file)) self.delim = delim @@ -38,6 +39,7 @@ def normalize_sequence(fasta:str) -> str: return fasta.lower().replace("-", "") def get_seqids(self): + print(bool(self.seq_obj), self.seq_obj.keys()) if self.seq_obj: return list(self.seq_obj.keys()) raise AssertionError("No fasta file loaded.") diff --git a/locidex/classes/gbk.py b/locidex/classes/gbk.py index fa2c705..a1a4443 100644 --- a/locidex/classes/gbk.py +++ b/locidex/classes/gbk.py @@ -6,8 +6,6 @@ from locidex.utils import revcomp,calc_md5 class parse_gbk: - input_file = None - seq_obj = None status = True messages = [] diff --git a/locidex/classes/seq_intake.py b/locidex/classes/seq_intake.py index 5db080f..7cb05f5 100644 --- a/locidex/classes/seq_intake.py +++ b/locidex/classes/seq_intake.py @@ -2,12 +2,12 @@ import sys from locidex.classes.gbk import parse_gbk -from locidex.classes.fasta import parse_fasta +from locidex.classes.fasta import ParseFasta from locidex.utils import guess_alphabet, calc_md5, six_frame_translation, slots from locidex.classes.prodigal import gene_prediction -from locidex.constants import DNA_AMBIG_CHARS, DNA_IUPAC_CHARS, CharacterConstants -from typing import NamedTuple, Optional -from dataclasses import dataclass +from locidex.constants import DNA_AMBIG_CHARS, DNA_IUPAC_CHARS, CharacterConstants, DBConfig +from typing import NamedTuple, Optional, List +from dataclasses import dataclass, asdict @dataclass class HitFilters: @@ -24,7 +24,6 @@ class HitFilters: @dataclass class SeqObject: - parent_id: str locus_name: str seq_id: str @@ -41,51 +40,44 @@ class SeqObject: # Manually adding slots for compatibility __slots__ = slots(__annotations__) + def to_dict(self) -> dict: + return asdict(self) + class seq_intake: - input_file = '' valid_types = ['genbank','gff','gtf','fasta'] - file_type = None - feat_key = 'CDS' - translation_table = 11 is_file_valid = '' - - messages = [] - prodigal_genes = [] - skip_trans = False + def __init__(self,input_file,file_type,feat_key='CDS',translation_table=11,perform_annotation=False,num_threads=1,skip_trans=False): + if not input_file.exists(): + raise FileNotFoundError("File {} does not exist.".format(input_file)) + self.input_file = input_file self.file_type = file_type self.translation_table = translation_table self.feat_key = feat_key self.skip_trans = skip_trans - self.status = True self.num_threads = num_threads + self.prodigal_genes = [] + self.skip_trans = False #self.seq_data = self.process_fasta() + self.status = True - if not os.path.isfile(self.input_file): - self.messages.append(f'Error {self.input_file} does not exist') - self.status = False - - if not self.status: - return if file_type == 'genbank': self.seq_data = self.process_gbk() self.status = True elif file_type == 'fasta' and perform_annotation==True: - #sobj = gene_prediction(self.input_file) - #sobj.predict(num_threads) - #self.prodigal_genes = sobj.genes - #self.process_seq_hash(sobj.sequences) - #self.seq_data = self.process_fasta() self.seq_data = self.annotate_fasta(self.input_file, num_threads=self.num_threads) elif file_type == 'fasta' and perform_annotation==False: self.seq_data = self.process_fasta() - elif file_type == 'gff': + elif file_type == 'gff': # TODO these lists do not contain all allowed file types self.status = False elif file_type == 'gtf': self.status = False + else: + raise AttributeError + if self.status: self.add_codon_data() @@ -147,7 +139,7 @@ def process_gbk(self) -> list[SeqObject]: return seq_data def process_fasta(self, seq_data = []) -> list[SeqObject]: - obj = parse_fasta(self.input_file) + obj = ParseFasta(self.input_file) ids = obj.get_seqids() for id in ids: features = obj.get_seq_by_id(id) @@ -272,15 +264,15 @@ def __init__(self,sample_name,db_config_dict,metadata_dict,query_seq_records,bla self.record['query_data']['sample_name'] = self.sample_name - def add_db_config(self,conf): - self.record['db_info'] = conf + def add_db_config(self,conf: DBConfig): + self.record['db_info'] = conf.to_dict() def add_hit_cols(self,columns): self.record['query_hit_columns'] = columns - def add_seq_data(self,query_seq_records): + def add_seq_data(self,query_seq_records: List[SeqObject]): for idx, v in enumerate(query_seq_records): - self.record['query_data']['query_seq_data'][idx] = v + self.record['query_data']['query_seq_data'][idx] = v.to_dict() def add_db_metadata(self,metadata_dict): locus_profile = {} diff --git a/locidex/constants.py b/locidex/constants.py index 0ba58c0..137ec28 100644 --- a/locidex/constants.py +++ b/locidex/constants.py @@ -71,8 +71,8 @@ def _keys(cls) -> list: return [i.name for i in fields(cls)] FILE_TYPES = { - 'genbank': ["gbk","genbank","gbf","gbk.gz","genbank.gz","gbf.gz","gbff","gbff.gz"], - 'fasta': ["fasta","fas","fa","ffn","fna","fasta.gz","fas.gz","fa.gz","ffn.gz","fna.gz"], + 'genbank': [".gbk",".genbank",".gbf",".gbk.gz",".genbank.gz",".gbf.gz",".gbff",".gbff.gz"], + 'fasta': [".fasta",".fas",".fa",".ffn",".fna",".fasta.gz",".fas.gz",".fa.gz",".ffn.gz",".fna.gz"], } diff --git a/locidex/extract.py b/locidex/extract.py index 637dda8..56e574f 100644 --- a/locidex/extract.py +++ b/locidex/extract.py @@ -9,10 +9,12 @@ import numpy as np import pandas as pd from locidex.classes.extractor import extractor -from locidex.classes.blast import blast_search, parse_blast +#from locidex.classes.blast import blast_search, parse_blast +from locidex.classes.blast2 import BlastSearch, FilterOptions, BlastMakeDB +from locidex.manifest import DBData from locidex.classes.db import search_db_conf, db_config from locidex.classes.seq_intake import seq_intake, seq_store -from locidex.constants import SEARCH_RUN_DATA, FILE_TYPES, BlastColumns, DBConfig, DB_EXPECTED_FILES, NT_SUB, EXTRACT_MODES, OPTION_GROUPS +from locidex.constants import SEARCH_RUN_DATA, FILE_TYPES, BlastColumns, BlastCommands, DBConfig, DB_EXPECTED_FILES, NT_SUB, EXTRACT_MODES, OPTION_GROUPS from locidex.version import __version__ from locidex.classes.aligner import perform_alignment, aligner from locidex.utils import check_db_groups @@ -76,8 +78,8 @@ def write_seq_info(seq_data,out_file): def run_extract(config): # Input Parameters - input_fasta = config['in_fasta'] - outdir = config['outdir'] + input_fasta = Path(config['in_fasta']) + outdir = Path(config['outdir']) db_dir = config['db'] min_dna_ident = config['min_dna_ident'] min_evalue = config['min_evalue'] @@ -90,6 +92,7 @@ def run_extract(config): sample_name = config['name'] max_target_seqs = config['max_target_seqs'] mode = config['mode'].lower() + db_data = DBData(db_dir=db_dir) if not mode in EXTRACT_MODES: @@ -107,7 +110,7 @@ def run_extract(config): if format is None: for t in FILE_TYPES: for ext in FILE_TYPES[t]: - if re.search(f"{ext}$", input_fasta): + if ext == input_fasta.suffix: format = t else: format = format.lower() @@ -146,13 +149,16 @@ def run_extract(config): if not os.path.isdir(db_path): os.makedirs(db_path, 0o755) - db_path = os.path.join(db_path,'contigs.fasta') + contigs_path = os.path.join(db_path,'contigs.fasta') seq_data = {} - with open(db_path,'w') as oh: + with open(contigs_path,'w') as oh: for idx,seq in enumerate(seq_obj.seq_data): - seq_data[str(idx)] = {'id':str(seq['seq_id']),'seq':seq['dna_seq']} - oh.write(">{}\n{}\n".format(idx,seq['dna_seq'])) + seq_data[str(idx)] = {'id':str(seq.seq_id),'seq':seq.dna_seq} + oh.write(">{}\n{}\n".format(idx,seq.dna_seq)) del(seq_obj) + contigs_db = BlastMakeDB(contigs_path, DBData.nucleotide_db_type(), True, contigs_path) + contigs_db.makeblastdb() + blast_dir_base = os.path.join(outdir, 'blast') if not os.path.isdir(blast_dir_base): @@ -166,24 +172,27 @@ def run_extract(config): 'num_threads': n_threads, 'word_size':11 } - nt_db = "{}.fasta".format(blast_database_paths['nucleotide']) - hit_file = os.path.join(blast_dir_base, "hsps.txt") - obj = blast_search(input_db_path=db_path, input_query_path=nt_db, - output_results=hit_file, blast_params=blast_params, blast_method='blastn', - blast_columns=BlastColumns._fields,create_db=True) - - if obj.status == False: - print("Error something went wrong, please check error messages above") - sys.exit() + nt_db = Path("{}.fasta".format(blast_database_paths['nucleotide'])) filter_options = { - 'evalue': {'min': None, 'max': min_evalue, 'include': None}, - 'pident': {'min': min_dna_ident, 'max': None, 'include': None}, - 'qcovs': {'min': min_dna_match_cov, 'max': None, 'include': None}, - 'qcovhsp': {'min': min_dna_match_cov, 'max': None, 'include': None}, + 'evalue': FilterOptions(min=None, max=min_evalue, include=None), + 'pident': FilterOptions(min=min_dna_ident, max=None, include=None), + 'qcovs': FilterOptions(min=min_dna_match_cov, max=None, include=None), + 'qcovhsp': FilterOptions(min=min_dna_match_cov, max=None, include=None), } - hit_df = parse_blast(hit_file, BlastColumns._fields, filter_options).df + hit_file = os.path.join(blast_dir_base, "hsps.txt") + # TODO is this supposed to support nucleotide and amino acid? + #obj = BlastSearch(db_data=db_data.nucleotide_blast_db, + print("NT DB", nt_db) + obj = BlastSearch(db_data=contigs_db.output_db_path, + query_path=nt_db, + blast_params=blast_params, + blast_method=BlastCommands.blastn, + blast_columns=BlastColumns._fields, + filter_options=filter_options) + hit_df = obj.get_blast_data(db_data.nucleotide_blast_db, Path(hit_file)) + hit_df['sseqid'] = hit_df['sseqid'].astype(str) hit_df['qseqid'] = hit_df['qseqid'].astype(str) @@ -198,6 +207,7 @@ def run_extract(config): filt_trunc = True if keep_truncated: filt_trunc = False + #print(seq_data) exobj = extractor(hit_df,seq_data,sseqid_col='sseqid',queryid_col='qseqid',qstart_col='qstart',qend_col='qend', qlen_col='qlen',sstart_col='sstart',send_col='send',slen_col='slen',sstrand_col='sstrand', @@ -209,13 +219,14 @@ def run_extract(config): nt_db_seq_obj = seq_intake(nt_db, 'fasta', 'source', translation_table, perform_annotation=False, skip_trans=True) nt_db_seq_data = {} for idx, seq in enumerate(nt_db_seq_obj.seq_data): - nt_db_seq_data[str(seq['seq_id'])] = seq['dna_seq'] + nt_db_seq_data[str(seq.seq_id)] = seq.dna_seq del(nt_db_seq_obj) ext_seq_data = {} with open(os.path.join(outdir,'raw.extracted.seqs.fasta'), 'w') as oh: for idx,record in enumerate(exobj.seqs): + print(record) if min_dna_len > len(record['seq']): continue seq_id = "{}:{}:{}:{}".format(record['locus_name'],record['query_id'],record['seqid'],record['id']) diff --git a/locidex/manifest.py b/locidex/manifest.py index 5fe1e77..69440fa 100644 --- a/locidex/manifest.py +++ b/locidex/manifest.py @@ -22,15 +22,34 @@ class DBData: * point we should have a better understanding of how all the modules fit together. """ - __db_names = ["nucleotide", "protein"] - __nucleotide_path = pathlib.Path(__db_names[0]) - __protein_path = pathlib.Path(__db_names[1]) + __nucleotide_name = "nucleotide" + __nucleotide_db_type = "nucl" + __protein_name = "protein" + __protein_db_type = "prot" + __nucleotide_path = pathlib.Path(__nucleotide_name) + __protein_path = pathlib.Path(__protein_name) def __init__(self, db_dir: pathlib.Path): - self.db_dir = db_dir + self.db_dir = pathlib.Path(db_dir) self.config_data: DBConfig = self._get_config(self.db_dir) self.metadata: dict = self._get_metadata(self.db_dir) - self.nucleotide, self.protein = self._get_blast_dbs(db_dir, self.config_data) + self.nucleotide, self.protein = self._get_blast_dbs(self.db_dir, self.config_data) + + @classmethod + def nucleotide_db_type(cls): + return cls.__nucleotide_db_type + + @classmethod + def protein_db_type(cls): + return cls.__protein_db_type + + @classmethod + def protein_name(cls): + return cls.__protein_name + + @classmethod + def nucleotide_name(cls): + return cls.__nucleotide_name @property def nucleotide_blast_db(self): @@ -54,6 +73,11 @@ def _get_metadata(self, db_dir: pathlib.Path) -> dict: metadata_file = db_dir.joinpath(DBFiles.meta_file) if not metadata_file.exists(): raise FileNotFoundError("Metadata file does not exist. Database path maybe incorrect: {}".format(db_dir)) + md_data = None + with open(metadata_file, 'r') as md: + md_data = json.load(md) + return md_data + def _get_blast_dbs(self, db_dir: pathlib.Path, config_data: DBConfig) -> Tuple[Optional[pathlib.Path], Optional[pathlib.Path]]: blast_db = db_dir.joinpath(DBFiles.blast_dir) @@ -135,7 +159,7 @@ def check_config(directory: pathlib.Path) -> DBConfig: directory: Path of the directory containing the parent. """ - config_dir = pathlib.Path(directory / DBFiles.config_file) + config_dir = pathlib.Path(directory).joinpath(DBFiles.config_file) config_data: Optional[DBConfig] = None with open(config_dir, 'r') as conf: config_data = DBConfig(**json.load(conf)) diff --git a/locidex/search.py b/locidex/search.py index b62bc35..933df56 100644 --- a/locidex/search.py +++ b/locidex/search.py @@ -7,15 +7,15 @@ from datetime import datetime from typing import Optional from dataclasses import dataclass - import pandas as pd +from functools import partial #from locidex.classes.blast import blast_search, parse_blast, FilterOptions from locidex.classes.blast2 import BlastSearch, FilterOptions from locidex.classes.db import search_db_conf, db_config from locidex.manifest import DBData from locidex.classes.seq_intake import seq_intake, seq_store, HitFilters -from locidex.constants import SEARCH_RUN_DATA, FILE_TYPES, BlastColumns, DB_EXPECTED_FILES, OPTION_GROUPS, DBConfig +from locidex.constants import BlastCommands, SEARCH_RUN_DATA, FILE_TYPES, BlastColumns, DB_EXPECTED_FILES, OPTION_GROUPS, DBConfig from locidex.utils import write_seq_dict, check_db_groups, slots from locidex.version import __version__ @@ -57,7 +57,7 @@ def add_args(parser=None): parser.add_argument('--format', type=str, required=False, help='Format of query file [genbank,fasta]') parser.add_argument('--translation_table', type=int, required=False, - help='output directory', default=11) + help='Table to use for translation', default=11) parser.add_argument('-a', '--annotate', required=False, help='Perform annotation on unannotated input fasta', action='store_true') parser.add_argument('-V', '--version', action='version', version="%(prog)s " + __version__) @@ -66,19 +66,48 @@ def add_args(parser=None): return parser -def perform_search(query_file,results_file,db_path,blast_prog,blast_params,columns): - return blast_search(db_path,query_file,results_file,blast_params,blast_prog,columns) +#def perform_search(query_file,results_file,db_path,blast_prog,blast_params,columns): +# return blast_search(db_path,query_file,results_file,blast_params,blast_prog,columns) def create_fasta_from_df(df,label_col,seq_col,out_file): - write_seq_dict(dict(zip(df[label_col].tolist(), df[seq_col])), out_file) - + return write_seq_dict(dict(zip(df[label_col].tolist(), df[seq_col])), out_file) + +@dataclass +class DefaultSearchOpts: + program: str + seq_col: str + pident_filter: FilterOptions + qcovs_filter: FilterOptions + db_dir: Path + output_dir: str + + +def create_outputs(output_dir: Path, db_data: DBData, blast_params: dict, configuration: DefaultSearchOpts, filtered_df: pd.DataFrame, filter_options: dict) -> pd.DataFrame: + """ + Create outputs of blast hits + output_dir Path: output location of search data + configuration DefaultSearchOpts: Pararmeters passed to 'run_search' from the cli + + This function will have some needed clean up once the cli is tidied + """ + hsps_out = "hsps.txt" #? Need to follow up on what hsps stands for + label_col = 'index' + query_fasta = output_dir.joinpath("queries.fasta") + output_hsps = output_dir.joinpath(hsps_out) + if not output_dir.exists() or not output_dir.is_dir(): + os.makedirs(output_dir, 0o755) + + output_file = create_fasta_from_df(filtered_df, label_col=label_col, seq_col=configuration.seq_col, out_file=query_fasta) + search_data = BlastSearch(db_data, output_file, blast_params, configuration.program, BlastColumns._fields, filter_options) + searched_df = search_data.get_blast_data(configuration.db_dir, output_hsps) + return searched_df def run_search(config): # Input Parameters - query_file = config['query'] - outdir = config['outdir'] + query_file = Path(config['query']) + outdir = Path(config['outdir']) db_dir = config['db'] min_dna_ident = config['min_dna_ident'] min_aa_ident =config['min_aa_ident'] @@ -106,26 +135,15 @@ def run_search(config): perform_annotation = False if sample_name == None: - sample_name = os.path.basename(query_file) + sample_name = query_file.stem run_data = SEARCH_RUN_DATA run_data['analysis_start_time'] = datetime.now().strftime("%d/%m/%Y %H:%M:%S") run_data['parameters'] = config - # Validate database is valid - #db_database_config = search_db_conf(db_dir, DB_EXPECTED_FILES, DBConfig._keys()) - #if db_database_config.status == False: - # print(f'There is an issue with provided db directory: {db_dir}\n {db_database_config.messages}') - # sys.exit() - db_data = DBData(db_dir=db_dir) - #metadata_path = db_database_config.meta_file_path - #metadata_obj = db_config(metadata_path, ['meta', 'info']) - metadata_obj = db_data.metadata - #blast_database_paths = db_database_config.blast_paths - #blast_database_paths = db_database_config.blast_paths if os.path.isdir(outdir) and not force: print(f'Error {outdir} exists, if you would like to overwrite, then specify --force') sys.exit() @@ -136,7 +154,7 @@ def run_search(config): if format is None: for t in FILE_TYPES: for ext in FILE_TYPES[t]: - if re.search(f"{ext}$", query_file): + if query_file.suffix == ext: format = t else: format = format.lower() @@ -148,7 +166,8 @@ def run_search(config): print(f'Format for query file must be one of {list(FILE_TYPES.keys())}, you supplied {format}') sys.exit() - seq_obj = seq_intake(query_file, format, 'CDS', translation_table, perform_annotation) + seq_obj = seq_intake(input_file=query_file, file_type=format, feat_key='CDS', + translation_table=translation_table, perform_annotation=perform_annotation) if perform_annotation: gbk_data = [] @@ -162,11 +181,6 @@ def run_search(config): with open(f, 'w') as oh: oh.write("\n".join([str(x) for x in gbk_data])) - - #blast_dir_base = os.path.join(outdir, 'blast') - #if not os.path.isdir(blast_dir_base): - # os.makedirs(blast_dir_base, 0o755) - blast_params = { 'evalue': min_evalue, 'max_target_seqs': max_target_seqs, @@ -192,61 +206,48 @@ def run_search(config): dna_ambig_count=max_ambig_count) - store_obj = seq_store(sample_name, db_data.config, metadata_obj.config['meta'], + store_obj = seq_store(sample_name, db_data.config_data, db_data.metadata['meta'], seq_obj.seq_data, BlastColumns._fields, hit_filters) - ############# Tommorow wrap this in two individual functions - for db_label in dbs: - label_col = 'index' - if db_data.nucleotide: - blast_prog = 'blastn' - seq_col = 'dna_seq' - d = db_data.nucleotide - filter_options['pident'] = FilterOptions(min=min_dna_ident, max=None, include=None) - filter_options['qcovs'] = FilterOptions(min=min_dna_match_cov, max=None, include=None) - - elif db_data.protein: - blast_prog = 'blastp' - seq_col = 'aa_seq' - d = db_data.protein - filter_options['pident'] = FilterOptions(min=min_aa_ident, max=None, include=None) - filter_options['qcovs'] = FilterOptions(min=min_aa_match_cov, max=None, include=None) - - if not os.path.isdir(d): - os.makedirs(d, 0o755) - else: - if os.path.isfile(os.path.join(d, "queries.fasta")): - os.remove(os.path.join(d, "queries.fasta")) - if os.path.isfile(os.path.join(d, "hsps.txt")): - os.remove(os.path.join(d, "hsps.txt")) - - #db_path = blast_database_paths[db_label] - create_fasta_from_df(filtered_df, label_col, seq_col, os.path.join(d, "queries.fasta")) - #perform_search(os.path.join(d, "queries.fasta"), os.path.join(d, "hsps.txt"), db_path, blast_prog, blast_params, - # BlastColumns._fields) - #hit_obj = parse_blast(os.path.join(d, "hsps.txt"), BlastColumns._fields, filter_options) - #hit_df = hit_obj.df - search_data = BlastSearch(db_data, os.path.join(d, "queries.fasta"), blast_params, blast_prog, BlastColumns._fields, filter_options) - searched_df = search_data.get_blast_data(db_data.nucleotide_blast_db, os.path.join(d, "hsps.txt")) - - #store_obj.add_hit_data(hit_df, db_label, 'qseqid') - store_obj.add_hit_data(searched_df, db_label, 'qseqid') + protein_filter = DefaultSearchOpts( + program=BlastCommands.blastp, + seq_col="aa_seq", + pident_filter=FilterOptions(min=min_aa_ident, max=None, include=None), + qcovs_filter=FilterOptions(min=min_aa_match_cov, max=None, include=None), + db_dir=db_data.protein_blast_db, + output_dir=DBData.protein_name()) + + nucleotide_filter = DefaultSearchOpts( + program=BlastCommands.blastn, + seq_col="dna_seq", + pident_filter=FilterOptions(min=min_dna_ident, max=None, include=None), + qcovs_filter=FilterOptions(min=min_dna_match_cov, max=None, include=None), + db_dir=db_data.nucleotide_blast_db, + output_dir=DBData.nucleotide_name()) + + searched_hits_col = 'qseqid' + output_creation = partial(create_outputs, output_dir=outdir, db_data=db_data, filtered_df=filtered_df, blast_params=blast_params, filter_options=filter_options) + if db_data.nucleotide: + searched_data = output_creation(configuration=nucleotide_filter) + store_obj.add_hit_data(searched_data, DBData.nucleotide_name(), searched_hits_col) + if db_data.protein: + searched_data = output_creation(configuration=protein_filter) + store_obj.add_hit_data(searched_data, DBData.protein_name(), searched_hits_col) store_obj.filter_hits() store_obj.convert_profile_to_list() - run_data['result_file'] = os.path.join(outdir,"seq_store.json") - del (filtered_df) + run_data['result_file'] = str(outdir.joinpath("seq_store.json")) + with open(run_data['result_file'], "w") as fh: fh.write(json.dumps(store_obj.record, indent=4)) run_data['analysis_end_time'] = datetime.now().strftime("%d/%m/%Y %H:%M:%S") - with open(os.path.join(outdir,"run.json"),'w' ) as fh: + with open(outdir.joinpath("run.json"),'w' ) as fh: fh.write(json.dumps(run_data, indent=4)) def run(cmd_args=None): - #cmd_args = parse_args() if cmd_args is None: parser = add_args() cmd_args = parser.parse_args() diff --git a/locidex/utils.py b/locidex/utils.py index 3d9b79c..bb50b60 100644 --- a/locidex/utils.py +++ b/locidex/utils.py @@ -134,6 +134,7 @@ def write_seq_dict(data,output_file): with open(output_file, 'w') as oh: for id in data: oh.write(f">{id}\n{data[id]}\n") + return output_file def validate_dict_keys(data_dict,required_keys): diff --git a/tests/test_fasta.py b/tests/test_fasta.py index c4cfc18..a61e71d 100644 --- a/tests/test_fasta.py +++ b/tests/test_fasta.py @@ -4,7 +4,7 @@ from Bio.SeqRecord import SeqRecord import tempfile import os -from locidex.classes.fasta import parse_fasta +from locidex.classes.fasta import ParseFasta @pytest.fixture def fasta_content(): @@ -23,7 +23,7 @@ def fasta_file(fasta_content): os.unlink(tmp_path) def test_parse_fasta_normal(fasta_file): - parser = parse_fasta(fasta_file) + parser = ParseFasta(fasta_file) assert len(parser.get_seqids()) == 2, "There should be two sequences" seq_data = parser.get_seq_by_id("gene1|123") assert seq_data.seq == "ATGCGTACGTAGCTAGC", "Sequence data should match the input" @@ -31,13 +31,13 @@ def test_parse_fasta_normal(fasta_file): def test_parse_fasta_with_nonexistent_file(): # Assuming that the locidex errors out with File not found error if FASTA is non existant: with pytest.raises(FileNotFoundError): - parse_fasta("nonexistent.fasta") + ParseFasta("nonexistent.fasta") # parser = parse_fasta("nonexistent.fasta") # assert not parser.status, "Parser status should be False when file does not exist" def test_parse_fasta_with_definitions(fasta_file): - parser = parse_fasta(fasta_file, parse_def=True, delim="|") + parser = ParseFasta(fasta_file, parse_def=True, delim="|") seq_data = parser.get_seq_by_id("gene1|123") assert seq_data.gene_name == "gene1", "Gene name should be correctly parsed from definition" assert seq_data.seq_id == "123", "Seq ID should be correctly parsed from definition" diff --git a/tests/test_workflows.yml b/tests/test_workflows.yml index a85800b..281e420 100644 --- a/tests/test_workflows.yml +++ b/tests/test_workflows.yml @@ -25,8 +25,8 @@ - name: Run all command: > bash -c " - locidex extract -i locidex/example/search/NC_003198.1.fasta -d locidex/example/build_db_mlst_out -o here; - locidex search --query here/raw.extracted.seqs.fasta -d locidex/example/build_db_mlst_out -o searched; + locidex extract -i locidex/example/search/NC_003198.1.fasta -d locidex/example/build_db_mlst_out -o here && + locidex search --query here/raw.extracted.seqs.fasta -d locidex/example/build_db_mlst_out -o searched && locidex report -i searched/seq_store.json -o reported " files: From 8a53a90782763971a27368e0ac28e1d056be4ff3 Mon Sep 17 00:00:00 2001 From: Matthew Wells Date: Fri, 10 May 2024 10:45:12 -0500 Subject: [PATCH 5/7] all tests passing now --- locidex/classes/blast2.py | 17 +++++++++++++---- locidex/extract.py | 7 +++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/locidex/classes/blast2.py b/locidex/classes/blast2.py index 1cc8244..7af6e3b 100644 --- a/locidex/classes/blast2.py +++ b/locidex/classes/blast2.py @@ -43,7 +43,10 @@ def makeblastdb(self): if self.output_db_path != None: command +=['-out', str(self.output_db_path)] stdout, stderr = run_command(" ".join([str(x) for x in command])) - print(stdout, stderr) + if stdout: + logger.info("Blast makedb stdout: {}".format(stdout)) + if stderr: + logger.info("Blast makedb stderr: {}".format(stderr)) return self.output_db_path class BlastSearch: @@ -69,8 +72,11 @@ def get_blast_data(self, db_path: Path, output: Path) -> pd.DataFrame: """ stdout, stderr = self._run_blast(db=db_path, output=output) - #logger.info("Blast stdout: {}".format(stdout)) - #logger.info("Blast stderr: {}".format(stderr)) + if stdout: + logger.info("Blast search stdout: {}".format(stdout)) + if stderr: + logger.info("Blast search stderr: {}".format(stderr)) + blast_data = self.parse_blast(output_file=output) return blast_data @@ -132,6 +138,8 @@ def _run_blast(self, db: Path, output: Path): """ db PAth: Path to the blast database to use, output Path: Path to file for blast output + + TODO need to use classes db versoin or not pass it too the initializer """ command = [ self.blast_method, @@ -144,7 +152,8 @@ def _run_blast(self, db: Path, output: Path): if param == "parse_seqids": command.append(f"-{param}") else: - command += [f'-{param}', f'{self.blast_params[param]}'] + command += [f'-{param}', f'{self.blast_params[param]}'] + logger.info("Blast command: {}".format(" ".join([str(x) for x in command]))) return run_command(" ".join([str(x) for x in command])) diff --git a/locidex/extract.py b/locidex/extract.py index 56e574f..2f5c352 100644 --- a/locidex/extract.py +++ b/locidex/extract.py @@ -183,15 +183,14 @@ def run_extract(config): hit_file = os.path.join(blast_dir_base, "hsps.txt") # TODO is this supposed to support nucleotide and amino acid? - #obj = BlastSearch(db_data=db_data.nucleotide_blast_db, - print("NT DB", nt_db) + obj = BlastSearch(db_data=contigs_db.output_db_path, query_path=nt_db, blast_params=blast_params, blast_method=BlastCommands.blastn, blast_columns=BlastColumns._fields, filter_options=filter_options) - hit_df = obj.get_blast_data(db_data.nucleotide_blast_db, Path(hit_file)) + hit_df = obj.get_blast_data(contigs_db.output_db_path, Path(hit_file)) hit_df['sseqid'] = hit_df['sseqid'].astype(str) hit_df['qseqid'] = hit_df['qseqid'].astype(str) @@ -207,7 +206,7 @@ def run_extract(config): filt_trunc = True if keep_truncated: filt_trunc = False - #print(seq_data) + exobj = extractor(hit_df,seq_data,sseqid_col='sseqid',queryid_col='qseqid',qstart_col='qstart',qend_col='qend', qlen_col='qlen',sstart_col='sstart',send_col='send',slen_col='slen',sstrand_col='sstrand', From 206d2bcc3bed7981d840c1581a14e2325bed0e78 Mon Sep 17 00:00:00 2001 From: Matthew Wells Date: Fri, 10 May 2024 12:51:20 -0500 Subject: [PATCH 6/7] removed redundant print statements --- locidex/classes/blast2.py | 9 ++++++--- locidex/extract.py | 1 - tests/test_aligner.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/locidex/classes/blast2.py b/locidex/classes/blast2.py index 7af6e3b..50436b0 100644 --- a/locidex/classes/blast2.py +++ b/locidex/classes/blast2.py @@ -42,7 +42,9 @@ def makeblastdb(self): command.append('-parse_seqids') if self.output_db_path != None: command +=['-out', str(self.output_db_path)] - stdout, stderr = run_command(" ".join([str(x) for x in command])) + mk_db_cmd = " ".join([str(x) for x in command]) + logger.info("Blast database command: {}".format(mk_db_cmd)) + stdout, stderr = run_command(mk_db_cmd) if stdout: logger.info("Blast makedb stdout: {}".format(stdout)) if stderr: @@ -153,8 +155,9 @@ def _run_blast(self, db: Path, output: Path): command.append(f"-{param}") else: command += [f'-{param}', f'{self.blast_params[param]}'] - logger.info("Blast command: {}".format(" ".join([str(x) for x in command]))) - return run_command(" ".join([str(x) for x in command])) + blast_search_cmd = " ".join([str(x) for x in command]) + logger.info("Blast command: {}".format(blast_search_cmd)) + return run_command(blast_search_cmd) diff --git a/locidex/extract.py b/locidex/extract.py index 2f5c352..bc80fbc 100644 --- a/locidex/extract.py +++ b/locidex/extract.py @@ -225,7 +225,6 @@ def run_extract(config): ext_seq_data = {} with open(os.path.join(outdir,'raw.extracted.seqs.fasta'), 'w') as oh: for idx,record in enumerate(exobj.seqs): - print(record) if min_dna_len > len(record['seq']): continue seq_id = "{}:{}:{}:{}".format(record['locus_name'],record['query_id'],record['seqid'],record['id']) diff --git a/tests/test_aligner.py b/tests/test_aligner.py index 9799db9..f820934 100644 --- a/tests/test_aligner.py +++ b/tests/test_aligner.py @@ -1,7 +1,7 @@ import pytest import os import locidex.classes.aligner -from locidex.constants import BLAST_TABLE_COLS +from locidex.constants import BlastColumns from dataclasses import dataclass PACKAGE_ROOT = os.path.dirname(locidex.__file__) From 84073401e754592eed00e1e0bb3b4bc3acb33d34 Mon Sep 17 00:00:00 2001 From: Matthew Wells Date: Fri, 10 May 2024 14:10:57 -0500 Subject: [PATCH 7/7] fixed type hints, and utils tests --- locidex/classes/blast.py | 1 - locidex/classes/blast2.py | 5 ++- locidex/classes/fasta.py | 1 - locidex/constants.py | 20 +++++------ .../example/build_db_mlst_out/results.json | 4 +-- locidex/extract.py | 5 ++- locidex/format.py | 2 +- tests/test_build.py | 4 +-- tests/test_extractor.py | 34 +++++++++++-------- tests/test_fasta.py | 7 ++-- tests/test_seq_intake.py | 22 +++++++----- tests/test_utils.py | 5 ++- 12 files changed, 65 insertions(+), 45 deletions(-) diff --git a/locidex/classes/blast.py b/locidex/classes/blast.py index 148a9c7..3b01f3d 100644 --- a/locidex/classes/blast.py +++ b/locidex/classes/blast.py @@ -64,7 +64,6 @@ def makeblastdb(self): def is_blast_db_valid(self): extensions = ['nsq', 'nin', 'nhr'] for e in extensions: - print(self.output_db_path, e) if not os.path.isfile(f'{self.output_db_path}.{e}'): extensions2 = ['pto', 'ptf', 'phr'] for e2 in extensions2: diff --git a/locidex/classes/blast2.py b/locidex/classes/blast2.py index 50436b0..1344b31 100644 --- a/locidex/classes/blast2.py +++ b/locidex/classes/blast2.py @@ -57,7 +57,7 @@ class BlastSearch: __blast_extensions_pt = frozenset(['.pto', '.ptf', '.phr']) __filter_columns = ["qseqid", "sseqid"] - def __init__(self, db_data: DBData, query_path: Path, blast_params: dict, blast_method: str, blast_columns: List[str], filter_options: Dict[str, FilterOptions]): + def __init__(self, db_data: DBData, query_path: Path, blast_params: dict, blast_method: str, blast_columns: List[str], filter_options: Optional[Dict[str, FilterOptions]]=None): self.db_data = db_data if blast_method not in self.__blast_commands: raise ValueError("{} is not a valid blast command please pick from {}".format(blast_method, self.__blast_commands)) @@ -94,6 +94,9 @@ def parse_blast(self, output_file: Path): if id_col in columns: tp[id_col] = 'object' df = df.astype(tp) + if self.filter_options is None: + return df + for col_name in self.filter_options: if col_name in columns: min_value = self.filter_options[col_name].min diff --git a/locidex/classes/fasta.py b/locidex/classes/fasta.py index a0884e8..ada7536 100644 --- a/locidex/classes/fasta.py +++ b/locidex/classes/fasta.py @@ -39,7 +39,6 @@ def normalize_sequence(fasta:str) -> str: return fasta.lower().replace("-", "") def get_seqids(self): - print(bool(self.seq_obj), self.seq_obj.keys()) if self.seq_obj: return list(self.seq_obj.keys()) raise AssertionError("No fasta file loaded.") diff --git a/locidex/constants.py b/locidex/constants.py index 137ec28..6044e92 100644 --- a/locidex/constants.py +++ b/locidex/constants.py @@ -105,16 +105,16 @@ def _keys(cls) -> list: @dataclass class DBConfig: - db_name: Union[str, None] = None - db_version: Union[str, None] = None - db_date: Union[str, None] = None - db_author: Union[str, None] = None - db_desc: Union[str, None] = None - db_num_seqs: Union[str, int] = None - is_nucl: Union[bool, None] = None - is_prot: Union[bool, None] = None - nucleotide_db_name: Union[str, None] = None - protein_db_name: Union[str, None] = None + db_name: Optional[str] = None + db_version: Optional[str] = None + db_date: Optional[str] = None + db_author: Optional[str] = None + db_desc: Optional[str] = None + db_num_seqs: Optional[Union[str, int]] = None + is_nucl: Optional[bool] = None + is_prot: Optional[bool] = None + nucleotide_db_name: Optional[str] = None + protein_db_name: Optional[str] = None def __getitem__(self, name: str) -> Any: return getattr(self, str(name)) diff --git a/locidex/example/build_db_mlst_out/results.json b/locidex/example/build_db_mlst_out/results.json index e8c0c25..bca6cb6 100644 --- a/locidex/example/build_db_mlst_out/results.json +++ b/locidex/example/build_db_mlst_out/results.json @@ -4,9 +4,9 @@ "input_file": "locidex/example/build_db_mlst_in/senterica.mlst.txt", "outdir": "locidex/example/build_db_mlst_out/", "name": "Locidex Database", - "author": "", + "author": "mw", "db_ver": "1.0.0", - "db_desc": "", + "db_desc": "test", "force": true }, "analysis_end_time": "2024-04-30 16:29:13" diff --git a/locidex/extract.py b/locidex/extract.py index bc80fbc..ef3e059 100644 --- a/locidex/extract.py +++ b/locidex/extract.py @@ -172,7 +172,10 @@ def run_extract(config): 'num_threads': n_threads, 'word_size':11 } - nt_db = Path("{}.fasta".format(blast_database_paths['nucleotide'])) + #nt_db = Path("{}.fasta".format(blast_database_paths['nucleotide'])) + nt_db = Path("{}.fasta".format(db_data.nucleotide_blast_db)) + if not nt_db.exists(): + raise FileNotFoundError("Could not find nucleotide database: {}".format(nt_db)) filter_options = { 'evalue': FilterOptions(min=None, max=min_evalue, include=None), diff --git a/locidex/format.py b/locidex/format.py index 6e2d919..6523fa7 100644 --- a/locidex/format.py +++ b/locidex/format.py @@ -70,7 +70,7 @@ def process_dir(self): for f in files[self.__file_input]: for e in self.valid_ext: if e in f[1]: - self.gene_name = f[1].replace(f'.{e}','') + self.gene_name = f[1].replace(f'{e}','') self.parse_fasta(f[0]) break diff --git a/tests/test_build.py b/tests/test_build.py index 680e89a..5a8e8ce 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -37,9 +37,9 @@ def cmd_args(output_directory): outdir=output_directory, name='Locidex Database', db_ver='1.0.0', - db_desc='', + db_desc='test', force=True, - author='', + author='mw', date='' ) return command diff --git a/tests/test_extractor.py b/tests/test_extractor.py index 1336ca2..9b5afbd 100644 --- a/tests/test_extractor.py +++ b/tests/test_extractor.py @@ -1,12 +1,14 @@ import pytest import os import locidex -from locidex.classes.blast import FilterOptions, blast_search, parse_blast -from locidex.constants import BlastColumns +from locidex.classes.blast import blast_search, parse_blast +from locidex.classes.blast2 import FilterOptions, BlastSearch, BlastMakeDB +from locidex.constants import BlastColumns, BlastCommands from locidex.classes.extractor import extractor from locidex.classes.seq_intake import seq_intake from locidex.classes.db import db_config - +from locidex.manifest import DBData +from pathlib import Path #could be tested via locidex extract -i ./locidex/example/search/NC_003198.1.fasta -d ./locidex/example/build_db_mlst_out/ -o tmp --force @@ -30,7 +32,7 @@ def blast_db_and_search(tmpdir,input_db_path): @pytest.fixture def seq_intake_fixture(): # Mimicking the creation of seq_data from a given input fasta file - input_fasta = os.path.join(PACKAGE_ROOT, 'example/search/NC_003198.1.fasta') + input_fasta = Path(PACKAGE_ROOT).joinpath('example/search/NC_003198.1.fasta') format = "fasta" # Adjust this based on your file type translation_table = 11 seq_obj = seq_intake(input_fasta, format, 'source', translation_table, perform_annotation=False,skip_trans=True) @@ -40,21 +42,25 @@ def seq_intake_fixture(): def test_extractor_initialization(seq_intake_fixture, tmpdir): db_path=os.path.join(tmpdir,"contigs.fasta") nt_db_test = os.path.join(PACKAGE_ROOT,'example/build_db_mlst_out/blast/nucleotide/nucleotide.fasta') - nt_db = os.path.join(PACKAGE_ROOT,'example/build_db_mlst_out/blast/nucleotide/') hit_file = os.path.join(tmpdir,"hsps.txt") blast_params={'evalue': 0.0001, 'max_target_seqs': 10, 'num_threads': 1} metadata_path = os.path.join(PACKAGE_ROOT,'example/build_db_mlst_out/meta.json') seq_obj = seq_intake_fixture seq_data={} - #with open(db_path,'w') as oh: - # for idx,seq in enumerate(seq_obj.seq_data): - # seq_data[str(idx)] = {'id':str(seq.seq_id),'seq':seq.dna_seq} - # oh.write(">{}\n{}\n".format(idx,seq.dna_seq)) - blast_search(input_db_path=nt_db, - input_query_path=nt_db_test, - output_results=hit_file, blast_params=blast_params, blast_method='blastn', - blast_columns=BlastColumns._fields) - hit_df = parse_blast(hit_file, BlastColumns._fields, {}).df + with open(db_path,'w') as oh: + for idx,seq in enumerate(seq_obj.seq_data): + seq_data[str(idx)] = {'id':str(seq.seq_id),'seq':seq.dna_seq} + oh.write(">{}\n{}\n".format(idx,seq.dna_seq)) + + blast_db = BlastMakeDB(db_path, DBData.nucleotide_db_type(), True, db_path) + blast_db.makeblastdb() + blast_search_obj = BlastSearch(db_data=blast_db.output_db_path, + query_path=nt_db_test, + blast_params=blast_params, + blast_method=BlastCommands.blastn, + blast_columns=BlastColumns._fields) + + hit_df = blast_search_obj.get_blast_data(blast_db.output_db_path, Path(hit_file)) loci = []; metadata_obj = db_config(metadata_path, ['meta', 'info']) for idx,row in hit_df.iterrows(): qid = str(row['qseqid']) diff --git a/tests/test_fasta.py b/tests/test_fasta.py index a61e71d..2ab217b 100644 --- a/tests/test_fasta.py +++ b/tests/test_fasta.py @@ -3,6 +3,7 @@ from Bio.Seq import Seq from Bio.SeqRecord import SeqRecord import tempfile +from pathlib import Path import os from locidex.classes.fasta import ParseFasta @@ -19,19 +20,19 @@ def fasta_file(fasta_content): with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix=".fasta") as tmp: SeqIO.write(fasta_content, tmp, "fasta") tmp_path = tmp.name - yield tmp_path + yield Path(tmp_path) os.unlink(tmp_path) def test_parse_fasta_normal(fasta_file): parser = ParseFasta(fasta_file) assert len(parser.get_seqids()) == 2, "There should be two sequences" seq_data = parser.get_seq_by_id("gene1|123") - assert seq_data.seq == "ATGCGTACGTAGCTAGC", "Sequence data should match the input" + assert seq_data.seq == "ATGCGTACGTAGCTAGC".lower(), "Sequence data should match the input" def test_parse_fasta_with_nonexistent_file(): # Assuming that the locidex errors out with File not found error if FASTA is non existant: with pytest.raises(FileNotFoundError): - ParseFasta("nonexistent.fasta") + ParseFasta(Path("nonexistent.fasta")) # parser = parse_fasta("nonexistent.fasta") # assert not parser.status, "Parser status should be False when file does not exist" diff --git a/tests/test_seq_intake.py b/tests/test_seq_intake.py index 318985b..9828ea5 100644 --- a/tests/test_seq_intake.py +++ b/tests/test_seq_intake.py @@ -1,10 +1,11 @@ import os, warnings import locidex -from locidex.classes.seq_intake import SeqObject, seq_intake, seq_store +from locidex.classes.seq_intake import SeqObject, seq_intake, seq_store, HitFilters from locidex.constants import BlastColumns, DB_EXPECTED_FILES, DBConfig from locidex.classes.db import search_db_conf, db_config from collections import Counter from dataclasses import asdict +from pathlib import Path PACKAGE_ROOT = os.path.dirname(locidex.__file__) @@ -12,22 +13,27 @@ def seq_intake_class_init(input_file, file_type, perform_annotation): #reset global class variables to avoid ambiguous results - obj = seq_intake(input_file=input_file, + obj = seq_intake(input_file=Path(input_file), file_type=file_type,feat_key='CDS',translation_table=11, perform_annotation=perform_annotation,num_threads=1,skip_trans=False) return obj + #@pytest.mark.skip(reason="no way of currently testing this") def test_seq_store_class(): db_dir = os.path.join(PACKAGE_ROOT, 'example/build_db_mlst_out') db_database_config = search_db_conf(db_dir, DB_EXPECTED_FILES, DBConfig._keys()) metadata_obj = db_config(db_database_config.meta_file_path, ['meta', 'info']) sample_name = 'NC_003198.1.fasta' - seq_obj = seq_intake(input_file=os.path.join(PACKAGE_ROOT, 'example/search/NC_003198.1.fasta'), + + seq_obj = seq_intake(input_file=Path(os.path.join(PACKAGE_ROOT, 'example/search/NC_003198.1.fasta')), file_type='fasta', perform_annotation=False) - hit_filters = {'min_dna_len': 1, 'max_dna_len': 10000000, 'min_dna_ident': 80.0, 'min_dna_match_cov': 80.0, 'min_aa_len': 1, - 'max_aa_len': 10000000, 'min_aa_ident': 80.0, 'min_aa_match_cov': 80.0, 'dna_ambig_count': 99999999999999} - seq_store_obj = seq_store(sample_name, db_database_config.config_obj.config, metadata_obj.config['meta'], + + hit_filters = HitFilters(**{'min_dna_len': 1, 'max_dna_len': 10000000, 'min_dna_ident': 80.0, 'min_dna_match_cov': 80.0, 'min_aa_len': 1, + 'max_aa_len': 10000000, 'min_aa_ident': 80.0, 'min_aa_match_cov': 80.0, 'dna_ambig_count': 99999999999999}) + + seq_store_obj = seq_store(sample_name, DBConfig(**db_database_config.config_obj.config), metadata_obj.config['meta'], seq_obj.seq_data, BlastColumns._fields, hit_filters) + assert list(seq_store_obj.record.keys()) == ['db_info', 'db_seq_info', 'query_data', 'query_hit_columns'] assert list(seq_store_obj.record['db_info'].keys()) == ['db_name', 'db_version', 'db_date', 'db_author', 'db_desc', 'db_num_seqs', 'is_nucl', 'is_prot', 'nucleotide_db_name', 'protein_db_name'] @@ -37,7 +43,7 @@ def test_seq_store_class(): else: warnings.warn(f"expected len(seq_store_obj.record['query_data']['query_seq_data']) == 1 but got {len(seq_store_obj.record['query_data']['query_seq_data'])}") - compare_dict = asdict(seq_store_obj.record['query_data']['query_seq_data'][0]) + compare_dict = seq_store_obj.record['query_data']['query_seq_data'][0] assert set(compare_dict.keys()) == set(['parent_id', 'locus_name', 'seq_id', 'dna_hash', 'dna_len', 'aa_hash', 'aa_len', 'start_codon', 'end_codon', 'count_internal_stop', 'dna_ambig_count', 'dna_seq', 'aa_seq']) assert list(seq_store_obj.record['query_data']['locus_profile'].keys()) == ['aroC', 'dnaN', 'hemD', 'hisD', 'purE', 'sucA', 'thrA'] @@ -71,7 +77,7 @@ def test_read_gbk_file(): def test_read_fasta_file(): expected_orfs=4653 - seq_intake_object = seq_intake_class_init(input_file=os.path.join(PACKAGE_ROOT, 'example/search/NC_003198.1.fasta'), + seq_intake_object = seq_intake_class_init(input_file=Path(os.path.join(PACKAGE_ROOT, 'example/search/NC_003198.1.fasta')), file_type='fasta', perform_annotation=True) assert seq_intake_object.file_type == 'fasta' assert seq_intake_object.feat_key == 'CDS' diff --git a/tests/test_utils.py b/tests/test_utils.py index d32d3f9..81d97f7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,6 +6,7 @@ from locidex import utils from locidex import manifest from argparse import Namespace +from collections import namedtuple def test_check_db_groups_pass(monkeypatch): @@ -13,7 +14,9 @@ def test_check_db_groups_pass(monkeypatch): analysis_params = {"db_group": "Db1", "db_name": "test_name", "db_version": "1.0.0"} def mockreturn(*args, **kwargs): - return True + ret_tup = namedtuple('stuff', ["db_path"]) + ret_val = ret_tup(True) + return ret_val monkeypatch.setattr(manifest, "get_manifest_db", mockreturn) analysis_params = utils.check_db_groups(analysis_params, nm_group) assert analysis_params["db"]