74115dfde3f8c981f8f9119875d8bb014cce7b50
angie
  Thu Aug 17 20:23:19 2023 -0700
Script to identify most specific branch and total number of sites masked for each sample, for https://github.com/iqbal-lab/improve-sc2-phylo-paper/issues/1

diff --git src/hg/utils/otto/sarscov2phylo/branchSpecificMaskStats.py src/hg/utils/otto/sarscov2phylo/branchSpecificMaskStats.py
new file mode 100755
index 0000000..5f90f18
--- /dev/null
+++ src/hg/utils/otto/sarscov2phylo/branchSpecificMaskStats.py
@@ -0,0 +1,118 @@
+#!/usr/bin/env python3
+
+import yaml
+import subprocess, tempfile
+import logging, argparse, os, sys
+from collections import defaultdict
+
+import branchSpecificMask
+
+def getArgs():
+    parser = argparse.ArgumentParser(description="""
+Given a full sample-paths file and a maskfile from the same protobuf, print
+(to stdout) tab-separated output with three columns:
+* sample name
+* most specific Pango lineage (e.g. if BA.2, BA.5 and BA.5.3 all apply, BA.5.3)
+* total number of sites (and/or reversions) masked across all lineages
+"""
+)
+    parser.add_argument('samplePaths', metavar='merged.pb.sample-paths',
+                        help='sample-paths from same protobuf file as was branch-specific masked')
+    parser.add_argument('maskFile', metavar='merged.pb.branchSpecificMask.tsv',
+                        help='branch-specific masking file generated by branchSpecificMask.py')
+    parser.add_argument('yamlIn', metavar='branchSpecificMask.yml',
+                        help='YAML spec that identifies representative node and sites to mask')
+    args = parser.parse_args()
+    return args
+
+def die(message):
+    """Log an error message and exit with nonzero status"""
+    logging.error(message)
+    exit(1)
+
+def getRepresentativeNodeIds(maskFile):
+    """Read representative node IDs from maskFile, return dict of node IDs -> #occurrences"""
+    repNodes = defaultdict(int)
+    with open(maskFile, 'r') as m:
+        for line in m:
+            try:
+                [mut, nodeId] = line.rstrip().split('\t')
+            except ValueError as e:
+                die(f"maskFile {maskFile} has unexpected format (expect two tab-sep columns:\n" + line)
+            repNodes[nodeId] += 1
+    return repNodes
+
+def getLineageNodeId(name, nodeMuts, branchSpec):
+    """Given a name that is a representative for branchSpec, and its node list, find the node that
+    starts the branch"""
+    # In most cases we want the last node in the path [-1]
+    nodeIx = -1
+    # ... but if the last word starts with the sample name (with private mutations)
+    # then we do not want to mask just that sample, so backtrack to [-2]
+    if nodeMuts[-1].startswith(name):
+        nodeIx = nodeIx - 1
+    # ... and the spec might say to backtrack even more (e.g. parent or grandparent):
+    backtrack = branchSpec.get('representativeBacktrack')
+    if backtrack is not None:
+        nodeIx = nodeIx - backtrack
+    # Strip to just the node ID, discard mutations
+    return nodeMuts[nodeIx].split(':')[0]
+
+def printSampleStats(spec, repNodes, samplePaths):
+    """For each sample in samplePaths, print out sample name, most specific pango lineage that is
+    used for branch-specific masking, and total number of sites/reversions masked in sample"""
+    repLineages = {}
+    for branch in spec:
+        rep = spec[branch]['representative']
+        repLineages[rep] = branch
+    nodeLineages = dict()
+    sampleStats = list()
+    with open(samplePaths, 'r') as s:
+        for line in s:
+            try:
+                [fullName, path] = line.rstrip().split('\t')
+            except ValueError as e:
+                continue
+            branchNodes = list()
+            nodeMuts = path.split(' ')
+            # Collect all branch-starting nodes found in nodeMuts
+            for nodeMut in nodeMuts:
+                nodeId = nodeMut.split(':')[0]
+                if nodeId in repNodes:
+                    branchNodes.append(nodeId)
+            # If this sample is a representative, find which branch-starting node it has
+            # and map that node to the lineage that the sample represents.
+            name = fullName.split('|')[0]
+            if repLineages.get(name) is not None:
+                lineage = repLineages[name]
+                nodeId = getLineageNodeId(name, nodeMuts, spec[lineage])
+                nodeLineages[nodeId] = lineage
+            mostSpecificBranch = ''
+            siteCount = 0
+            if len(branchNodes) > 0:
+                # When the path contains multiple branch-starting nodes, the last one is the most
+                # specific.
+                mostSpecificBranch = branchNodes[-1]
+                for node in branchNodes:
+                    siteCount += repNodes[node]
+            sampleStats.append((fullName, mostSpecificBranch, siteCount))
+
+
+    for fullName, mostSpecificBranch, siteCount in sampleStats:
+        lineage = ''
+        if mostSpecificBranch != '':
+            lineage = nodeLineages.get(mostSpecificBranch)
+            if lineage is None:
+                die(f"No lineage for node {mostSpecificBranch} (sample {fullName})")
+        print('\t'.join([fullName, lineage, str(siteCount)]))
+
+
+def main():
+    args = getArgs()
+    spec = branchSpecificMask.getSpec(args.yamlIn)
+    repNodes = getRepresentativeNodeIds(args.maskFile)
+    printSampleStats(spec, repNodes, args.samplePaths)
+
+
+if __name__ == '__main__':
+    main()