4baf9e8d58909435aeff4102172d4b6c76e9ed0c
angie
  Sun Apr 19 11:43:07 2020 -0700
Better Nextstrain VCF (merge SNVs at same pos, add BACKMUTS); new output file .varPaths.  refs #25188
VCF improvements:
* Merge SNVs at the same position (multiple alt alleles like C>A and C>T, serial mutations
like T>G and G>A, and back-mutations like G>T and T>G) into a single VCF record with
accurate allele counts.
* Add BACKMUTS to INFO to highlight samples with back-mutations.
The new output file nextstrainSamples.varPaths shows the sequence of variants that occur at
nodes in the path from the root of the tree to the sample's leaf, for each sample.  This gives
some insight into the strucure of the tree and how many mutations the various samples have;
by comparing .varPaths files from different NextStrain releases, we can also see that the
order of ancestors and descendants is sometimes shuffled from one release to the next.

diff --git src/hg/utils/otto/nextstrainNcov/nextstrain.py src/hg/utils/otto/nextstrainNcov/nextstrain.py
index 2bd0c8d..85347af 100755
--- src/hg/utils/otto/nextstrainNcov/nextstrain.py
+++ src/hg/utils/otto/nextstrainNcov/nextstrain.py
@@ -174,152 +174,234 @@
     # Add variants specific to this branch (if any)
     for varName in localVariants:
         branchVariants[varName] = 1
     # Make an ordered variant string as David requested: semicolons between nodes,
     # comma-separated within a node:
     branchVarStr = ''
     for varName in localVariants:
         if (len(branchVarStr)):
             branchVarStr += ', '
         branchVarStr += varName
         aaVar = variantAaChanges.get(varName)
         if (aaVar):
             branchVarStr += ' (' + aaVar + ')'
     if (len(parentVarStr) and len(branchVarStr)):
         branchVarStr = parentVarStr + '; ' + branchVarStr
+    elif (not len(branchVarStr)):
+        branchVarStr = parentVarStr
     nodeAttrs = branch['node_attrs']
     if (nodeAttrs.get('clade_membership')):
         cladeName = nodeAttrs['clade_membership']['value']
         if (cladeName != 'unassigned' and not cladeName in clades):
             clades[cladeName] = cladeFromVariants(cladeName, branchVariants, branchVarStr)
             addDatesToClade(clades[cladeName], nodeAttrs['num_date'])
             if (nodeAttrs.get('country')):
                 addCountryToClade(clades[cladeName], nodeAttrs['country'])
             elif (nodeAttrs.get('division')):
                 addCountryToClade(clades[cladeName], nodeAttrs['division'])
             else:
                 warn('No country or division for new clade ' + cladeName)
             cladeNodes[cladeName] = branch
     kids = branch.get('children')
     if (kids):
         for child in kids:
             rUnpackNextstrainTree(child, branchVariants, branchVarStr);
     else:
         for varName in branchVariants:
             variantCounts[varName] += 1
         samples.append({ 'id': nodeAttrs['gisaid_epi_isl']['value'],
                          'name': branch['name'],
                          'clade': nodeAttrs['clade_membership']['value'],
                          'date': numDateToMonthDay(nodeAttrs['num_date']['value']),
-                         'variants': branchVariants })
+                         'variants': branchVariants,
+                         'varStr': branchVarStr })
 
 rUnpackNextstrainTree(ncov['tree'], {}, '')
 
+def sampleName(sample):
+    return '|'.join([sample['id'], sample['name'], sample['date']])
+
 sampleCount = len(samples)
-sampleNames = [ '|'.join([sample['id'], sample['name'], sample['date']]) for sample in samples ]
+sampleNames = [ sampleName(sample)  for sample in samples ]
 
 # Parse variant names like 'G11083T' into pos and alleles; bundle in VCF column order
 parsedVars = []
 for varName in variantCounts.keys():
     m = snvRe.match(varName)
     if (m):
         ref, pos, alt = m.groups()
         parsedVars.append([int(pos), varName, ref, alt])
     else:
         warn("Can't match " + varName)
 # Sort by position
 def parsedVarPos(pv):
     return pv[0]
 parsedVars.sort(key=parsedVarPos)
 
-def boolToStr01(bool):
-    """Convert boolean to string 1 or 0."""
-    if (bool):
-        return '1'
+def parsedVarAlleleCount(pv):
+    return variantCounts[pv[1]]
+
+def mergeVariants(pvList):
+    """Given a list of parsedVars [pos, varName, ref, alt] at the same pos, resolve the actual allele at each sample, handling back-mutations and serial mutations."""
+    # Sort by descending allele count, assuming that the ref allele of the most frequently
+    # observed variant is the true ref allele.  For back-muts, ref and alt are swapped;
+    # for serial muts, alt of the first is ref of the second.
+    pvList.sort(key=parsedVarAlleleCount)
+    pvList.reverse()
+    pos, varName, trueRef, firstAlt = pvList[0]
+    alts = []
+    sampleAlleles = [ 0 for sample in samples ]
+    backMutSamples = []
+    for pv in pvList:
+        thisPos, thisName, thisRef, thisAlt = pv
+        if (thisPos != pos):
+            warn("mergeVariants: inconsistent pos " + pos + " and " + thisPos)
+        if (thisAlt == trueRef):
+            # Back-mutation, not a true alt
+            alIx = 0
+        else:
+            # Add to list of alts - unless it's an alt we've already seen, but from a different
+            # serial mutation.  For example, there might be T>A but also T>G+G>A; don't add A twice.
+            if (not thisAlt in alts):
+                alts.append(thisAlt)
+                if (thisName != varName):
+                    varName += "," + thisName
+            alIx = alts.index(thisAlt) + 1
+        for ix, sample in enumerate(samples):
+            if (sample['variants'].get(thisName)):
+                sampleAlleles[ix] = alIx
+                if (alIx == 0):
+                    backMutSamples.append(sampleName(sample))
+    # After handling back- and serial mutations, figure out true counts of each alt allele:
+    altCounts = [ 0 for alt in alts ]
+    for alIx in sampleAlleles:
+        if (alIx > 0):
+            altCounts[alIx - 1] += 1
+    return [ [pos, varName, trueRef, ','.join(alts)],
+             alts, altCounts, sampleAlleles, backMutSamples ]
+
+mergedVars = []
+
+variantsAtPos = []
+for pv in parsedVars:
+    pos = pv[0]
+    if (len(variantsAtPos) == 0 or pos == variantsAtPos[0][0]):
+        variantsAtPos.append(pv)
     else:
-        return '0'
+        mergedVars.append(mergeVariants(variantsAtPos))
+        variantsAtPos = [pv]
+mergedVars.append(mergeVariants(variantsAtPos))
 
 def writeVcfHeaderExceptSamples(outF):
     """Write VCF header lines -- except for sample names (this ends with a \t not a \n)."""
     outF.write('##fileformat=VCFv4.3\n')
     outF.write('##source=nextstrain.org\n')
     outF.write('\t'.join(['#CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT']) +
                '\t');
 
+def tallyAaChanges(varNameMerged):
+    """If any of the merged variants cause a coding change, then produce a comma-sep string in same order as variants with corresponding change(s) or '-' if a variant does not cause a coding change.  If none of the variants cause a coding change, return the empty string."""
+    varNames = varNameMerged.split(',')
+    gotAaChange = 0
+    aaChanges = []
+    for varName in varNames:
+        aaChange = variantAaChanges.get(varName)
+        if (aaChange):
+            gotAaChange = 1
+            aaChanges.append(aaChange)
+        else:
+            aaChanges.append('-');
+    if (gotAaChange):
+        aaChangeStr = ','.join(aaChanges)
+    else:
+        aaChangeStr = ''
+    return aaChangeStr
+
 # VCF
 with open('nextstrainSamples.vcf', 'w') as outC:
     writeVcfHeaderExceptSamples(outC)
     outC.write('\t'.join(sampleNames) + '\n')
-    for pv in parsedVars:
-        varName = pv[1]
-        info = 'AC=' + str(variantCounts[varName]) + ';AN=' + str(sampleCount)
-        aaChange = variantAaChanges.get(varName)
-        if (aaChange):
+    for mv in mergedVars:
+        pv, alts, altCounts, sampleAlleles, backMutSamples = mv
+        info = 'AC=' + ','.join(map(str, altCounts)) + ';AN=' + str(sampleCount)
+        varNameMerged = pv[1]
+        aaChange = tallyAaChanges(varNameMerged)
+        if (len(aaChange)):
             info += ';AACHANGE=' + aaChange
+        if (len(backMutSamples)):
+            info += ';BACKMUTS=' + ','.join(backMutSamples)
         genotypes = []
-        for sample in samples:
-            gt = boolToStr01(sample['variants'].get(varName))
+        for sample, alIx in zip(samples, sampleAlleles):
+            gt = str(alIx)
             genotypes.append(gt + ':' + sample['clade'])
         outC.write('\t'.join([ chrom,
                                '\t'.join(map(str, pv)),
                                '\t'.join(['.', 'PASS', info, 'GT:CLADE']),
                                '\t'.join(genotypes) ]) + '\n')
     outC.close()
 
 # Assign samples to clades; a sample can appear in multiple clades if they are nested.
 cladeSamples = {}
 cladeSampleCounts = {}
 cladeSampleNames = {}
 
 def sampleIdsFromNode(node, ids):
     """Fill in a dict of IDs of all samples found under node."""
     kids = node.get('children')
     if (kids):
         for kid in kids:
             sampleIdsFromNode(kid, ids)
     else:
         sampleId = node['node_attrs']['gisaid_epi_isl']['value']
         ids[sampleId] = 1
 
 for cladeName, node in cladeNodes.items():
     sampleIds = {}
     sampleIdsFromNode(node, sampleIds)
     cladeSampleList = [ sample for sample in samples if sample['id'] in sampleIds ]
-    cladeSamples[cladeName] = cladeSampleList
+    cladeSamples[cladeName] = sampleIds
     cladeSampleCounts[cladeName] = len(cladeSampleList)
-    cladeSampleNames[cladeName] = [ '|'.join([ sample['id'], sample['name'], sample['date'] ])
-                                    for sample in cladeSampleList ]
+    cladeSampleNames[cladeName] = [ sampleName(sample) for sample in cladeSampleList ]
 
 # Per-clade VCF subset
-for cladeName, cladeSampleList in cladeSamples.items():
+for cladeName, cladeSampleIds in cladeSamples.items():
     with open('nextstrainSamples' + cladeName + '.vcf', 'w') as outV:
         writeVcfHeaderExceptSamples(outV)
         outV.write('\t'.join(cladeSampleNames[cladeName]) + '\n')
-        for pv in parsedVars:
-            varName = pv[1]
+        for mv in mergedVars:
+            pv, alts, overallAltCounts, sampleAlleles, backMutSamples = mv
+            varNameMerged = pv[1]
             genotypes = []
-            ac=0
-            for sample in cladeSampleList:
-                gt = boolToStr01(sample['variants'].get(varName))
-                genotypes.append(gt)
-                if (sample['variants'].get(varName)):
-                    ac += 1
-            if (ac > 0):
-                info = 'AC=' + str(ac) + ';AN=' + str(cladeSampleCounts[cladeName])
-                aaChange = variantAaChanges.get(varName)
-                if (aaChange):
+            altCounts = [ 0 for alt in alts ]
+            acTotal=0
+            for sample, alIx in zip(samples, sampleAlleles):
+                if (sample['id'] in cladeSampleIds):
+                    gt = str(alIx)
+                    genotypes.append(gt + ':' + cladeName)
+                    if (alIx > 0):
+                        altCounts[alIx - 1] += 1
+                        acTotal += 1
+            if (acTotal > 0):
+                info = 'AC=' + ','.join(map(str, altCounts))
+                info += ';AN=' + str(cladeSampleCounts[cladeName])
+                aaChange = tallyAaChanges(varNameMerged)
+                if (len(aaChange)):
                     info += ';AACHANGE=' + aaChange
+                cladeBackMuts = [ sampleName for sampleName in backMutSamples
+                                  if sampleName in cladeSampleNames[cladeName] ]
+                if (len(cladeBackMuts)):
+                    info += ';BACKMUTS=' + ','.join(cladeBackMuts)
                 outV.write('\t'.join([ chrom,
                                        '\t'.join(map(str, pv)),
                                        '\t'.join(['.', 'PASS', info, 'GT']),
                                        '\t'.join(genotypes) ]) + '\n')
         outV.close()
 
 # BED+ file for clades
 with open('nextstrainClade.bed', 'w') as outC:
     for name, clade in clades.items():
         if (clade.get('thickStart')):
             outC.write('\t'.join(map(str,
                                      [ chrom, 0, 29903, name, 0, '.',
                                        clade['thickStart'], clade['thickEnd'], clade['color'],
                                        len(clade['varSizes']) + 2,
                                        '1,' + ','.join(map(str, clade['varSizes'])) + ',1,',
@@ -351,27 +433,33 @@
             cladeName = nodeAttrs['clade_membership']['value']
             color = str(cladeRgbFromName(cladeName))
         elif (parentColor):
             color = parentColor
         else:
             color = '0'
         descendants = ','.join([ rNextstrainToNewick(child, color) for child in kids ])
         treeString = '(' + descendants + ')' + ':' + color
     else:
         nodeAttrs = node['node_attrs']
         gId = nodeAttrs['gisaid_epi_isl']['value']
         name = node['name']
         date = numDateToMonthDay(nodeAttrs['num_date']['value'])
         cladeName = nodeAttrs['clade_membership']['value']
         color = str(cladeRgbFromName(cladeName))
-        treeString = '|'.join([ gId, name, date ]) + ':' + color
+        treeString = sampleName({ 'id': gId, 'name': name, 'date': date }) + ':' + color
     return treeString
 
 with open('nextstrain.nh', 'w') as outF:
     outF.write(rNextstrainToNewick(ncov['tree']) + ';\n')
-    outF.close
+    outF.close()
 
 for cladeName, node in cladeNodes.items():
     filename = 'nextstrain' + cladeName + '.nh'
     with open(filename, 'w') as outF:
         outF.write(rNextstrainToNewick(node) + ';\n')
-        outF.close
+        outF.close()
+
+# File with samples and their variant paths
+with open('nextstrainSamples.varPaths', 'w') as outF:
+    for sample in samples:
+        outF.write('\t'.join([sampleName(sample), sample['clade'], sample['varStr']]) + '\n');
+    outF.close()