644aca2e8d7af3d037a6fc29a400c5d79484756e
angie
  Wed Feb 8 20:49:24 2023 -0800
Data-driven branch-specific masking: replaced the contents of maskDelta.sh
with a YAML specification of branches and sites/ranges/reversions to mask
and a python script that reads the spec and then makes a matUtils mask file
and runs matUtils mask as before.  Did this so I could point people to the
YAML spec instead of a non-portable bash script when they asked what sites
are masked in what branches.  See https://github.com/yatisht/usher/issues/324

diff --git src/hg/utils/otto/sarscov2phylo/branchSpecificMask.py src/hg/utils/otto/sarscov2phylo/branchSpecificMask.py
new file mode 100755
index 0000000..52b2e59
--- /dev/null
+++ src/hg/utils/otto/sarscov2phylo/branchSpecificMask.py
@@ -0,0 +1,141 @@
+#!/usr/bin/env python3
+
+import yaml
+import subprocess, tempfile
+import logging, argparse, os, sys
+
+def getArgs():
+    parser = argparse.ArgumentParser(description="""
+Apply branch-specific masking to selected nodes as specified in spec.yaml
+to in.pb and write out.pb.  Requires matUtils from usher package.
+"""
+)
+    parser.add_argument('pbIn', metavar='in.pb',
+                        help='MAT protobuf file for input tree')
+    parser.add_argument('yamlIn', metavar='spec.yaml',
+                        help='YAML spec that identifies representative node and sites to mask')
+    parser.add_argument('pbOut', metavar='out.pb',
+                        help='MAT protobuf file to which output tree will be written')
+    args = parser.parse_args()
+    return args
+
+def die(message):
+    """Log an error message and exit with nonzero status"""
+    logging.error(message)
+    exit(1)
+
+def getSpec(yamlIn):
+    """Read yamlIn and make sure it looks like a spec should"""
+    with open(yamlIn) as f:
+        try:
+            spec = yaml.safe_load(f)
+        except yaml.YAMLError as e:
+            die(e)
+    myVersion = 0
+    specVersion = spec.get('version')
+    if specVersion is None:
+        die(f'Expecting to find version {myVersion} in {yamlIn} but no version found.')
+    elif specVersion > myVersion:
+        die(f'Version {specVersion} in {yamlIn} is too new, this script is at version {myVersion}')
+    del spec['version']
+    for branch in spec:
+        branchSpec = spec[branch]
+        repSeq = branchSpec.get('representative')
+        if repSeq is None:
+            die(f'No representative sequence was given for {branch} in {yamlIn}')
+        ranges = branchSpec.get('ranges')
+        sites = branchSpec.get('sites')
+        reversions = branchSpec.get('reversions')
+        if not ranges and not sites and not reversions:
+            die('Found none of {ranges, sites, reversions} ' + f'for {branch} in {yamlIn}')
+        if ranges:
+            for r in ranges:
+                try:
+                    range(r[0], r[1]+1)
+                except TypeError as e:
+                    die(f'Unexpected non-list value "{r}" in ranges for {branch} in {yamlIn}')
+    return spec
+
+def run(cmd):
+    """Run a command and exit with error output if it fails"""
+    try:
+        subprocess.run(cmd).check_returncode()
+    except subprocess.CalledProcessError as e:
+        die(e)
+
+def getBacktrack(spec, rep):
+    """If spec for branch whose representative is rep has representativeBacktrack, return that
+    value, otherwise return 0."""
+    for branch in spec:
+        branchSpec = spec[branch]
+        if branchSpec['representative'] == rep:
+            backtrack = branchSpec.get('representativeBacktrack')
+            if backtrack is not None:
+                return backtrack
+    return 0
+
+def getRepresentativeNodes(pbIn, spec):
+    """Run matUtils extract --sample-paths on pbIn and find path to each branch's representative.
+    Return dict mapping representative name to final node in path."""
+    repNodes = {}
+    for branch in spec:
+        rep = spec[branch]['representative']
+        repNodes[rep] = ''
+    samplePaths = tempfile.NamedTemporaryFile(delete=False)
+    samplePaths.close()
+    # matUtils SEGVs if given a full path output file name unless output dir is '/', need to fix that
+    run(['matUtils', 'extract', '-i', pbIn, '-d', '/', '--sample-paths', samplePaths.name])
+    with open(samplePaths.name) as f:
+        for line in f:
+            try:
+                [fullName, path] = line.rstrip().split('\t')
+            except ValueError as e:
+                continue
+            name = fullName.split('|')[0]
+            if repNodes.get(name) is not None:
+                nodes = path.split(' ')
+                nodeIx = -1
+                if nodes[-1].startswith(name):
+                    nodeIx = nodeIx - 1
+                nodeIx = nodeIx - getBacktrack(spec, name)
+                nodeMuts = nodes[nodeIx]
+                node = nodeMuts.split(':')[0]
+                repNodes[name] = node;
+    # Make sure we found all of them
+    for rep in repNodes:
+        if repNodes[rep] == '':
+            die("sample-paths file {samplePaths.name} does not have name {rep}")
+    os.unlink(samplePaths.name)
+    return repNodes
+
+def makeMaskFile(spec, repNodes, maskFileName):
+    """Create a file to use as input for matUtils mask --mask-mutations,
+    generated from spec with node IDs from repNodes."""
+    with open(maskFileName, 'w') as maskFile:
+        for branch in spec:
+            branchSpec = spec[branch]
+            rep = branchSpec['representative']
+            nodeId = repNodes[rep]
+            ranges = branchSpec.get('ranges')
+            if ranges:
+                for r in ranges:
+                    for pos in range(r[0], r[1]+1):
+                        maskFile.write(f'N{pos}N\t{nodeId}\n')
+            sites = branchSpec.get('sites')
+            if sites:
+                for pos in sites:
+                    maskFile.write(f'N{pos}N\t{nodeId}\n')
+            reversions = branchSpec.get('reversions')
+            if reversions:
+                for rev in reversions:
+                    maskFile.write(f'{rev}\t{nodeId}\n')
+
+def main():
+    args = getArgs()
+    spec = getSpec(args.yamlIn)
+    repNodes = getRepresentativeNodes(args.pbIn, spec)
+    maskFileName = args.pbIn + '.branchSpecificMask.tsv'
+    makeMaskFile(spec, repNodes, maskFileName)
+    run(['matUtils', 'mask', '-i', args.pbIn, '--mask-mutations', maskFileName, '-o', args.pbOut])
+
+main()