Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring/search #15

Merged
merged 7 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
__pycache__
*.egg*
.vscode
.vscode
.pytest_cache
tmp
85 changes: 39 additions & 46 deletions locidex/classes/blast.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,51 @@
import pandas as pd

from locidex.classes import run_command
from locidex.utils import slots
from dataclasses import dataclass
from typing import Optional
import os


@dataclass
class FilterOptions:
min: Optional[float]
max: Optional[float]
include: Optional[bool]
__slots__ = slots(__annotations__)

class blast_search:
VALID_BLAST_METHODS = ['blastn','tblastn','blastp']
BLAST_TABLE_COLS = []
input_db_path = None
input_query_path = None
output_db_path = None
output_results = None
blast_params = {}
blast_method = None
create_db = False
parse_seqids = False
status = True
messages = []

def __init__(self,input_db_path,input_query_path,output_results,blast_params,blast_method,blast_columns,output_db_path=None,create_db=False,parse_seqids=False):

def __init__(self,input_db_path,input_query_path,output_results,blast_params,blast_method,blast_columns,output_db_path=None,parse_seqids=False):
self.input_query_path = input_query_path
self.input_db_path = input_db_path
self.output_db_path = output_db_path
self.output_results = output_results
self.blast_method = blast_method
self.blast_params = blast_params
self.create_db = create_db
#self.create_db = create_db
self.BLAST_TABLE_COLS = blast_columns
self.parse_seqids = parse_seqids

if self.output_db_path is None:
self.output_db_path = input_db_path

if not os.path.isfile(self.input_query_path):
self.messages.append(f'Error {self.input_query_path} query fasta does not exist')
self.status = False

elif not create_db and not self.is_blast_db_valid():
self.messages.append(f'Error {self.input_db_path} is not a valid blast db and creation of db is disabled')
self.status = False
return

raise ValueError(f'Error {self.input_query_path} query fasta does not exist')
elif not self.is_blast_db_valid():
raise ValueError(f'Error {self.input_db_path} is not a valid blast db')
elif not blast_method in self.VALID_BLAST_METHODS:
self.messages.append(f'Error {blast_method} is not a supported blast method: {self.blast_method}')
self.status = False

elif create_db and not self.is_blast_db_valid():
(stdout,stderr) = self.makeblastdb()
if not self.is_blast_db_valid():
self.messages.append(f'Error {self.output_db_path} is not a valid blast db and creation of db failed')
self.status = False


self.messages.append(self.run_blast())
raise ValueError(f'Error {blast_method} is not a supported blast method: {self.blast_method}')

#elif create_db and not self.is_blast_db_valid():
# (stdout,stderr) = self.makeblastdb()
# print(stdout, stderr)
# if not self.is_blast_db_valid():
# raise ValueError(f'Error {self.output_db_path} is not a valid blast db and creation of db failed')
# TODO add logging for blast messages
(stdout, stderr) = self.run_blast()
print(stdout, stderr)


def makeblastdb(self):
Expand All @@ -71,10 +64,10 @@ def makeblastdb(self):
def is_blast_db_valid(self):
extensions = ['nsq', 'nin', 'nhr']
for e in extensions:
if not os.path.isfile(f'{self.input_db_path}.{e}'):
if not os.path.isfile(f'{self.output_db_path}.{e}'):
extensions2 = ['pto', 'ptf', 'phr']
for e2 in extensions2:
if not os.path.isfile(f'{self.input_db_path}.{e2}'):
if not os.path.isfile(f'{self.output_db_path}.{e2}'):
return False

return True
Expand All @@ -96,23 +89,20 @@ def run_blast(self):


class parse_blast:
BLAST_TABLE_COLS = []
input_file = None
df = None
columns = []
filter_options = {}
status = True
messages = []

def __init__(self, input_file,blast_columns,filter_options):
self.input_file = input_file
self.filter_options = filter_options
self.BLAST_TABLE_COLS = blast_columns

if not os.path.isfile(self.input_file):
self.messages.append(f'Error {self.input_file} does not exist')
self.status = False
self.read_hit_table()
raise FileNotFoundError(f'Error {self.input_file} does not exist')

self.df = self.read_hit_table()
print(self.df)
self.columns = self.df.columns.tolist()

for id_col in ['qseqid','sseqid']:
Expand All @@ -122,13 +112,16 @@ def __init__(self, input_file,blast_columns,filter_options):
self.df = self.df.astype(tp)
for col_name in self.filter_options:
if col_name in self.columns:
min_value = self.filter_options[col_name]['min']
max_value = self.filter_options[col_name]['max']
include = self.filter_options[col_name]['include']
#min_value = self.filter_options[col_name]['min']
#max_value = self.filter_options[col_name]['max']
#include = self.filter_options[col_name]['include']
min_value = self.filter_options[col_name].min
max_value = self.filter_options[col_name].max
include = self.filter_options[col_name].include
self.filter_df(col_name, min_value, max_value,include)

def read_hit_table(self):
self.df = pd.read_csv(self.input_file,header=None,names=self.BLAST_TABLE_COLS,sep="\t",low_memory=False)
return pd.read_csv(self.input_file,header=None,names=self.BLAST_TABLE_COLS,sep="\t",low_memory=False)


def filter_df(self,col_name,min_value,max_value,include):
Expand Down
167 changes: 167 additions & 0 deletions locidex/classes/blast2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""
Blast module refactored
"""

import pandas as pd

from locidex.classes import run_command
from locidex.utils import slots
from locidex.manifest import DBData
from locidex.constants import BlastCommands
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, List, Dict
import logging
import sys

logger = logging.getLogger(__name__)
logging.basicConfig(filemode=sys.stderr, level=logging.INFO)

@dataclass
class FilterOptions:
min: Optional[float]
max: Optional[float]
include: Optional[bool]
__slots__ = slots(__annotations__)


class BlastMakeDB:
"""
Create a blast database
"""

def __init__(self, input_file: Path, db_type: str, parse_seqids: bool, output_db_path: Optional[Path]):
self.input_file = input_file
self.db_type = db_type
self.parse_seqids = parse_seqids
self.output_db_path = output_db_path

def makeblastdb(self):
command = ['makeblastdb', '-in', str(self.input_file), '-dbtype', self.db_type]
if self.parse_seqids:
command.append('-parse_seqids')
if self.output_db_path != None:
command +=['-out', str(self.output_db_path)]
mk_db_cmd = " ".join([str(x) for x in command])
logger.info("Blast database command: {}".format(mk_db_cmd))
stdout, stderr = run_command(mk_db_cmd)
if stdout:
logger.info("Blast makedb stdout: {}".format(stdout))
if stderr:
logger.info("Blast makedb stderr: {}".format(stderr))
return self.output_db_path

class BlastSearch:
__blast_commands = set(BlastCommands._keys())
__blast_extensions_nt = frozenset(['.nsq', '.nin', '.nhr'])
__blast_extensions_pt = frozenset(['.pto', '.ptf', '.phr'])
__filter_columns = ["qseqid", "sseqid"]

def __init__(self, db_data: DBData, query_path: Path, blast_params: dict, blast_method: str, blast_columns: List[str], filter_options: Optional[Dict[str, FilterOptions]]=None):
self.db_data = db_data
if blast_method not in self.__blast_commands:
raise ValueError("{} is not a valid blast command please pick from {}".format(blast_method, self.__blast_commands))
self.query_path = query_path
self.blast_params = blast_params
self.blast_method = blast_method
self.blast_columns = blast_columns
self.filter_options = filter_options

def get_blast_data(self, db_path: Path, output: Path) -> pd.DataFrame:
"""
Run blast and parse results
TODO need to clean up the db_path hand off from the DBData obj its dirty
"""

stdout, stderr = self._run_blast(db=db_path, output=output)
if stdout:
logger.info("Blast search stdout: {}".format(stdout))
if stderr:
logger.info("Blast search stderr: {}".format(stderr))

blast_data = self.parse_blast(output_file=output)
return blast_data

def parse_blast(self, output_file: Path):
"""
Parse a blast output file
output_file Path: Generate blast data
"""
df = self.read_hit_table(output_file)
columns = df.columns.tolist()
for id_col in self.__filter_columns:
tp = {}
if id_col in columns:
tp[id_col] = 'object'
df = df.astype(tp)
if self.filter_options is None:
return df

for col_name in self.filter_options:
if col_name in columns:
min_value = self.filter_options[col_name].min
max_value = self.filter_options[col_name].max
include = self.filter_options[col_name].include
df = self.filter_df(df,col_name, min_value, max_value, include, columns)
return df


def filter_df(self,df, col_name,min_value,max_value,include, columns):
if col_name not in columns:
return False
if min_value is not None:
df = df[df[col_name] >= min_value]
if max_value is not None:
df = df[df[col_name] <= max_value]
if include is not None:
df = df[df[col_name].isin(include)]
return df

def read_hit_table(self, blast_data):
return pd.read_csv(blast_data,header=None,names=self.blast_columns,sep="\t",low_memory=False)


def _check_blast_files(self, db_dir: Path, extensions: frozenset):
"""
"""
extensions_ = set([i.suffix for i in db_dir.iterdir()])
if not extensions_.issuperset(extensions):
raise ValueError("Missing required blast files. {}".format([i for i in extensions_ if i not in extensions]))

def validate_blast_db(self, db_data=None):
"""
"""
if db_data is None:
db_data = self.db_data
if db_data.nucleotide:
self._check_blast_files(db_data.nucleotide, self.__blast_extensions_nt)

if db_data.protein:
self._check_blast_files(db_data.protein, self.__blast_extensions_pt)

def _run_blast(self, db: Path, output: Path):
"""
db PAth: Path to the blast database to use,
output Path: Path to file for blast output

TODO need to use classes db versoin or not pass it too the initializer
"""
command = [
self.blast_method,
'-query', self.query_path,
'-db', str(db),
'-out', str(output),
'-outfmt', "'6 {}'".format(' '.join(self.blast_columns)),
]
for param in self.blast_params:
if param == "parse_seqids":
command.append(f"-{param}")
else:
command += [f'-{param}', f'{self.blast_params[param]}']
blast_search_cmd = " ".join([str(x) for x in command])
logger.info("Blast command: {}".format(blast_search_cmd))
return run_command(blast_search_cmd)




Loading
Loading