1af2fddea7e6ede49cc2b334a6c8e78df998c9fa
angie
  Mon Jun 8 16:15:28 2020 -0700
Nextstrain changed to a new set of clade definitions, but they still keep the old clade assignments around in a different tag.  Make subtracks for both old and new sets of clades.  refs #25188

diff --git src/hg/utils/otto/nextstrainNcov/nextstrain.py src/hg/utils/otto/nextstrainNcov/nextstrain.py
index 37b0be1..710773c 100755
--- src/hg/utils/otto/nextstrainNcov/nextstrain.py
+++ src/hg/utils/otto/nextstrainNcov/nextstrain.py
@@ -17,42 +17,50 @@
     if (gene != 'nuc'):
         genePos[gene] = anno['start'] - 1
         geneBed.append([chrom, anno['start']-1, anno['end'], gene])
 def bedStart(bed):
     return bed[1]
 geneBed.sort(key=bedStart)
 with open('nextstrainGene.bed', 'w') as outG:
     for bed in geneBed:
         outG.write('\t'.join(map(str, bed)) + '\n')
 
 # Variants and "clades"
 
 snvRe = re.compile('^([ACGT])([0-9]+)([ACGT])$')
 snvAaRe = re.compile('^([A-Z*])([0-9]+)([A-Z*])$')
 
-clades = {}
-cladeNodes = {}
+newClades = {}
+oldClades = {}
 variantCounts = {}
 variantAaChanges = {}
 samples = []
 
-cladeColors = { 'A1a': '73,75,225', 'A2': '75,131,233', 'A2a': '92,173,207',
+# Clades from March 15th, 2020 to early morning of June 2nd, 2020:
+oldCladeColors = { 'A1a': '73,75,225', 'A2': '75,131,233', 'A2a': '92,173,207',
                    'A3': '119,199,164', 'A6': '154,212,122', 'A7': '173,189,81',
                    'B': '233,205,74', 'B1': '255,176,65', 'B2': '255,122,53',
                    'B4': '249,53,41' }
 
-def cladeColorFromName(cladeName):
+# Clades from late morning of June 2nd, 2020:
+newCladeColors = { '19A': '76,135,232',
+                   '19B': '110,194,178',
+                   '20A': '168,214,110',
+                   '20B': '232,206,75',
+                   '20C': '255,146,58' }
+
+def cladeColorFromName(cladeName, cladeColors):
     color = cladeColors.get(cladeName);
     if (not color):
         color = '0,0,0'
     return color
 
 def cladeFromVariants(name, variants, varStr):
     """Extract bed12 info from an object whose keys are SNV variant names"""
     clade = {}
     snvEnds = []
     # Watch out for back-mutations which make invalid BED because then we have multiple "blocks"
     # at the same position.  Instead, make a back-mutation cancel out the mutation because the
     # mutation is not found at this node.
     changesByPos = defaultdict(list)
     ixsToRemove = []
     for varName in variants:
@@ -70,51 +78,61 @@
             else:
                 ix = len(snvEnds)
                 changesByPos[pos] = (ix, ref, alt)
                 snvEnds.append(int(pos))
     if ixsToRemove:
         ixsToRemove.sort(reverse=True)
         for ix in ixsToRemove:
             del snvEnds[ix]
     if snvEnds:
         snvEnds.sort()
         snvStarts = [ e-1 for e in snvEnds ]
         snvSizes = [ 1 for e in snvEnds ]
         clade['thickStart'] = min(snvStarts)
         clade['thickEnd'] = max(snvEnds)
         clade['name'] = name
-        clade['color'] = cladeColorFromName(name)
         clade['varSizes'] = snvSizes
         clade['varStarts'] = snvStarts
         clade['varNames'] = varStr
     return clade
 
 def addDatesToClade(clade, numDateAttrs):
     """Add the numeric dates from ncov.json node_attrs.num_date to clade record"""
     clade['dateInferred'] = numDateAttrs['value']
     clade['dateConfMin'] = numDateAttrs['confidence'][0]
     clade['dateConfMax'] = numDateAttrs['confidence'][1]
 
 def addCountryToClade(clade, countryAttrs):
     """Add country data from ncov.json node_attrs.country to clade"""
     clade['countryInferred'] = countryAttrs['value']
-    confString = ''
-    for country, conf in countryAttrs['confidence'].items():
-        if (len(confString)):
-            confString += ', '
-        confString += "%s: %0.5f" % (country, conf)
-    clade['countryConf'] = confString
+    conf = countryAttrs.get('confidence')
+    clade['countryConf'] = ', '.join([ "%s: %0.5f" % (country, conf)
+                                       for country, conf in conf.items()]) if conf else ''
+
+def processClade(branch, tag, branchVariants, branchVarStr, clades):
+    """If this is the first time we've seen (old or new) clade, add it to clades"""
+    nodeAttrs = branch['node_attrs']
+    if (nodeAttrs.get(tag)):
+        cladeName = nodeAttrs[tag]['value']
+        if (cladeName != 'unassigned' and not cladeName in clades):
+            clades[cladeName] = cladeFromVariants(cladeName, branchVariants, branchVarStr)
+            addDatesToClade(clades[cladeName], nodeAttrs['num_date'])
+            if (nodeAttrs.get('country')):
+                addCountryToClade(clades[cladeName], nodeAttrs['country'])
+            elif (nodeAttrs.get('division')):
+                addCountryToClade(clades[cladeName], nodeAttrs['division'])
+            clades[cladeName]['topNode'] = branch
 
 def numDateToYmd(numDate):
     """Convert numeric date (decimal year) to integer year, month, day"""
     year = int(numDate)
     isLeapYear = 1 if (year % 4 == 0) else 0
     # Get rid of the year
     numDate -= year
     # Convert to Julian day
     daysInYear = 366 if isLeapYear else 365
     jDay = int(numDate * daysInYear) + 1
     if (jDay > 334 + isLeapYear):
         month, day = 11, (jDay - 334 - isLeapYear)
     elif (jDay > 304 + isLeapYear):
         month, day = 10, (jDay - 304 - isLeapYear)
     elif (jDay > 273 + isLeapYear):
@@ -136,32 +154,35 @@
     elif (jDay > 31):
         month, day = 1, (jDay - 31)
     else:
         month, day = 0, jDay
     return year, month, day
 
 months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'];
 
 def numDateToMonthDay(numDate):
     """Transform decimal year timestamp to string with only month and day"""
     year, month, day = numDateToYmd(numDate)
     return months[month] + str(day)
 
 def numDateToYmdStr(numDate):
     """Convert decimal year to YY-MM-DD string"""
+    if numDate:
         year, month, day = numDateToYmd(numDate)
         return "%02d-%02d-%02d" % (year, month+1, day)
+    else:
+        return ''
 
 def rUnpackNextstrainTree(branch, parentVariants, parentVarStr):
     """Recursively descend ncov.tree and build data structures for genome browser tracks"""
     # Gather variants specific to this node/branch (if any)
     localVariants = []
     if (branch.get('branch_attrs') and branch['branch_attrs'].get('mutations') and
         branch['branch_attrs']['mutations'].get('nuc')):
         # Nucleotide variants specific to this branch
         for varName in branch['branch_attrs']['mutations']['nuc']:
             if (snvRe.match(varName)):
                 localVariants.append(varName)
                 if (not variantCounts.get(varName)):
                     variantCounts[varName] = 0;
         # Amino acid variants: figure out which nucleotide variant goes with each
         for geneName in branch['branch_attrs']['mutations'].keys():
@@ -190,57 +211,52 @@
         branchVariants[varName] = 1
     # Make an ordered variant string as David requested: semicolons between nodes,
     # comma-separated within a node:
     branchVarStr = ''
     for varName in localVariants:
         if (len(branchVarStr)):
             branchVarStr += ', '
         branchVarStr += varName
         aaVar = variantAaChanges.get(varName)
         if (aaVar):
             branchVarStr += ' (' + aaVar + ')'
     if (len(parentVarStr) and len(branchVarStr)):
         branchVarStr = parentVarStr + '; ' + branchVarStr
     elif (not len(branchVarStr)):
         branchVarStr = parentVarStr
-    nodeAttrs = branch['node_attrs']
-    if (nodeAttrs.get('clade_membership')):
-        cladeName = nodeAttrs['clade_membership']['value']
-        if (cladeName != 'unassigned' and not cladeName in clades):
-            clades[cladeName] = cladeFromVariants(cladeName, branchVariants, branchVarStr)
-            addDatesToClade(clades[cladeName], nodeAttrs['num_date'])
-            if (nodeAttrs.get('country')):
-                addCountryToClade(clades[cladeName], nodeAttrs['country'])
-            elif (nodeAttrs.get('division')):
-                addCountryToClade(clades[cladeName], nodeAttrs['division'])
-            else:
-                warn('No country or division for new clade ' + cladeName)
-            cladeNodes[cladeName] = branch
+    processClade(branch, 'clade_membership', branchVariants, branchVarStr, newClades)
+    processClade(branch, 'legacy_clade_membership', branchVariants, branchVarStr, oldClades)
     kids = branch.get('children')
     if (kids):
         for child in kids:
             rUnpackNextstrainTree(child, branchVariants, branchVarStr);
     else:
         for varName in branchVariants:
             variantCounts[varName] += 1
+        nodeAttrs = branch['node_attrs']
         if (nodeAttrs.get('submitting_lab')):
             lab = nodeAttrs['submitting_lab']['value']
         else:
             lab = ''
+        if (nodeAttrs.get('legacy_clade_membership')):
+            oldClade = nodeAttrs['legacy_clade_membership']['value']
+        else:
+            oldClade = ''
         samples.append({ 'id': nodeAttrs['gisaid_epi_isl']['value'],
                          'name': branch['name'],
                          'clade': nodeAttrs['clade_membership']['value'],
+                         'oldClade': oldClade,
                          'date': numDateToMonthDay(nodeAttrs['num_date']['value']),
                          'lab': lab,
                          'variants': branchVariants,
                          'varStr': branchVarStr })
 
 rUnpackNextstrainTree(ncov['tree'], {}, '')
 
 def sampleName(sample):
     return '|'.join([sample['id'], sample['name'], sample['date']])
 
 sampleCount = len(samples)
 sampleNames = [ sampleName(sample)  for sample in samples ]
 
 # Parse variant names like 'G11083T' into pos and alleles; bundle in VCF column order
 parsedVars = []
@@ -324,76 +340,84 @@
     aaChanges = []
     for varName in varNames:
         aaChange = variantAaChanges.get(varName)
         if (aaChange):
             gotAaChange = 1
             aaChanges.append(aaChange)
         else:
             aaChanges.append('-');
     if (gotAaChange):
         aaChangeStr = ','.join(aaChanges)
     else:
         aaChangeStr = ''
     return aaChangeStr
 
 # VCF
-with open('nextstrainSamples.vcf', 'w') as outC:
+def vcfForAll(fileName, cladeTag):
+    with open(fileName, 'w') as outC:
         writeVcfHeaderExceptSamples(outC)
         outC.write('\t'.join(sampleNames) + '\n')
         for mv in mergedVars:
             pv, alts, altCounts, sampleAlleles, backMutSamples = mv
             info = 'AC=' + ','.join(map(str, altCounts)) + ';AN=' + str(sampleCount)
             varNameMerged = pv[1]
             aaChange = tallyAaChanges(varNameMerged)
             if (len(aaChange)):
                 info += ';AACHANGE=' + aaChange
             if (len(backMutSamples)):
                 info += ';BACKMUTS=' + ','.join(backMutSamples)
             genotypes = []
             for sample, alIx in zip(samples, sampleAlleles):
                 gt = str(alIx)
-            genotypes.append(gt + ':' + sample['clade'])
+                genotypes.append(gt + ':' + sample[cladeTag])
             outC.write('\t'.join([ chrom,
                                    '\t'.join(map(str, pv)),
                                    '\t'.join(['.', 'PASS', info, 'GT:CLADE']),
                                    '\t'.join(genotypes) ]) + '\n')
 
+vcfForAll('nextstrainSamples.vcf', 'clade')
+# I'll skip writing an enormous file with wasteful clades-in-genotypes... really need
+# a sample metadata file!
+
 # Assign samples to clades; a sample can appear in multiple clades if they are nested.
-cladeSamples = {}
-cladeSampleCounts = {}
-cladeSampleNames = {}
 
-def sampleIdsFromNode(node, ids):
-    """Fill in a dict of IDs of all samples found under node."""
+def sampleIdsFromNode(node, cladeTops=()):
+    """Return a list of IDs of all samples found under node."""
     kids = node.get('children')
     if (kids):
+        sampleIds = []
         for kid in kids:
-            sampleIdsFromNode(kid, ids)
+            if (kid not in cladeTops):
+                sampleIds += sampleIdsFromNode(kid, cladeTops)
     else:
         sampleId = node['node_attrs']['gisaid_epi_isl']['value']
-        ids[sampleId] = 1
+        sampleIds = [sampleId]
+    return sampleIds
 
-for cladeName, node in cladeNodes.items():
-    sampleIds = {}
-    sampleIdsFromNode(node, sampleIds)
-    cladeSampleList = [ sample for sample in samples if sample['id'] in sampleIds ]
-    cladeSamples[cladeName] = sampleIds
+cladeSampleCounts = {}
+cladeSampleNames = {}
+
+def vcfForClades(clades, cladeTops=()):
+    """Given a set of clades (old or new), dump out VCF for each clade.
+    Stop at nodes in cladeTops (for Nextstrain's new clade scheme where 19A is root, 20A fully
+    contains 20B and 20C, etc."""
+    for cladeName in clades:
+        node = clades[cladeName]['topNode']
+        cladeSampleIds = set(sampleIdsFromNode(node, cladeTops))
+        cladeSampleList = [ sample for sample in samples if sample['id'] in cladeSampleIds ]
         cladeSampleCounts[cladeName] = len(cladeSampleList)
         cladeSampleNames[cladeName] = [ sampleName(sample) for sample in cladeSampleList ]
-
-# Per-clade VCF subset
-for cladeName, cladeSampleIds in cladeSamples.items():
         with open('nextstrainSamples' + cladeName + '.vcf', 'w') as outV:
             writeVcfHeaderExceptSamples(outV)
             outV.write('\t'.join(cladeSampleNames[cladeName]) + '\n')
             for mv in mergedVars:
                 pv, alts, overallAltCounts, sampleAlleles, backMutSamples = mv
                 varNameMerged = pv[1]
                 genotypes = []
                 altCounts = [ 0 for alt in alts ]
                 acTotal=0
                 for sample, alIx in zip(samples, sampleAlleles):
                     if (sample['id'] in cladeSampleIds):
                         gt = str(alIx)
                         genotypes.append(gt + ':' + cladeName)
                         if (alIx > 0):
                             altCounts[alIx - 1] += 1
@@ -401,103 +425,144 @@
                 if (acTotal > 0):
                     info = 'AC=' + ','.join(map(str, altCounts))
                     info += ';AN=' + str(cladeSampleCounts[cladeName])
                     aaChange = tallyAaChanges(varNameMerged)
                     if (len(aaChange)):
                         info += ';AACHANGE=' + aaChange
                     cladeBackMuts = [ sampleName for sampleName in backMutSamples
                                       if sampleName in cladeSampleNames[cladeName] ]
                     if (len(cladeBackMuts)):
                         info += ';BACKMUTS=' + ','.join(cladeBackMuts)
                     outV.write('\t'.join([ chrom,
                                            '\t'.join(map(str, pv)),
                                            '\t'.join(['.', 'PASS', info, 'GT:CLADE']),
                                            '\t'.join(genotypes) ]) + '\n')
 
-# BED+ file for clades
-with open('nextstrainClade.bed', 'w') as outC:
+def bedForClades(fileName, clades, cladeColors):
+    """Make a BED file summarizing each clade"""
+    with open(fileName, 'w') as outC:
         for name, clade in clades.items():
-        if (clade.get('thickStart')):
+            if (not clade.get('thickStart')):
+                # "Clade" 19A encompasses the entire tree (minus the parts assigned to
+                # other "clades").  It has no identifying variants, and (as of June 7)
+                # no dates assigned.
+                clade['thickStart'] = clade['thickEnd'] = 0
+                clade['varStarts'] = clade['varSizes'] = []
+                clade['varNames'] = ''
+                clade['dateInferred'] = clade['dateConfMin'] = clade['dateConfMax'] = 0
+            countryConf = clade.get('countryConf')
+            if (not countryConf):
+                countryConf = ''
+            countryInferred = clade.get('countryInferred')
+            if (not countryInferred):
+                countryInferred = ''
             outC.write('\t'.join(map(str,
                                      [ chrom, 0, 29903, name, 0, '.',
-                                       clade['thickStart'], clade['thickEnd'], clade['color'],
+                                       clade['thickStart'], clade['thickEnd'],
+                                       cladeColorFromName(name, cladeColors),
                                        len(clade['varSizes']) + 2,
-                                       '1,' + ','.join(map(str, clade['varSizes'])) + ',1,',
-                                       '0,' + ','.join(map(str, clade['varStarts'])) + ',29902,',
+                                       ','.join(map(str, ([1] + clade['varSizes']) + [1])),
+                                       ','.join(map(str, ([0] + clade['varStarts']) + [29902])),
                                        clade['varNames'],
                                        numDateToYmdStr(clade['dateInferred']),
                                        numDateToYmdStr(clade['dateConfMin']),
                                        numDateToYmdStr(clade['dateConfMax']),
-                                       clade['countryInferred'],
-                                       clade['countryConf'],
+                                       countryInferred,
+                                       countryConf,
                                        cladeSampleCounts[name],
                                        ', '.join(cladeSampleNames[name]) ])) + '\n')
 
+if (len(oldClades) == 0):
+    # This ncov.json must be from when old clades were the new clades, and the new clades
+    # had not yet arrived.  Revert to old colors and don't exclude subclades.
+    newCladeColors = oldCladeColors
+    newCladeTops = ()
+else:
+    newCladeTops = [ newClades[cladeName]['topNode'] for cladeName in newClades ]
+vcfForClades(newClades, newCladeTops)
+bedForClades('nextstrainClade.bed', newClades, newCladeColors)
+if (len(oldClades)):
+    vcfForClades(oldClades)
+    bedForClades('nextstrainOldClade.bed', oldClades, oldCladeColors)
+
 # Newick-formatted tree of samples for VCF display
-def cladeRgbFromName(cladeName):
+def cladeRgbFromName(cladeName, cladeColors):
     """Look up the r,g,b string color for clade; convert to int RGB."""
-    rgbCommaStr = cladeColorFromName(cladeName)
+    rgbCommaStr = cladeColorFromName(cladeName, cladeColors)
     r, g, b = [ int(x) for x in rgbCommaStr.split(',') ]
     rgb = (r << 16) | (g << 8) | b
     return rgb
 
-def rNextstrainToNewick(node, parentClade=None, parentVarStr=''):
-    """Recursively descend ncov.tree and build Newick tree string of samples to file"""
+def rNextstrainToNewick(node, cladeColors, cladeTops=(), parentClade=None, parentVarStr=''):
+    """Recursively descend ncov.tree and build Newick tree string of samples to file.
+    Exclude nodes in cladeTops."""
     kids = node.get('children')
     if (kids):
         # Make a more concise variant path string than the one we make for the clade track,
         # to embed in internal node labels for Yatish's tree explorations.
         localVariants = []
         if (node.get('branch_attrs') and node['branch_attrs'].get('mutations') and
             node['branch_attrs']['mutations'].get('nuc')):
             # Nucleotide variants specific to this branch
             for varName in node['branch_attrs']['mutations']['nuc']:
                 if (snvRe.match(varName)):
                     localVariants.append(varName)
         varStr = '+'.join(localVariants)
         if (len(parentVarStr) and len(varStr)):
             varStr = '$'.join([parentVarStr, varStr])
         elif (not len(varStr)):
             varStr = parentVarStr
         nodeAttrs = node['node_attrs']
         if (nodeAttrs.get('clade_membership')):
             cladeName = nodeAttrs['clade_membership']['value']
         elif (parentClade):
             cladeName = parentClade
         else:
             cladeName = 'unassigned'
-        color = str(cladeRgbFromName(cladeName))
-        descendants = ','.join([ rNextstrainToNewick(child, cladeName, varStr) for child in kids ])
+        color = str(cladeRgbFromName(cladeName, cladeColors))
+        descendants = ','.join([ rNextstrainToNewick(child, cladeColors, cladeTops, cladeName,
+                                                     varStr)
+                                 for child in kids if child not in cladeTops ])
         label = '#'.join([cladeName, varStr])
         treeString = '(' + descendants + ')' + label + ':' + color
     else:
         nodeAttrs = node['node_attrs']
         gId = nodeAttrs['gisaid_epi_isl']['value']
         name = node['name']
         date = numDateToMonthDay(nodeAttrs['num_date']['value'])
         cladeName = nodeAttrs['clade_membership']['value']
-        color = str(cladeRgbFromName(cladeName))
+        color = str(cladeRgbFromName(cladeName, cladeColors))
         treeString = sampleName({ 'id': gId, 'name': name, 'date': date }) + ':' + color
     return treeString
 
 with open('nextstrain.nh', 'w') as outF:
-    outF.write(rNextstrainToNewick(ncov['tree']) + ';\n')
+    outF.write(rNextstrainToNewick(ncov['tree'], newCladeColors) + ';\n')
 
-for cladeName, node in cladeNodes.items():
+if (len(oldClades)):
+    with open('nextstrainOldCladeColors.nh', 'w') as outF:
+        outF.write(rNextstrainToNewick(ncov['tree'], oldCladeColors) + ';\n')
+
+def newickForClades(clades, cladeColors, cladeTops=()):
+    for cladeName in clades:
         filename = 'nextstrain' + cladeName + '.nh'
+        node = clades[cladeName]['topNode']
         with open(filename, 'w') as outF:
-        outF.write(rNextstrainToNewick(node) + ';\n')
+            outF.write(rNextstrainToNewick(node, cladeColors, cladeTops) + ';\n')
+
+newickForClades(newClades, newCladeColors, newCladeTops)
+if (len(oldClades)):
+    newickForClades(oldClades, oldCladeColors)
 
 # File with samples and their clades, labs and variant paths
 
 apostropheSRe = re.compile("'s");
 firstLetterRe = re.compile('(\w)\w+');
 spacePunctRe = re.compile('\W');
 
 def abbreviateLab(lab):
     """Lab names are very long and sometimes differ by punctuation or typos.  Abbreviate for easier comparison."""
     labAbbrev = apostropheSRe.sub('', lab)
     labAbbrev = firstLetterRe.sub(r'\1', labAbbrev, count=0)
     labAbbrev = spacePunctRe.sub('', labAbbrev, count=0)
     return labAbbrev
 
 with open('nextstrainSamples.varPaths', 'w') as outF: