74115dfde3f8c981f8f9119875d8bb014cce7b50 angie Thu Aug 17 20:23:19 2023 -0700 Script to identify most specific branch and total number of sites masked for each sample, for https://github.com/iqbal-lab/improve-sc2-phylo-paper/issues/1 diff --git src/hg/utils/otto/sarscov2phylo/branchSpecificMaskStats.py src/hg/utils/otto/sarscov2phylo/branchSpecificMaskStats.py new file mode 100755 index 0000000..5f90f18 --- /dev/null +++ src/hg/utils/otto/sarscov2phylo/branchSpecificMaskStats.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 + +import yaml +import subprocess, tempfile +import logging, argparse, os, sys +from collections import defaultdict + +import branchSpecificMask + +def getArgs(): + parser = argparse.ArgumentParser(description=""" +Given a full sample-paths file and a maskfile from the same protobuf, print +(to stdout) tab-separated output with three columns: +* sample name +* most specific Pango lineage (e.g. if BA.2, BA.5 and BA.5.3 all apply, BA.5.3) +* total number of sites (and/or reversions) masked across all lineages +""" +) + parser.add_argument('samplePaths', metavar='merged.pb.sample-paths', + help='sample-paths from same protobuf file as was branch-specific masked') + parser.add_argument('maskFile', metavar='merged.pb.branchSpecificMask.tsv', + help='branch-specific masking file generated by branchSpecificMask.py') + parser.add_argument('yamlIn', metavar='branchSpecificMask.yml', + help='YAML spec that identifies representative node and sites to mask') + args = parser.parse_args() + return args + +def die(message): + """Log an error message and exit with nonzero status""" + logging.error(message) + exit(1) + +def getRepresentativeNodeIds(maskFile): + """Read representative node IDs from maskFile, return dict of node IDs -> #occurrences""" + repNodes = defaultdict(int) + with open(maskFile, 'r') as m: + for line in m: + try: + [mut, nodeId] = line.rstrip().split('\t') + except ValueError as e: + die(f"maskFile {maskFile} has unexpected format (expect two tab-sep columns:\n" + line) + repNodes[nodeId] += 1 + return repNodes + +def getLineageNodeId(name, nodeMuts, branchSpec): + """Given a name that is a representative for branchSpec, and its node list, find the node that + starts the branch""" + # In most cases we want the last node in the path [-1] + nodeIx = -1 + # ... but if the last word starts with the sample name (with private mutations) + # then we do not want to mask just that sample, so backtrack to [-2] + if nodeMuts[-1].startswith(name): + nodeIx = nodeIx - 1 + # ... and the spec might say to backtrack even more (e.g. parent or grandparent): + backtrack = branchSpec.get('representativeBacktrack') + if backtrack is not None: + nodeIx = nodeIx - backtrack + # Strip to just the node ID, discard mutations + return nodeMuts[nodeIx].split(':')[0] + +def printSampleStats(spec, repNodes, samplePaths): + """For each sample in samplePaths, print out sample name, most specific pango lineage that is + used for branch-specific masking, and total number of sites/reversions masked in sample""" + repLineages = {} + for branch in spec: + rep = spec[branch]['representative'] + repLineages[rep] = branch + nodeLineages = dict() + sampleStats = list() + with open(samplePaths, 'r') as s: + for line in s: + try: + [fullName, path] = line.rstrip().split('\t') + except ValueError as e: + continue + branchNodes = list() + nodeMuts = path.split(' ') + # Collect all branch-starting nodes found in nodeMuts + for nodeMut in nodeMuts: + nodeId = nodeMut.split(':')[0] + if nodeId in repNodes: + branchNodes.append(nodeId) + # If this sample is a representative, find which branch-starting node it has + # and map that node to the lineage that the sample represents. + name = fullName.split('|')[0] + if repLineages.get(name) is not None: + lineage = repLineages[name] + nodeId = getLineageNodeId(name, nodeMuts, spec[lineage]) + nodeLineages[nodeId] = lineage + mostSpecificBranch = '' + siteCount = 0 + if len(branchNodes) > 0: + # When the path contains multiple branch-starting nodes, the last one is the most + # specific. + mostSpecificBranch = branchNodes[-1] + for node in branchNodes: + siteCount += repNodes[node] + sampleStats.append((fullName, mostSpecificBranch, siteCount)) + + + for fullName, mostSpecificBranch, siteCount in sampleStats: + lineage = '' + if mostSpecificBranch != '': + lineage = nodeLineages.get(mostSpecificBranch) + if lineage is None: + die(f"No lineage for node {mostSpecificBranch} (sample {fullName})") + print('\t'.join([fullName, lineage, str(siteCount)])) + + +def main(): + args = getArgs() + spec = branchSpecificMask.getSpec(args.yamlIn) + repNodes = getRepresentativeNodeIds(args.maskFile) + printSampleStats(spec, repNodes, args.samplePaths) + + +if __name__ == '__main__': + main()