b88380414ddecaab81f5c19b9f0cf06dbbedc787
max
  Fri May 26 06:36:28 2017 -0700
changing python hash bang again for hgGeneGraph, tiny wrangling fix, refs #13634

diff --git src/utils/ggTables src/utils/ggTables
index bb07669..7dbe660 100755
--- src/utils/ggTables
+++ src/utils/ggTables
@@ -1,1255 +1,1267 @@
 #!/usr/bin/env python2.7
 
 import logging, sys, optparse, glob, itertools, os, tempfile, gzip, operator, re, cPickle, gc
 import marshal
 from collections import defaultdict, namedtuple, Counter
 from itertools import chain, combinations
 from os.path import join, basename, dirname, isfile, expanduser
 import ujson
 
 # saves 20% of time when loading graph marshal
 import gc
 gc.disable()
 
 #LTPMAXMEMBERS=5 # maximum number of proteins in a complex for an interaction to quality for low throughput
 LTPMAX=5 # maximum number of interactions a PMID can have to be declared low throughput
 
 # don't even write out links with less than this number of documents. 2 = weeds out many false positives.
 # not using right now, because a few pathway databases do not annotate ANY PMID. Increasing this filter would remove
 # all interactions from these pathways databases.
 MINSUPP=0
 
 outFields = ["gene1", "gene2", "flags", "refs", "fwWeight", "revWeight", "snip"]
 
 # directory with autoSql descriptions of output tables
 autoSqlDir = expanduser("~/kent/src/hg/lib/")
 
 # file with all of medline in short form
 allArtFname = "textInfo.tab"
 # file with just pmids and events
 #pmidEventFname = "temp.pmidToEvent.tab"
 
 # RE to split sentences
 wordRe = re.compile("[a-zA-Z0-9]+")
 
 # === COMMAND LINE INTERFACE, OPTIONS AND HELP ===
 parser = optparse.OptionParser("""usage: %prog [options] build|load pathwayDir ppiDir textDir outDir - given various tab sep files with text-mining, gene interaction or pathway information, build the  table ggLink, ggDoc, ggDb and ggText
 
 run it like this:
 
-%prog medline - to reduce the big medline table to something smaller, only needed once
+Reduce the big medline table to something smaller, only needed once:
+    %prog medline
 
+Slowest part: build the big table of interactions mysql/ggLink.tmp.tab
     %prog build pathways ppi text mysql
-%prog docs mysql  # creates the ggDocs.tab file, slow
+
+Create mysql/ggDocs.tab, very slow
+    %prog docs mysql
+
+Add the "context" (aka mesh terms) to mysql/ggLink.tmp.tab 
+and write to mysql/ggLink.tab file
     %prog context mysql  
-%prog load mysql publications
+format is:
+gene1, gene2, flags, forwDocCount, revDocCount, allDocCount, databaseList, minimalPairCountPaper, snippet
+
+Load all tables in mysql/ into MySql:
+    %prog load mysql hgFixed
+
+Create the bigBed File
     %prog bigBed outDir bigBedFile db
 """)
 
 parser.add_option("-d", "--debug", dest="debug", action="store_true", help="show debug messages") 
 #parser.add_option("-t", "--test", dest="test", action="store_true", help="run tests") 
 parser.add_option("-t", "--textDir", dest="textDir", action="store", help="directory with the parsed copy of medline, default %default", default="/hive/data/inside/pubs/text/medline")
 parser.add_option("-m", "--meshFile", dest="meshFname", action="store", help="An mtrees<year>.bin file, default %default", default="/hive/data/outside/ncbi/mesh/mtrees2015.bin")
 parser.add_option("-j", "--journalInfo", dest="journalInfo", action="store", help="tab-sep file with journal info from the NLM Catalog converted by 'pubPrepCrawl publishers'. Used to shorten the journal names. Optional and not used if file is not found. Default %default", default="/cluster/home/max/projects/pubs/tools/data/journals/journals.tab")
 parser.add_option("-b", "--wordFname", dest="wordFname", action="store", help="a file with common English words", default="/hive/data/outside/pubs/wordFrequency/bnc/bnc.txt")
 #parser.add_option("-f", "--file", dest="file", action="store", help="run on file") 
 #parser.add_option("", "--test", dest="test", action="store_true", help="do something") 
 (options, args) = parser.parse_args()
 
 if options.debug:
     logging.basicConfig(level=logging.DEBUG)
 else:
     logging.basicConfig(level=logging.INFO)
 # ==== FUNCTIONs =====
     
 def parseAutoSql(asFname):
     " parse auto sql file and return list of field names "
     headers = []
     for line in open(asFname):
         line = line.replace("(","").replace(")","").strip()
         if line.startswith('"') or line.startswith("\n"):
             continue
         if line.startswith("table"):
             continue
         if not ";" in line:
             continue
         parts = line.split(";")[0]
         parts = parts.split()
         assert(len(parts)>=2)
         headers.append(parts[1])
     return headers
 
 def lineFileNext(fh, headers=None, asFname=None):
     """ parses tab-sep file with headers as field names , assumes that file starts with headers
         yields collection.namedtuples
     """
     line1 = fh.readline()
     line1 = line1.strip("\n").strip("#")
     if headers==None and asFname!=None:
         headers = parseAutoSql(asFname)
     elif headers==None:
         headers = line1.split("\t")
         headers = [h.replace(" ", "_") for h in headers]
         headers = [h.replace("(", "") for h in headers]
         headers = [h.replace(")", "") for h in headers]
     Record = namedtuple('tsvRec', headers)
 
     for line in fh:
         line = line.rstrip("\n")
         fields = line.split("\t")
         try:
             rec = Record(*fields)
         except Exception, msg:
             logging.error("Exception occured while parsing line, %s" % msg)
             logging.error("Filename %s" % fh.name)
             logging.error("Line was: %s" % repr(line))
             logging.error("Does number of fields match headers?")
             logging.error("Headers are: %s" % headers)
             #raise Exception("wrong field count in line %s" % line)
             continue
         # convert fields to correct data type
         yield rec
 
 def loadFiles(inDir, prefix=None):
     """ load .tab files into dict fType -> list of rows and return tuple (ppiRows, textRows)
     """
     #typeRows = defaultdict(list)
     #jppiRows = list()
     #jtextRows = list()
     pairs = defaultdict(list)
     inFnames = glob.glob(join(inDir, "*.tab"))
     rows = list()
     for inFname in inFnames:
         logging.info("Loading %s" % inFname)
         #fType = None
         for row in lineFileNext(open(inFname)):
             if prefix!=None:
                 row = row._replace(eventId=prefix+row.eventId)
             rows.append(row)
     return rows
 
 def getResultCounts(pairs):
     """
     Input is a pair -> rows dictionary
     for each PMID, count how many pairs are assigned to it. This is
     something like the "resultCount" of a paper, the lower, the better.
     Then, for each pair, get the minimum resultCount and return as a dict
     pair -> resultCount
     """
     logging.info("Getting low-throughput studies in PPI and pathway data")
 
     # create dict doc -> set of gene pairs
     docToPairs = defaultdict(set)
     for pair, rows in pairs.iteritems():
         #print "pp", pair
         for row in rows:
-            #print "r", row
             #members = row.themeGenes.split("|")
             #members.extend(row.causeGenes.split("|"))
             # complexes with more than 5 proteins are not low-throughput anyways
             # skip these right away
             #if len(members)>LTPMAXMEMBERS:
                 #continue
             docIds = row.pmids.split("|")
             for docId in docIds:
                 if docId=="":
                     continue
                 docToPairs[docId].add(pair)
 
     #print "d2p", docToPairs
 
     pairMinResultCounts = {}
     for pair, rows in pairs.iteritems():
         #print "pp", pair
         resultCounts = []
         # get all docIds in rows
         docIds = []
         for row in rows:
             docIds.extend(row.pmids.split("|"))
 
         # get the minimal resultCount of all docIds 
         for docId in set(docIds):
             if docId=="":
                 continue
             #print "di2", docId
             #print "d2p", docToPairs[docId]
             resCount = len(docToPairs[docId])
             #print "rc", resCount
             resultCounts.append(resCount)
         if len(resultCounts)!=0:
             minResCount = min(resultCounts)
         else:
             minResCount = 0
         #print "min", minResCount
         pairMinResultCounts[pair] = minResCount
         
     #ltPairs = set()
     #ltDocs = []
     #for pmid, pairList in docToPairs.iteritems():
         #if len(pairList) <= LTPMAX:
             #ltPairs.update(pairList)
             #ltDocs.append(pmid)
 
     #logging.info("Found %d low-throughput studies out of %d" % (len(ltDocs), len(pmidToPairs)))
     #logging.info("Found %d low-throughput interactions out of %d" % (len(ltPairs), len(pairs)))
     #return ltPairs, ltDocs
     return pairMinResultCounts, docToPairs
             
 def allSubPairs(pair):
     """ given a pair of two strings, where each can be a _-separate list of genes (a family), 
     return all combinations of each member 
     """
     x, y = pair
     xs = x.split("_")
     ys = y.split("_")
     for subPair in [(a, b) for a in xs for b in ys]:
         a, b = subPair
         if a=="" or b=="" or a=="-" or b=="-" or \
                 a.startswith("compound") or b.startswith("compound"):
             continue
         yield subPair
     
 def iterAllPairs(row):
     """ yield all pairs of interacting genes for a given row. Handles families 
     >>> list(iterAllPairs("gene", ["TP1","TP2"], "complex", ["OMG1","OMG2"]))
     [('OMG1', 'OMG2'), ('TP1', 'OMG1'), ('TP1', 'OMG2'), ('TP2', 'OMG1'), ('TP2', 'OMG2')]
     >>> list(iterAllPairs("complex", ["TP1_TEST2","TP2"], "complex", ["OMG1","OMG2"]))
     """
 
     type1 = row.causeType
     type2 = row.themeType
     genes1 = set(row.causeGenes.split("|"))
     genes2 = set(row.themeGenes.split("|"))
     # all genes of complexes interact in some way
     if type1=="complex":
         for pair in itertools.combinations(genes1, 2):
             # a complex can contain families
             for subPair in allSubPairs(pair):
                 yield tuple(subPair)
     if type2=="complex":
         for pair in itertools.combinations(genes2, 2):
             # a complex can contain families
             for subPair in allSubPairs(pair):
                 yield tuple(subPair)
 
     # all genes from the left and the right side interact
     if type2!="":
         pairs = list([(aa, bb) for aa in genes1 for bb in genes2])
         for pair in pairs:
             for subPair in allSubPairs(pair):
                 gene1, gene2 = subPair
                 if gene1=="-" or gene1=="" or gene2=="" or gene2=="-":
                     #skipCount += 1
                     continue
                 if gene1.startswith("compound") or gene2.startswith("compound"):
                     continue
                 yield subPair
     
 def indexPairs(ppiRows, desc):
     """ given rows with theme and cause genes, return 
     a dict with sorted (gene1, gene2) -> list of eventIds """
     logging.info("enumerating all interacting pairs: %s" % desc)
     pairs = defaultdict(list)
     for row in ppiRows:
         for pair in iterAllPairs(row):
             gene1, gene2 = pair
             if gene1.startswith("compound") or gene2.startswith("compound"):
                 continue
             pairs[tuple(sorted(pair))].append(row)
     logging.info("got %d pairs" % len(pairs))
     return pairs
 
 def mergePairs(dicts):
     " merge a list of defaultdict(list) into one defaultdict(set) "
     logging.info("Merging pairs")
     data = defaultdict(set)
     for defDict in dicts:
         for key, valList in defDict.iteritems():
             data[key].update(valList)
     return data
 
 def directedPairToDocs(rows):
     """ get documents of text mining pairs. a dict with pair -> text rows
     create a dict pair -> set of document IDs. Note that these pairs are
     DIRECTED - so they can be used to infer the direction of the interaction
     """
     # create a dict with pair -> pmids
     pairPmids = defaultdict(set)
     for row in rows:
         genes1 = set(row.causeGenes.split("|"))
         genes2 = set(row.themeGenes.split("|"))
         pairs = list([(aa, bb) for aa in genes1 for bb in genes2])
         for cause, theme in pairs:
             pairPmids[(cause, theme)].add(row.pmid)
     
     return pairPmids
 
 def writeGraphTable(allPairs, pairDocs, pairToDbs, pairMinResCounts, pwDirPairs, bestSentences, outFname, outFname2):
     " write the ggLink table "
     logging.info("writing merged graph to %s" % outFname)
     rows = []
     rows2 = []
     for pair,pairRows in allPairs.iteritems():
         gene1, gene2 = pair
 
         dbs = set()
         flags = []
         if pair in dbPairs:
             flags.append("ppi")
         if pair in pwPairs:
             flags.append("pwy")
         if pair in textPairs:
             flags.append("text")
         refs = [row.eventId for row in pairRows]
         #if pairMinResultCounts:
             #flags.append("low")
         # direction of interaction - only based on pathways
         if pair in pwDirPairs:
             flags.append("fwd")
         if tuple(reversed(pair)) in pwDirPairs:
             flags.append("rev")
             
         forwDocs = pairDocs.get(pair, [])
         revDocs = pairDocs.get(tuple(reversed(pair)), [])
         allDocs = set(forwDocs).union(set(revDocs))
 
         if len(allDocs)<MINSUPP and "pwy" not in flags and "ppi" not in flags:
             # if it's text-mining only and less than X documents, just skip it
             continue
         pairMinResCount = pairMinResCounts.get(pair, 0)
 
         pairDbs = "|".join(pairToDbs.get(pair, []))
         snippet = bestSentences.get(pair, "")
         row = [gene1, gene2, ",".join(flags), str(len(forwDocs)), str(len(revDocs)), \
             str(len(allDocs)), pairDbs, str(pairMinResCount), snippet]
         rows.append(row)
 
         refs = list(refs)
         refs.sort()
         for ref in refs:
             #row2 = [gene1, gene2, ",".join(refs)]
             row = [gene1, gene2, ref]
             rows2.append(row)
 
     ofh = open(outFname, "w")
     rows.sort()
     for row in rows:
         ofh.write("\t".join(row))
         ofh.write("\n")
     ofh.close()
 
     ofh2 = open(outFname2, "w")
     rows2.sort()
     for row in rows2:
         ofh2.write("\t".join(row))
         ofh2.write("\n")
     ofh2.close()
 
 def runCmd(cmd):
     """ run command in shell, exit if not successful """
     msg = "Running shell command: %s" % cmd
     logging.debug(msg)
     ret = os.system(cmd)
     if ret!=0:
         raise Exception("Could not run command (Exitcode %d): %s" % (ret, cmd))
     return ret
 
 def asToSql(table, sqlDir):
     " given a table name, return the name of a .sql file with CREATE TABLE for it"
 
     asPath = join(autoSqlDir, table+".as")
     #tempBase = tempfile.mktemp()
     outBase = join(sqlDir, table)
     cmd = "autoSql %s %s" % (asPath, outBase)
     runCmd(cmd)
     #sql = open("%s.sql" % sqlFname).read()
 
     # delete the files that are not needed
     #assert(len(tempBase)>5) # paranoia check
     #cmd = "rm -f %s.h %s.c"  % (tempBase, tempBase)
     #runCmd(cmd)
 
     return outBase+".sql"
 
 def loadTable(db, tableDir, table):
     " load table into mysql, using autoSql "
     #sqlFname = join(tableDir, table+".sql")
     tmpSqlFname = asToSql(table, tableDir)
     tabFname = join(tableDir, table+".tab")
 
     cmd = "hgLoadSqlTab %s %s %s %s" % (db, table, tmpSqlFname, tabFname)
     try:
         runCmd(cmd)
     except:
         # make sure that the temp file gets deleted
         os.remove(tmpSqlFname)
         raise
 
     os.remove(tmpSqlFname)
 
 def hgsql(db, query):
     assert('"' not in sql)
     cmd = "hgsql %s -NBe '%s'" % (db, query)
 
 def addIndexes(db):
     " add the indexes for mysql "
     query = "ALTER TABLE ggLinkEvent ADD INDEX gene12Idx (gene1, gene2);"
     hgsql(db, query)
 
     query = "ALTER TABLE ggEventText ADD INDEX docIdIdx (docId);"
     hgsql(db, query)
 
 def loadTables(tableDir, db):
     " load graph tables into mysql "
 
     loadTable(db, tableDir, "ggDoc")
     loadTable(db, tableDir, "ggDocEvent")
     loadTable(db, tableDir, "ggEventDb")
     loadTable(db, tableDir, "ggEventText")
     loadTable(db, tableDir, "ggLink")
     loadTable(db, tableDir, "ggLinkEvent")
 
     addIndexes(db)
 
 def indexPmids(rowList, textRows):
     " return dict pmid -> list of event Ids "
     pmidToIds = defaultdict(set) 
     for rows in rowList:
         for row in rows:
             pmidStr = row.pmids
             if pmidStr=="":
                 continue
             pmids = pmidStr.split("|")
             rowId = row.eventId
             for pmid in pmids:
                 if pmid=="":
                     continue
                 pmidToIds[pmid].add(rowId)
 
     for row in textRows:
         if row.pmid=="":
             continue
         pmidToIds[row.pmid].add(row.eventId)
 
     return pmidToIds
 
 def writeDocEvents(pmidToId, outFname):
     " write a table with PMID -> list of event Ids "
     logging.info("Writing docId-eventId %s" % outFname)
     ofh = open(outFname, "w")
     for docId, eventIds in pmidToId.iteritems():
         eventIds = sorted(list(eventIds))
         for eventId in eventIds:
             ofh.write("%s\t%s\n" % (docId, eventId))
     ofh.close()
 
 def writeEventTable(rowList, outFname, colCount=None):
     " write the event table with event details "
     logging.info("Writing events to %s" % outFname)
     ofh = open(outFname, "w")
     for rows in rowList:
         for row in rows:
             if colCount:
                 row = row[:colCount]
             ofh.write("%s\n" % ("\t".join(row)))
     ofh.close()
 
 def pairToDbs(pairs):
     """ given pairs and data rows, return a dict pair -> int 
     that indicates how many DBs a pair is referenced in 
     """
     # first make dict event -> source dbs
     #eventDbs = defaultdict(set)
     #for row in pwRows:
         #eventDbs[row.eventId].add(row.sourceDb)
     #for row in dbRows:
         #sourceDbs = row.sourceDbs.split("|")
         #eventDbs[row.eventId].update(sourceDbs)
 
     # construct a dict pair -> source dbs
     pairDbs = defaultdict(set)
     for pair, rows in pairs.iteritems():
         for row in rows:
             pairDbs[pair].add(row.sourceDb)
 
     return pairDbs
 
 def parseMeshContext(fname):
     " given a medline trees file, return the list of disease and pathway names in it "
     # ex. filename is mtrees2013.bin (it's ascii)
     # WAGR Syndrome;C10.597.606.643.969
     terms = set()
     lines = open(fname)
     for line in lines:
         line = line.strip()
         term, code = line.split(";")
         term = term.strip()
         # all disease terms start with C
         if code.startswith("C"):
             terms.add(term)
         # all signalling pathways a below a very specific branch
         elif code.startswith("G02.149.115.800") and not code=="G02.149.115.800":
             terms.add(term)
     logging.info("Read %d disease/context MESH terms from %s" % (len(terms), fname))
     return terms
 
 
 def getDirectedPairs(pwRows):
     " get the set of directed gene pairs from the rows, keep the direction "
     pairs = set()
     for row in pwRows:
         for pair in iterAllPairs(row):
             pairs.add(pair)
     return pairs
             
 def writeAllDocInfo(textDir, outFname):
     " get all author/year/journal/title as tab-sep from a pubtools-style input directory, ~5GB big "
     mask = join(textDir, "*.articles.gz")
 
     ofh = open(outFname, "w")
     fnames = glob.glob(mask)
     doneDocs = set()
     for i, fname in enumerate(fnames):
         if i % 10 == 0:
             logging.info("%d out of %d files" % (i, len(fnames)))
 
         for row in lineFileNext(gzip.open(fname)):
             # skip duplicates
             if row.pmid in doneDocs:
                 continue
             doneDocs.add(row.pmid)
 
             if row.year.isdigit() and int(row.year)>1975:
                 newRow = (row.pmid, row.authors, row.year, row.journal, row.printIssn, \
                         row.title, row.abstract, row.keywords)
                 ofh.write("\t".join(newRow))
                 ofh.write("\n")
     ofh.close()
     logging.info("Article info written to %s" % outFname)
 
 def parseShortNames(journalFname):
     # get dict ISSN -> short name
     shortNames = {}
     if isfile(journalFname):
         for row in lineFileNext(open(journalFname)):
             if row.medlineTA!="" and row.pIssn!="":
                 shortNames[row.pIssn] = row.medlineTA
         logging.info("Read a short journal name for %d ISSNs from  %s" % (len(shortNames), journalFname))
     else:
         logging.info("%s not found, not shortening journal names" % journalFname)
     return shortNames
 
 def writeDocsTable(pmidEventPath, medlinePath, shortNames, contextFilter, resCounts, outFname):
     """ join pmid-Event info and our shortened medline version 
     resCount is a set of docIds with low-throughput data (fewer than LTPMAX interactions per doc)
     """
     # parse the PMIDs to export
     docIds = set()
     for row in lineFileNext(open(pmidEventPath), headers=["docId", "eventId"]):
         docIds.add(row.docId)
     logging.info("read %d document IDs from %s" %  (len(docIds), pmidEventPath))
 
     docContexts = {}
 
     logging.info("Writing to %s" % outFname)
     ofh = open(outFname, "w")
     # fields are: docId, authors, year, journal, printIssn, title, abstract, keywords
     foundIds = set()
     for line in open(medlinePath):
         fields = line.rstrip("\n").split("\t")
         docId = fields[0]
         if docId in docIds:
             issn = fields[4]
             shortName = shortNames.get(issn)
             if shortName!=None:
                 fields[3] = shortName
 
             newKeywords = []
             for kw in fields[7].split("/"):
                 if kw in contextFilter:
                     newKeywords.append(kw)
             docContext = "|".join(newKeywords)
             fields[7] = docContext
             if docContext!="":
                 docContexts[docId] = docContext
 
             # add a field: how many gene-pairs are associated to this paper
             fields.append(resCounts.get(docId, "0"))
 
             line = "\t".join(fields)+"\n"
 
             ofh.write(line)
             foundIds.add(docId)
     ofh.close()
 
     notFoundIds = docIds - foundIds
     logging.info("No info for %d documents" % len(notFoundIds))
     logging.debug("No info for these documents: %s" % notFoundIds)
 
     return docContexts
 
 def sumBasic(sentences, commonWords):
     """ given probabilities of words, rank sentences by average prob
     (removing commonWords).
     Sentences is a list of list of words
     Algorithm is described in http://ijcai.org/papers07/Papers/IJCAI07-287.pdf
 
     Returns sentence with highest score and shortest length, if several have a highest score
     """
 
     if len(sentences)==0:
         return ""
     sentWordsList = [set(wordRe.findall(sentence)) for sentence in sentences]
     words = list(chain.from_iterable(sentWordsList))
     wordProbs = {word: float(count)/len(words) for word, count in Counter(words).items()}
 
     scoredSentences = []
     for sentWords, sentence in zip(sentWordsList, sentences):
         mainWords = sentWords - commonWords
         if len(mainWords)==0:
             continue
         avgProb = sum([wordProbs[word] for word in mainWords]) / len(mainWords)
         scoredSentences.append((avgProb, sentence, sentWords))
 
     # happens rarely: all words are common English words
     if len(scoredSentences)==0:
         return ""
 
     # get sentences with equally good top score
     scoredSentences.sort(key=operator.itemgetter(0), reverse=True)
     topScore = scoredSentences[0][0]
     topSents = [(sent, words) for score, sent, words in scoredSentences if score >= topScore]
 
     # sort these by length and pick shortest one
     topSentLens = [(len(s), s, w) for s, w in topSents]
     topSentLens.sort(key=operator.itemgetter(0))
     topLen, topSent, topWords = topSentLens[0]
 
     # update word frequencies
     for word in topWords:
         wordProbs[word] *= wordProbs[word]
 
     return topSent
 
 def runSumBasic(textPairs, wordFname):
     """ Get all sentences for an interaction and use sumBasic to pick the best one 
     """
     # get list of very common English words
     bncWords = set([line.split()[0] for line in open(wordFname).read().splitlines()])
     logging.info("Loaded %d common English words from %s" % (len(bncWords), wordFname))
 
     logging.info("Running SumBasic on sentences")
     bestSentences = {}
     for pair, rows in textPairs.iteritems():
         sentSet = set()
         for row in rows:
             if row.sentence!="":
                 sentSet.add(row.sentence)
 
         sentences = list(sentSet)
         bestSentences[pair] = sumBasic(sentences, bncWords)
     return bestSentences
 
 def readDictList(fname, reverse=False):
     " read a key-value tab-sep table and return as a dict of key -> values"
     logging.info("reading %s" % fname)
     data = defaultdict(list)
     for line in open(fname):
         key, val = line.rstrip("\n").split("\t")
         if reverse:
             key, val = val, key
         data[key].append(val)
     return data
 
 def readDict(fname, reverse=False):
     " read a key-value tab-sep table and return as a dict of key -> value"
     logging.info("reading %s" % fname)
     data = defaultdict(list)
     for line in open(fname):
         key, val = line.rstrip("\n").split("\t")
         if reverse:
             key, val = val, key
         data[key] = val
     return data
 
 def readPairEvent(fname):
     """ read a tab-sep table in format (gene1, gene2, eventId) and return as
     dict (gene1,gene2) -> list of eventId""" 
     logging.info("reading %s" % fname)
     data = defaultdict(list)
     for line in open(fname):
         gene1, gene2, val = line.rstrip("\n").split("\t")
         data[(gene1, gene2)].append(val)
     return data
 
 def addContext(ctFname, docEventFname, linkEventFname, linkFname):
     " read the data from the first three files, and put it into the last field of linkFname "
     docContext = readDict(ctFname)
     eventDocs = readDictList(docEventFname, reverse=True)
     pairEvents = readPairEvent(linkEventFname)
     
     logging.info("Reading %s" % linkFname)
     newLines = []
     for line in open(linkFname):
         #print line
         fields = line.rstrip("\n").split("\t")
         pair = (fields[0], fields[1])
 
         contextCounts = Counter()
         for eventId in pairEvents[pair]:
             #print "pair %s, event %s" % (pair, eventId)
             for docId in eventDocs.get(eventId, []):
                 #print "doc %s" % docId
                 contexts = docContext.get(docId, "")
                 for context in contexts.split("|"):
                     #print "context %s" % context
                     if context=="" or context==" ":
                         continue
                     contextCounts[context]+=1
         # take best three contexts and reformat as a string
         suffix = ""
         if len(contextCounts)>3:
             suffix = "..."
         bestContexts = contextCounts.most_common()
         contextStrings = ["%s (%d)" % (ct, count) for ct, count in bestContexts]
         contextStr = ", ".join(contextStrings)
         fields.append(contextStr)
         newLines.append("\t".join(fields))
     return newLines
 
 def convGraph(outDir):
     import networkx as nx
     fname = join(outDir, "ggLink.tab")
     G=nx.Graph()
     for line in open(fname):
         g1, g2 = line.split()[:2]
         G.add_edge(g1, g2)
     outFname = join(outDir, "graph.bin")
     #nx.write_adjlist(G, outFname)
     cPickle.dump(G, open(outFname, "w"), cPickle.HIGHEST_PROTOCOL)
     logging.info("Wrote graph to %s" % outFname)
 
 def convGraph2(outDir):
     import igraph as ig
     fname = join(outDir, "ggLink.tab")
     G=ig.Graph()
     geneToId = {}
     nextId = 0
     edges = []
     for line in open(fname):
         g1, g2 = line.split()[:2]
 
         if g1 not in geneToId:
             id1 = geneToId[g1] = nextId
             nextId += 1
         else:
             id1 = geneToId[g1]
 
         if g2 not in geneToId:
             id2 = geneToId[g2] = nextId
             nextId += 1
         else:
             id2 = geneToId[g2]
 
         edges.append( (id1, id2) )
     #print nextId
     #print edges
     G.add_vertices(nextId+1)
     G.add_edges(edges)
     #outFname = join(outDir, "graph2.bin")
     #nx.write_adjlist(G, outFname)
     #cPickle.dump(G, open(outFname, "w"), cPickle.HIGHEST_PROTOCOL)
     #logging.info("Wrote graph to %s" % outFname)
     outFname = join(outDir, "graph.lgl")
     G.write(outFname, "lgl")
     logging.info("Wrote graph to %s" % outFname)
 
     outFname = join(outDir, "graph.genes.txt")
     ofh = open(outFname, "w")
     for gene, geneId in geneToId.iteritems():
         ofh.write("%s\t%s\n" % (gene, geneId))
     ofh.close()
     logging.info("Wrote nodeId-symbol mapping to %s" % outFname)
     #outFname = join(outDir, "graph.adj")
     #G.write(outFname, "adjacency")
     #outFname = join(outDir, "graph.leda")
     #G.write(outFname, "leda")
     #outFname = join(outDir, "graph.dot")
     #G.write(outFname, "dot")
     #outFname = join(outDir, "graph.pajek")
     #G.write(outFname, "pajek")
     #logging.info("Wrote graph to %s" % outFname)
 
 def convGraph3(outDir):
     """ convert graph to a compact format gene -> list of connected genes
     """
     fname = join(outDir, "ggLink.tab")
     geneLinks = defaultdict(set)
     #genes = set()
     # format: gene1, gene2, flags, forwDocs, revDocs, allDocs, pairDbs, pairMinResCount, snippet
     for line in open(fname):
         fields = line.split("\t")
 
         g1, g2, flags = fields[:3]
         #flags = flags.split(",")
         pairMinResCount = int(fields[7])
         docCount = int(fields[5])
 
         # require many documents or manually curated database
         if docCount < 3 and not "ppi" in flags and not "pwy" in flags:
             continue
         #pairDbs = fields[6]
 
         # ignore all interactions that are derived from papers 
         # that reported more than 100 interactions, e.g. big complexes
         if pairMinResCount > 100:
             continue
 
         #geneLinks[g1].add((g2, flags))
         #geneLinks[g2].add((g1, flags))
         #geneLinks[g1].add(g2)
         #geneLinks[g2].add(g1)
         geneLinks[g1].add((g2, docCount))
         geneLinks[g2].add((g1, docCount))
 
         #genes.add(g1)
         #genes.add(g2)
     geneLinks = dict(geneLinks)
 
     # rewrite to dict str -> list
     #geneLinks = {k:list(v) for k,v in geneLinks.iteritems()}
 
     # map gene -> integer
     #geneToId = {}
     #for geneId, gene in enumerate(genes):
         #geneToId[gene] = geneId
 
     # this is only slightly slower - probably best format
     #outFname = join(outDir, "graph.txt")
     #ofh = open(outFname, "w")
     #for gene, neighbors in geneLinks.iteritems():
         #ofh.write("%s\t%s\n" % (gene, ",".join(neighbors)))
     #ofh.close()
     #logging.info("Wrote links to %s" % outFname)
 
     # this is not faster
     #intLinks = {}
     #for gene, neighbors in geneLinks.iteritems():
         #intLinks[geneToId[gene]] = [geneToId[n] for n in neighbors]
 
     # fastest when taking into account set-building time
     outFname = join(outDir, "graph.marshal")
     marshal.dump(geneLinks, open(outFname, "w"), 2) # prot 0, 1 not faster
     logging.info("Wrote links to %s" % outFname)
 
     #outFname = join(outDir, "graph.ujson")
     #ujson.dump(geneLinks, open(outFname, "w"))
     #logging.info("Wrote json strings links to %s" % outFname)
 
     #outFname = join(outDir, "graphInt.marshal")
     #idToGene = {v: k for k, v in geneToId.iteritems()}
     #data = (idToGene, intLinks)
     #marshal.dump(data, open(outFname, "w"), 2)
     #logging.info("Wrote integer links to %s" % outFname)
 
     #outFname = join(outDir, "graphInt.ujson")
     #ujson.dump(data, open(outFname, "w"))
     #logging.info("Wrote json integer links to %s" % outFname)
 
 def loadGraph(outDir):
     import networkx as nx
     inFname = join(outDir, "graph.bin")
     gc.disable() # no need for GC here, saves 2 seconds
     G = cPickle.load(open(inFname))
     gc.enable()
     geneList = ["OTX2", "PITX2", "APOE", "TP53", "TNF", "SP1"]
     foundPairs = set()
     G2 = G.subgraph(geneList)
     print nx.nodes(G)
     #for g1, g2 in combinations(geneList, 2):
         #path = nx.shortest_path(G, g1, g2)
         #print path
         #for i in range(0, len(path)-1):
             #pair = tuple(sorted(path[i:i+2]))
             #foundPairs.add(pair)
     #print foundPairs
         
 def parseGeneToId(inFname):
     geneToId = {}
     idToGene = {}
     for l in open(inFname):
         gene, geneId = l.rstrip("\n").split("\t")
         geneId = int(geneId)
         geneToId[gene] = geneId
         idToGene[geneId] = gene
     return geneToId, idToGene
 
 def loadGraph2(outDir):
     import igraph as ig
     inFname = join(outDir, "graph.lgl")
     #gc.disable() # no need for GC here, saves 2 seconds
     #G = cPickle.load(open(inFname))
     #gc.enable()
     #G=ig.Graph()
     G = ig.load(inFname)
     #print G
 
     inFname = join(outDir, "graph.genes.txt")
     geneToId, idToGene = parseGeneToId(inFname)
 
     geneList = ["OTX2", "PITX2", "APOE", "TP53", "TNF", "SP1","ABCA1", "CD4", "BRCA2", "APP", "SRY", "GAST", "MYOD1"]
     idList = [geneToId[g] for g in geneList]
     allSyms = set()
     for i in range(0, len(idList)-1):
         fromId = idList[i]
         fromGene = geneList[i]
         toIds = idList[i+1:]
         toGenes = geneList[i+1:]
         print fromId, fromGene, toIds, toGenes
         paths = G.get_shortest_paths(fromId, toIds, mode=ig.ALL) 
         genePaths = []
         for idPath in paths:
             genePath = [idToGene[i] for i in idPath]
             genePaths.append(genePath)
             allSyms.update(genePath)
         print "paths", genePaths
         print ",".join(allSyms)
         #for i in range(0, len(path)-1):
             #pair = tuple(sorted(path[i:i+2]))
             #foundPairs.add(pair)
     #print foundPairs
     #geneList = ["OTX2", "PITX2", "APOE", "TP53", "TNF", "SP1"]
     #G2 = G.subgraph(geneList)
 
 def loadGraph3(outDir):
     inFname = join(outDir, "graph.marshal")
     #idToGene, idLinks = marshal.load(open(inFname))
     #idToGene, idLinks = ujson.load(open(inFname))
     #geneLinks = ujson.load(open(inFname))
     geneLinks = marshal.load(open(inFname))
     return geneLinks
 
     #idToGene = {int(k):v for k,v in idToGene.iteritems()}
 
     #geneLinks = {}
     #for geneId, linkedIds in idLinks.iteritems():
         #geneLinks[idToGene[int(geneId)]] = [idToGene[i] for i in linkedIds]
 
     # reversing the list was 25% slower than reading it all from disk
     # 1.1 seconds, so 20% slower than the marshal version
     #inFname = join(outDir, "graph.links.txt")
     #ofh = open(outFname, "w")
     #graph = {}
     #for line in open(inFname):
         #gene, neighbors = line.rstrip("\n").split("\t")
         #neighbors = set(neighbors.split(","))
         #graph[gene] = neighbors
     ##logging.info("Wrote links to %s" % outFname)
 
 def parseLinkTargets(outDir, validSyms):
     """ parse the ggLink table in outDir and return a dict gene -> Counter() of targetGenes -> count.
     Count is either the article count or, if there is no text mining hit, the count of databases 
     """
     errFh = open("ggLink.errors.tab", "w")
 
     inFname = join(outDir, "ggLink.tab")
     logging.info("Parsing %s" % inFname)
 
     asPath = join(autoSqlDir, "ggLink.as")
     targets = defaultdict(Counter)
     for row in lineFileNext(open(inFname), asFname=asPath):
         gene1, gene2 = row.gene1, row.gene2
         count = int(row.docCount)
         if count==0:
             count = len(row.dbList.split("|"))
 
         if gene1 not in validSyms:
             if gene2 not in validSyms:
                 errFh.write("BothSymsInvalid\t"+"\t".join(row)+"\n")
             else:
                 errFh.write("sym1Invalid\t"+"\t".join(row)+"\n")
         if gene2 not in validSyms:
             errFh.write("sym2Invalid\t"+"\t".join(row)+"\n")
 
         targets[gene1][gene2]=count
         targets[gene2][gene1]=count
 
     errFh.close()
     logging.info("Wrote rows from ggLink.tab with invalid symbols to ggLink.errors.tab")
 
     return targets
 
 def makeBigBed(inDir, outDir, bedFname, db):
     " create a file geneInteractions.<db>.bb in outDir from bedFname "
     validSyms = set()
     for line in open(bedFname):
         sym = line.split("\t")[3].rstrip("\n")
         validSyms.add(sym)
 
     # get interactors from our ggLink table
     geneCounts = parseLinkTargets(inDir, validSyms)
 
     # get genes from knownGenes table and write to bed
     #bedFname = join(outDir, "genes.%s.bed" % db)
     #logging.info("Writing genes to %s" % bedFname)
     #cmd = "hgsql %s -NBe 'select chrom, chromStart, chromEnd, geneSymbol from knownCanonical JOIN kgXref ON kgId=transcript' > %s" % (db, bedFname)
     #runCmd(cmd)
     #bedFname = "geneModels/gencode19.median.bed"
 
     # rewriting bed file and fill with counts
     bedOutFname = join(outDir, "geneInteractions.%s.bed" % db)
     ofh = open(bedOutFname, "w")
     logging.info("Rewriting %s to %s" % (bedFname, bedOutFname))
     doneSymbols = set()
     for line in open(bedFname):
         row = line.rstrip("\n").split("\t")
         gene = row[3]
         counts = geneCounts.get(gene, None)
         if counts==None:
             # skip gene if not found
             continue
 
         # create the new name field
         docCount = 0
         strList = []
         geneCount = 0
         for targetGene, count in counts.most_common():
             #strList.append("%s:%d" % (targetGene, count))
             if geneCount < 10:
                 strList.append("%s" % (targetGene))
             docCount += count
             geneCount += 1
 
         score = min(docCount, 1000)
 
         targetGenes = ",".join(strList)
         row[3] = gene+": "+targetGenes # why a space? see linkIdInName trackDb statement
 
         row.append( str(score) )
         row.append(".")
         row.append(row[1])
         row.append(row[2])
 
         if docCount > 100:
             color = "0,0,0"
         elif docCount > 10:
             color = "0,0,128"
         else:
             color = "173,216,230"
 
         row.append(color)
 
         ofh.write("\t".join(row))
         ofh.write("\n")
         doneSymbols.add(gene)
     ofh.close()
 
     missingSyms = set(geneCounts) - set(doneSymbols)
     logging.info("%d symbols in ggLink not found in BED file" % len(missingSyms))
     logging.info("missing symbols written to missSym.txt")
 
     ofh= open("missSym.txt", "w")
     ofh.write("\n".join(missingSyms))
     ofh.close()
 
     cmd = "bedSort %s %s" % (bedOutFname, bedOutFname)
     runCmd(cmd)
 
     bbFname = join(outDir, "geneInteractions.%s.bb" % db)
     chromSizeFname = "/hive/data/genomes/%s/chrom.sizes" % db
     cmd = "bedToBigBed -tab %s %s %s"  % (bedOutFname, chromSizeFname, bbFname)
     runCmd(cmd)
     logging.info("bigBed file written to %s" % bbFname)
     
 def findBestPaths(genes, geneLinks):
     " find best paths of max length 2 between genes using geneLinks. return pairs. "
     pairs = set()
     links = defaultdict(list) # dict (from, to) -> list of (docCountSum, path)
     for gene1 in genes:
         # search at distance 1
         for gene2, docCount2 in geneLinks.get(gene1, []):
             if gene2 in genes and gene2!=gene1:
                 # stop if found
                 print "%s-%s" % (gene1, gene2)
                 pairs.add( tuple(sorted((gene1, gene2))) )
                 links[ (gene1, gene2) ].append( (docCount2, [gene1, gene2]) )
                 continue
 
             # search at distance 2
             for gene3, docCount3 in geneLinks.get(gene2, []):
                 if gene3 in genes and gene3!=gene2 and gene3!=gene1:
                     # distance = 2
                     print "%s-%s-%s" % (gene1, gene2, gene3)
                     pairs.add( tuple(sorted((gene1, gene2))) )
                     pairs.add( tuple(sorted((gene2, gene3))) )
                     links[ (gene1, gene3) ].append( ((docCount2+docCount3)/2, [gene1, gene2, gene3]) )
 
     for genePair, paths in links.iteritems():
         paths.sort(reverse=True)
         print genePair, paths
         
     return pairs
 
 # ----------- MAIN --------------
 #if options.test:
     #import doctest
     #doctest.testmod()
     #sys.exit(0)
 
 if args==[]:
     parser.print_help()
     exit(1)
 
 cmd = args[0]
 if cmd == "build":
     wordFname = options.wordFname
 
     pathwayDir, dbDir, textDir, outDir = args[1:]
     # load the input files into memory
     dbRows = loadFiles(dbDir, prefix="ppi_") 
     pwRows = loadFiles(pathwayDir)
     textRows = loadFiles(textDir)
 
     # index and merge them
     dbPairs   = indexPairs(dbRows, "ppi databases")
     pwPairs   = indexPairs(pwRows, "pathways")
     textPairs = indexPairs(textRows, "text mining")
     pwDirPairs = getDirectedPairs(pwRows)
 
     curatedPairs = mergePairs([dbPairs, pwPairs])
     pairMinResultCounts, docToPairs = getResultCounts(curatedPairs)
 
     bestSentences = runSumBasic(textPairs, wordFname)
     allPairs = mergePairs([curatedPairs, textPairs])
 
     #ltPairs, ltDocs = getResultCounts(curatedPairs)
     # keep result counts for the "docs" step
     ofh = open(join(outDir, "resultCounts.tmp.txt"), "w")
     for docId, pairs in docToPairs.iteritems():
         ofh.write("%s\t%d\n" % (docId, len(pairs)))
     ofh.close()
 
     pairDirDocs = directedPairToDocs(textRows)
     pairDbs = pairToDbs(curatedPairs)
 
     outFname = join(outDir, "ggLink.tmp.txt") # needs the addContext step to complete it
     eventFname = join(outDir, "ggLinkEvent.tab")
     writeGraphTable(allPairs, pairDirDocs, pairDbs, pairMinResultCounts, pwDirPairs, \
         bestSentences, outFname, eventFname)
 
     pmidToId = indexPmids([dbRows,pwRows], textRows)
     outFname = join(outDir, "ggDocEvent.tab")
     writeDocEvents(pmidToId, outFname)
 
     outFname = join(outDir, "ggEventDb.tab")
     writeEventTable([dbRows, pwRows], outFname, colCount=13)
 
     outFname = join(outDir, "ggEventText.tab")
     writeEventTable([textRows], outFname)
 
     # make sure we don't forget to update the link table with context
     linkFname = join(outDir, "ggLink.tab")
     if isfile(linkFname):
         os.remove(linkFname)
 
 elif cmd == "medline":
     outDir = args[1]
     textDir = options.textDir
     medlineFname = join(outDir, allArtFname)
     writeAllDocInfo(textDir, medlineFname)
 
 elif cmd == "docs":
     outDir = args[1]
     outFname = join(outDir, "ggDoc.tab")
     pmidEventPath = join(outDir, "ggDocEvent.tab")
 
     medlineFname = join(outDir, allArtFname)
     meshTerms = parseMeshContext(options.meshFname)
 
     shortNames = parseShortNames(options.journalInfo)
 
     resCountFname = join(outDir, "resultCounts.tmp.txt")
     resCounts = readDict(resCountFname)
 
     docContext = writeDocsTable(pmidEventPath, medlineFname, shortNames, meshTerms, resCounts, outFname)
 
     # write docContext to file
     ctFname = join(outDir, "docContext.txt")
     ofh = open(ctFname, "w")
     for docId, context in docContext.iteritems():
         ofh.write("%s\t%s\n" % (docId, context))
     ofh.close()
     logging.info("Written document contexts to %s for %d documents" % (ctFname, len(docContext)))
 
 elif cmd == "context":
     outDir = args[1]
     ctFname = join(outDir, "docContext.txt")
     docEventFname = join(outDir, "ggDocEvent.tab")
     linkEventFname = join(outDir, "ggLinkEvent.tab")
     linkFname = join(outDir, "ggLink.tmp.txt")
     newLines = addContext(ctFname, docEventFname, linkEventFname, linkFname)
 
     outFname = join(outDir, "ggLink.tab")
     ofh = open(outFname, "w")
     for l in newLines:
         ofh.write("%s\n" % l)
     ofh.close()
     logging.info("appended document context to %s" % outFname)
 
 elif cmd == "bigBed":
     inDir = args[1]
     outDir = args[2]
     geneBedFile = args[3]
     db = args[4]
     makeBigBed(inDir, outDir, geneBedFile, db)
 
 elif cmd == "load":
     inDir = args[1]
     db = args[2]
     loadTables(inDir, db)
 
 # --- DEBUGGING / TESTING ----
 
 elif cmd=="sumBasic": # for debugging
     inFname = args[1]
     rows = []
     for row in lineFileNext(open(inFname)):
         rows.append(row)
     textPairs = indexPairs(rows, "text mining")
     for pair, sent in runSumBasic(textPairs, options.wordFname).iteritems():
         print pair, sent
     
 elif cmd == "graph":
     outDir = args[1]
     convGraph(outDir)
 
 elif cmd == "graph2":
     outDir = args[1]
     convGraph2(outDir)
 
 elif cmd == "graph3":
     outDir = args[1]
     convGraph3(outDir)
 
 elif cmd == "loadgraph":
     outDir = args[1]
     loadGraph(outDir)
 
 elif cmd == "load2":
     outDir = args[1]
     loadGraph2(outDir)
 
 elif cmd == "load3":
     outDir = args[1]
     loadGraph3(outDir)
 
 elif cmd == "subnet":
     outDir, geneFile = args[1:]
     geneLinks = loadGraph3(outDir)
     genes = set(open(geneFile).read().splitlines())
 
     print len(findBestPaths(genes, geneLinks))
 
 else:
     logging.error("unknown command %s" % cmd)