Skip to content

Commit

Permalink
shifted functions to package modules
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Youngblut committed Dec 5, 2021
1 parent 5ae8ef8 commit 83f8687
Show file tree
Hide file tree
Showing 10 changed files with 271 additions and 219 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ will create different taxIDs.

# Install

## Dependencies

* numpy
* networkx
* taxonkit

## Package

### From pypi

```
Expand Down
98 changes: 9 additions & 89 deletions bin/gtdb_to_diamond.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,93 +56,11 @@
help='Keep temporary output? (Default: %(default)s)')
parser.add_argument('--version', action='version', version='0.0.1')

# logging
logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.DEBUG)


def copy_nodes(infile, outdir):
"""
Simple copy of nodes.dmp file into the output directory
"""
logging.info('Read nodes.dmp file: {}'.format(infile))
outfile = os.path.join(outdir, 'nodes.dmp')
if infile == outfile:
raise IOError('Input == Output: {} <=> {}'.format(infile, outfile))
with open(infile) as inF, open(outfile, 'w') as outF:
for line in inF:
outF.write(line)
logging.info('File written: {}'.format(outfile))

def read_names_dmp(infile, outdir):
"""
Reading names.dmp file
"""
outfile = os.path.join(outdir, 'names.dmp')
regex = re.compile(r'\t\|\t')
regexp = re.compile(r'^[^_]+_|_')
names_dmp = {}
logging.info('Reading dumpfile: {}'.format(infile))
with open(infile) as inF, open(outfile, 'w') as outF:
for line in inF:
line = regex.split(line.rstrip())
if len(line) >= 2:
line[1] = regexp.sub('', line[1]) # accession
names_dmp[line[1]] = line[0]
outF.write('\t|\t'.join(line) + '\n')

logging.info(' File written: {}'.format(outfile))
msg = ' No. of accession<=>taxID pairs: {}'
logging.info(msg.format(len(names_dmp.keys())))
return names_dmp

def faa_gz_files(members):
"""
Getting .faa.gz files from the tarball
"""
for tarinfo in members:
for ext in ('.faa.gz', '.faa'):
if tarinfo.name.endswith(ext):
yield tarinfo

def faa_gz_index(directory='.', extensions=['.faa', '.faa.gz']):
"""
Creating {accession:faa_file} index from extracted tarball files
"""
extensions = set(extensions)
regex = re.compile(r'(_protein\.faa\.gz|_protein\.faa)$')
regexp = re.compile(r'^[^_]+_|_')
found = {}
for dirpath, dirnames, files in os.walk(directory):
for name in files:
for ext in extensions:
if name.lower().endswith(ext):
accession = regexp.sub('', regex.sub('', name))
found[accession] = os.path.join(dirpath, name)
continue
return found

def uncomp_tarball(tarball_file, tmp_dir):
"""
Extracting info from the tarball
"""
# tmp dir
if os.path.isdir(tmp_dir):
shutil.rmtree(tmp_dir)
os.makedirs(tmp_dir)
# extracting tarball
logging.info('Extracting tarball: {}'.format(tarball_file))
logging.info(' Extracting to: {}'.format(tmp_dir))
tar = tarfile.open(tarball_file)
tar.extractall(path=tmp_dir, members=faa_gz_files(tar))
tar.close()
# listing files
faa_files = faa_gz_index(tmp_dir, ['.faa', '.faa.gz'])
n_files = len(faa_files.keys())
msg = ' No. of .faa(.gz) files: {}'
logging.info(msg.format(n_files))
if n_files == 0:
logging.warning(' No .faa(.gz) files found!')
return faa_files

# functions
def accession2taxid(names_dmp, faa_files, outdir):
"""
Creating accession2taxid table
Expand Down Expand Up @@ -189,19 +107,20 @@ def faa_merge(faa_files, outdir, gzip_out=False):
outF.write(line)
logging.info(' File written: {}'.format(outfile))
logging.info(' No. of seqs. written: {}'.format(seq_cnt))


## main interface
def main(args):
"""
Main interface
"""
if not os.path.isdir(args.outdir):
os.makedirs(args.outdir)
# copying nodes
copy_nodes(args.nodes_dmp, args.outdir)
gtdb2td.Dmp.copy_nodes(args.nodes_dmp, args.outdir)
# reading in names.dmp
names_dmp = read_names_dmp(args.names_dmp, args.outdir)
names_dmp = gtdb2td.Dmp.read_names_dmp(args.names_dmp, args.outdir)
# uncompressing tarball of faa fasta files
faa_files = uncomp_tarball(args.faa_tarball, args.tmpdir)
faa_files = gtdb2td.IO.uncomp_tarball(args.faa_tarball, args.tmpdir)
# create accession2taxid
accession2taxid(names_dmp, faa_files, args.outdir)
# creating combined faa fasta
Expand All @@ -210,7 +129,8 @@ def main(args):
if not args.keep_temp and os.path.isdir(args.tmpdir):
shutil.rmtree(args.tmpdir)
logging.info('Temp-dir removed: {}'.format(args.tmpdir))


# script main
if __name__ == '__main__':
args = parser.parse_args()
main(args)
Expand Down
32 changes: 1 addition & 31 deletions bin/gtdb_to_taxdump.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,44 +56,14 @@

logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.DEBUG)



def find_all_paths(self, start_vertex, end_vertex, path=[]):
""" find all paths from start_vertex to
end_vertex in graph """
graph = self.__graph_dict
path = path + [start_vertex]
if start_vertex == end_vertex:
return [path]
if start_vertex not in graph:
return []
paths = []
for vertex in graph[start_vertex]:
if vertex not in path:
extended_paths = self.find_all_paths(vertex,
end_vertex,
path)
for p in extended_paths:
paths.append(p)
return paths

def get_url_data(url):
"""
Downloading data from url; assuming gzip
"""
req = urllib.request.Request(url)
req.add_header('Accept-Encoding', 'gzip')
response = urllib.request.urlopen(req)
content = gzip.decompress(response.read())
return content.splitlines()

def load_gtdb_tax(infile, graph):
"""
loading gtdb taxonomy & adding to DAG
"""
# url or file download/open
try:
inF = get_url_data(infile)
inF = gtdb2td.Utils.get_url_data(infile)
except (OSError,ValueError) as e:
try:
ftpstream = urllib.request.urlopen(infile)
Expand Down
55 changes: 5 additions & 50 deletions bin/lineage2taxid.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,53 +67,6 @@ class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter,
parser.add_argument('--version', action='version', version='0.0.1')

# functions
def load_dmp(names_dmp_file, nodes_dmp_file):
"""
Loading NCBI names/nodes dmp files as DAG
Arguments:
names_dmp_file : str, names.dmp file
nodes_dmp_file : str, nodes.dmp file
Return:
network.DiGraph object
"""
regex = re.compile(r'\t\|\t')
# nodes
logging.info('Loading file: {}'.format(names_dmp_file))
idx = {} # {taxid : name}
with gtdb2td.Utils.Open(names_dmp_file) as inF:
for line in inF:
line = line.rstrip()
if line == '':
continue
line = regex.split(line)
idx[int(line[0])] = line[1].lower()
# names
logging.info('Loading file: {}'.format(nodes_dmp_file))
G = nx.DiGraph()
G.add_node(0, rank = 'root', name = 'root')
with gtdb2td.Utils.Open(nodes_dmp_file) as inF:
for line in inF:
line = line.rstrip()
if line == '':
continue
line = regex.split(line)
taxid_child = int(line[0])
taxid_parent = int(line[1])
rank_child = line[2]
name_child = idx[taxid_child].lower()
name_parent = idx[taxid_parent].lower()
# adding node
G.add_node(taxid_child, rank=rank_child, name=name_child)
# adding edge
if taxid_parent == 1:
G.add_edge(0, taxid_child)
else:
G.add_edge(taxid_parent, taxid_child)
idx.clear()
logging.info(' No. of nodes: {}'.format(G.number_of_nodes()))
logging.info(' No. of edges: {}'.format(G.number_of_edges()))
return G

def lineage2taxid(lineage, G):
lineage = lineage.split(';')
for cls in lineage[::-1]:
Expand Down Expand Up @@ -152,18 +105,20 @@ def parse_lineage_table(table_file, lineage_column, G,
# status
if i > 0 and (i+1) % 100 == 0:
logging.info(' Records processed: {}'.format(i+1))


## main interface
def main(args):
"""
Main interface
"""
# loading dmp as DAG
G = load_dmp(args.names_dmp, args.nodes_dmp)
G = gtdb2td.Dmp.load_dmp(args.names_dmp, args.nodes_dmp)
# lineage2taxid
parse_lineage_table(args.table_file, args.lineage_column, G=G,
taxid_column = args.taxid_column,
taxid_rank_column = args.taxid_rank_column)


# script main
if __name__ == '__main__':
args = parser.parse_args()
main(args)
100 changes: 51 additions & 49 deletions bin/ncbi-gtdb_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,54 +133,56 @@
logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.DEBUG)


def load_dmp(names_dmp_file, nodes_dmp_file, no_prefix=False):
"""
Loading NCBI names/nodes dmp files as DAG
Arguments:
names_dmp_file : str, names.dmp file
nodes_dmp_file : str, nodes.dmp file
Return:
network.DiGraph object
"""
regex = re.compile(r'\t\|\t')
# nodes
logging.info('Loading file: {}'.format(names_dmp_file))
idx = {} # {taxid : name}
with open(names_dmp_file) as inF:
for line in inF:
line = line.rstrip()
if line == '':
continue
line = regex.split(line)
idx[int(line[0])] = line[1]
# names
logging.info('Loading file: {}'.format(nodes_dmp_file))
G = nx.DiGraph()
G.add_node(0, rank = 'root', name = 'root')
with open(nodes_dmp_file) as inF:
for line in inF:
line = line.rstrip()
if line == '':
continue
line = regex.split(line)
taxid_child = int(line[0])
taxid_parent = int(line[1])
rank_child = line[2]
name_child = idx[taxid_child]
name_parent = idx[taxid_parent]
if rank_child == 'species':
name_child = 's__' + name_child
# adding node
G.add_node(taxid_child, rank=rank_child, name=name_child)
# adding edge
if taxid_parent == 1:
G.add_edge(0, taxid_child)
else:
G.add_edge(taxid_parent, taxid_child)
idx.clear()
logging.info(' No. of nodes: {}'.format(G.number_of_nodes()))
logging.info(' No. of edges: {}'.format(G.number_of_edges()))
return G
# functions

# def load_dmp(names_dmp_file, nodes_dmp_file, no_prefix=False):
# """
# Loading NCBI names/nodes dmp files as DAG
# Arguments:
# names_dmp_file : str, names.dmp file
# nodes_dmp_file : str, nodes.dmp file
# Return:
# network.DiGraph object
# """
# regex = re.compile(r'\t\|\t')
# # nodes
# logging.info('Loading file: {}'.format(names_dmp_file))
# idx = {} # {taxid : name}
# with open(names_dmp_file) as inF:
# for line in inF:
# line = line.rstrip()
# if line == '':
# continue
# line = regex.split(line)
# idx[int(line[0])] = line[1]
# # names
# logging.info('Loading file: {}'.format(nodes_dmp_file))
# G = nx.DiGraph()
# G.add_node(0, rank = 'root', name = 'root')
# with open(nodes_dmp_file) as inF:
# for line in inF:
# line = line.rstrip()
# if line == '':
# continue
# line = regex.split(line)
# taxid_child = int(line[0])
# taxid_parent = int(line[1])
# rank_child = line[2]
# name_child = idx[taxid_child]
# name_parent = idx[taxid_parent]
# if rank_child == 'species':
# name_child = 's__' + name_child
# # adding node
# G.add_node(taxid_child, rank=rank_child, name=name_child)
# # adding edge
# if taxid_parent == 1:
# G.add_edge(0, taxid_child)
# else:
# G.add_edge(taxid_parent, taxid_child)
# idx.clear()
# logging.info(' No. of nodes: {}'.format(G.number_of_nodes()))
# logging.info(' No. of edges: {}'.format(G.number_of_edges()))
# return G

def format_taxonomy(T, hierarchy, acc, no_prefix=False):
"""
Expand Down Expand Up @@ -594,7 +596,7 @@ def main(args):
"""
# loading ncbi dmp files if provided
if args.names_dmp is not None and args.nodes_dmp is not None:
ncbi_tax = load_dmp(args.names_dmp, args.nodes_dmp)
ncbi_tax = gtdb2td.Dmp.load_dmp(args.names_dmp, args.nodes_dmp)
else:
ncbi_tax = None
# loading the metadata as graphs
Expand Down
Loading

0 comments on commit 83f8687

Please sign in to comment.