diff --git a/scripts/add_derived_haplotypes.py b/scripts/add_derived_haplotypes.py index e1df4afb..43219d2f 100644 --- a/scripts/add_derived_haplotypes.py +++ b/scripts/add_derived_haplotypes.py @@ -2,20 +2,35 @@ Annotate derived haplotypes per node from annotated clades and store as node data JSON. """ import argparse +import json import pandas as pd -def create_haplotype_for_record(record, clade_column, mutations_column, genes=None, strip_genes=False): +def create_haplotype_for_record(record, clade_column, mutations_column, genes=None, strip_genes=False, sites_by_gene=None): """Create a haplotype string for the given record based on the values in its clade and mutations column. If a list of genes is given, filter mutations to only those in the requested genes. """ clade = record[clade_column] + + if record[mutations_column] == "": + return clade + mutations = record[mutations_column].split(",") # Filter mutations to requested genes. - if genes is not None: + if sites_by_gene is not None: + filtered_mutations = [] + for mutation in mutations: + # mutation looks like "HA1:N145K" + gene, allele = mutation.split(":") + position = allele[1:-1] + if gene in sites_by_gene and position in sites_by_gene[gene]: + filtered_mutations.append(mutation) + + mutations = filtered_mutations + elif genes is not None: mutations = [ mutation for mutation in mutations @@ -44,6 +59,7 @@ def create_haplotype_for_record(record, clade_column, mutations_column, genes=No parser.add_argument("--clade-column", help="name of the branch attribute for clade labels in the given Nextclade annotations", default="subclade") parser.add_argument("--mutations-column", help="name of the attribute for mutations relative to clades in the given Nextclade annotations", default="founderMuts['subclade'].aaSubstitutions") parser.add_argument("--genes", nargs="+", help="list of genes to filter mutations to. If not provided, all mutations will be used.") + parser.add_argument("--distance-map", help="distance map JSON of genes and positions to include in haplotypes") parser.add_argument("--strip-genes", action="store_true", help="strip gene names from coordinates in output haplotypes") parser.add_argument("--attribute-name", default="haplotype", help="name of attribute to store the derived haplotype in the output file") parser.add_argument("--output", help="TSV file of Nextclade annotations with derived haplotype column added", required=True) @@ -60,6 +76,13 @@ def create_haplotype_for_record(record, clade_column, mutations_column, genes=No na_filter=False, ) + # Load distance map. + sites_by_gene = None + if args.distance_map: + with open(args.distance_map, "r", encoding="utf-8") as fh: + distance_map = json.load(fh) + sites_by_gene = distance_map["map"] + # Annotate derived haplotypes. df[args.attribute_name] = df.apply( lambda record: create_haplotype_for_record( @@ -68,6 +91,7 @@ def create_haplotype_for_record(record, clade_column, mutations_column, genes=No args.mutations_column, args.genes, args.strip_genes, + sites_by_gene, ), axis=1 )