3aec6ebceadc5dee647d4c1b294b3a3d81a391e6
angie
  Wed Mar 25 17:38:38 2020 -0700
Clade variant list format change from David: ;-separated for different nodes, ,-sep for within node; also added amino acid changes when applicable.  refs #25188

diff --git src/hg/utils/otto/nextstrainNcov/nextstrain.py src/hg/utils/otto/nextstrainNcov/nextstrain.py
index 8a239b5..e3cd141 100755
--- src/hg/utils/otto/nextstrainNcov/nextstrain.py
+++ src/hg/utils/otto/nextstrainNcov/nextstrain.py
@@ -36,51 +36,51 @@
 
 cladeColors = { 'A1a': '73,75,225', 'A2': '75,131,233', 'A2a': '92,173,207',
                 'A3': '119,199,164', 'A6': '154,212,122', 'A7': '173,189,81',
                 'B': '233,205,74', 'B1': '255,176,65', 'B2': '255,122,53',
                 'B4': '249,53,41' }
 
 def cladeColorFromName(cladeName):
     color = cladeColors.get(cladeName);
     if (not color):
         color = 'purple'
     return color
 
 def subtractStart(coord, start):
     return coord - start
 
-def cladeFromVariants(name, variants):
+def cladeFromVariants(name, variants, varStr):
     """Extract bed12 info from an object whose keys are SNV variant names"""
     clade = {}
     snvEnds = []
     varNames = []
     for varName in variants:
         m = snvRe.match(varName)
         if (m):
             snvEnds.append(int(m.group(2)))
             varNames.append(varName)
     if snvEnds:
         snvEnds.sort()
         snvStarts = list(map(lambda x: x-1, snvEnds))
         snvSizes = list(map(lambda x: 1, snvEnds))
         clade['thickStart'] = min(snvStarts)
         clade['thickEnd'] = max(snvEnds)
         clade['name'] = name
         clade['color'] = cladeColorFromName(name)
         clade['varSizes'] = snvSizes
         clade['varStarts'] = snvStarts
-        clade['varNames'] = varNames
+        clade['varNames'] = varStr
     return clade
 
 def numDateToMonthDay(numDate):
     """Transform decimal year timestamp to string with only month and day"""
     year = int(numDate)
     isLeapYear = 1 if (year % 4 == 0) else 0
     # Get rid of the year
     numDate -= year
     # Convert to Julian day
     jDay = int(numDate * 365) + 1
     if (jDay > 334 + isLeapYear):
         monthDay ="Dec" + str(jDay - 334 - isLeapYear)
     elif (jDay > 304 + isLeapYear):
         monthDay ="Nov" + str(jDay - 304 - isLeapYear)
     elif (jDay > 273 + isLeapYear):
@@ -93,94 +93,109 @@
         monthDay ="Jul" + str(jDay - 181 - isLeapYear)
     elif (jDay > 151 + isLeapYear):
         monthDay ="Jun" + str(jDay - 151 - isLeapYear)
     elif (jDay > 120 + isLeapYear):
         monthDay ="May" + str(jDay - 120 - isLeapYear)
     elif (jDay > 90 + isLeapYear):
         monthDay ="Apr" + str(jDay - 90 - isLeapYear)
     elif (jDay > 59 + isLeapYear):
         monthDay ="Mar" + str(jDay - 59 - isLeapYear)
     elif (jDay > 31):
         monthDay ="Feb" + str(jDay - 31)
     else:
         monthDay ="Jan" + str(jDay)
     return monthDay
 
-def rUnpackNextstrainTree(branch, parentVariants):
+def rUnpackNextstrainTree(branch, parentVariants, parentVarStr):
     """Recursively descend ncov.tree and build data structures for genome browser tracks"""
-    # Inherit parent variants
-    branchVariants = parentVariants.copy()
-    # Add variants specific to this branch (if any)
-    try:
+    # Gather variants specific to this node/branch (if any)
+    localVariants = []
+    if (branch.get('branch_attrs') and branch['branch_attrs'].get('mutations') and
+        branch['branch_attrs']['mutations'].get('nuc')):
         # Nucleotide variants specific to this branch
         for varName in branch['branch_attrs']['mutations']['nuc']:
             if (snvRe.match(varName)):
-                branchVariants[varName] = 1
+                localVariants.append(varName)
                 if (not variantCounts.get(varName)):
                     variantCounts[varName] = 0;
-        # Amino acid variants specific to this branch
+        # Amino acid variants: figure out which nucleotide variant goes with each
         for geneName in branch['branch_attrs']['mutations'].keys():
             if (geneName != 'nuc'):
                 for change in branch['branch_attrs']['mutations'][geneName]:
                     # Get nucleotide coords & figure out which nuc var this aa change corresponds to
                     aaM = snvAaRe.match(change)
                     if (aaM):
                         aaRef, aaPos, aaAlt = aaM.groups()
                         varStartMin = (int(aaPos) - 1) * 3
                         if (genePos.get(geneName)):
                             cdsStart = genePos.get(geneName)
                             varStartMin += cdsStart
                             varStartMax = varStartMin + 2
-                            for varName in branchVariants.keys():
+                            for varName in localVariants:
                                 ref, pos, alt = snvRe.match(varName).groups()
-                                pos = int(pos)
+                                pos = int(pos) - 1
                                 if (pos >= varStartMin and pos <= varStartMax):
                                     variantAaChanges[varName] = geneName + ':' + change
                         else:
                             warn("Can't find start for gene " + geneName)
                     else:
                         warn("Can't match amino acid change" + change)
-    except KeyError:
-        pass
+    # Inherit parent variants
+    branchVariants = parentVariants.copy()
+    # Add variants specific to this branch (if any)
+    for varName in localVariants:
+        branchVariants[varName] = 1
+    # Make an ordered variant string as David requested: semicolons between nodes,
+    # comma-separated within a node:
+    branchVarStr = ''
+    for varName in localVariants:
+        if (len(branchVarStr)):
+            branchVarStr += ', '
+        branchVarStr += varName
+        aaVar = variantAaChanges.get(varName)
+        if (aaVar):
+            branchVarStr += ' (' + aaVar + ')'
+    if (len(parentVarStr) and len(branchVarStr)):
+        branchVarStr = parentVarStr + '; ' + branchVarStr
     if (branch['node_attrs'].get('clade_membership')):
         cladeName = branch['node_attrs']['clade_membership']['value']
         if (not cladeName in clades):
-            clades[cladeName] = cladeFromVariants(cladeName, branchVariants)
+            clades[cladeName] = cladeFromVariants(cladeName, branchVariants, branchVarStr)
     kids = branch.get('children')
     if (kids):
         for child in kids:
-            rUnpackNextstrainTree(child, branchVariants);
+            rUnpackNextstrainTree(child, branchVariants, branchVarStr);
     else:
         for varName in branchVariants:
             variantCounts[varName] += 1
         date = numDateToMonthDay(branch['node_attrs']['num_date']['value'])
         samples.append({ 'id': branch['node_attrs']['gisaid_epi_isl']['value'],
                          'name': branch['name'],
                          'clade': branch['node_attrs']['clade_membership']['value'],
                          'date': date,
                          'variants': branchVariants })
         if (cladeName):
             if (clades[cladeName].get('sampleCount')):
                 clades[cladeName]['sampleCount'] += 1
             else:
                 clades[cladeName]['sampleCount'] = 1
             if (clades[cladeName].get('samples')):
                 clades[cladeName]['samples'].append(branch['name'])
             else:
                 clades[cladeName]['samples'] = [ branch['name'] ]
 
-rUnpackNextstrainTree(ncov['tree'], {})
+rUnpackNextstrainTree(ncov['tree'], {}, '')
 
 sampleCount = len(samples)
 sampleNames = [ ':'.join([sample['id'], sample['name'], sample['date']]) for sample in samples ]
 
 # Parse variant names like 'G11083T' into pos and alleles; bundle in VCF column order
 parsedVars = []
 for varName in variantCounts.keys():
     m = snvRe.match(varName)
     if (m):
         ref, pos, alt = m.groups()
         parsedVars.append([int(pos), varName, ref, alt])
     else:
         warn("Can't match " + varName)
 # Sort by position
 def parsedVarPos(pv):
@@ -211,19 +226,19 @@
         outC.write('\t'.join([ chrom,
                                '\t'.join(map(str, pv)),
                                '\t'.join(['.', 'PASS', info, 'GT:CLADE']),
                                '\t'.join(genotypes) ]) + '\n')
     outC.close()
 
 with open('nextstrainClade.bed', 'w') as outC:
     for name, clade in clades.items():
         if (clade.get('thickStart')):
             outC.write('\t'.join(map(str,
                                      [ chrom, 0, 29903, name, 0, '.',
                                        clade['thickStart'], clade['thickEnd'], clade['color'],
                                        len(clade['varSizes']) + 2,
                                        '1,' + ','.join(map(str, clade['varSizes'])) + ',1,',
                                        '0,' + ','.join(map(str, clade['varStarts'])) + ',29902,',
-                                       ', '.join(clade['varNames']),
+                                       clade['varNames'],
                                        clade['sampleCount'],
                                        ', '.join(clade['samples']) ])) + '\n')
     outC.close()