16720e7dfab79a7f60e91f0cb102a213c3e4738a max Fri Apr 28 15:39:08 2017 -0700 first big commit for hgGeneGraph. Others will follow as QA progresses. refs #13634 diff --git src/utils/ggTables src/utils/ggTables new file mode 100755 index 0000000..bb07669 --- /dev/null +++ src/utils/ggTables @@ -0,0 +1,1255 @@ +#!/usr/bin/env python2.7 + +import logging, sys, optparse, glob, itertools, os, tempfile, gzip, operator, re, cPickle, gc +import marshal +from collections import defaultdict, namedtuple, Counter +from itertools import chain, combinations +from os.path import join, basename, dirname, isfile, expanduser +import ujson + +# saves 20% of time when loading graph marshal +import gc +gc.disable() + +#LTPMAXMEMBERS=5 # maximum number of proteins in a complex for an interaction to quality for low throughput +LTPMAX=5 # maximum number of interactions a PMID can have to be declared low throughput + +# don't even write out links with less than this number of documents. 2 = weeds out many false positives. +# not using right now, because a few pathway databases do not annotate ANY PMID. Increasing this filter would remove +# all interactions from these pathways databases. +MINSUPP=0 + +outFields = ["gene1", "gene2", "flags", "refs", "fwWeight", "revWeight", "snip"] + +# directory with autoSql descriptions of output tables +autoSqlDir = expanduser("~/kent/src/hg/lib/") + +# file with all of medline in short form +allArtFname = "textInfo.tab" +# file with just pmids and events +#pmidEventFname = "temp.pmidToEvent.tab" + +# RE to split sentences +wordRe = re.compile("[a-zA-Z0-9]+") + +# === COMMAND LINE INTERFACE, OPTIONS AND HELP === +parser = optparse.OptionParser("""usage: %prog [options] build|load pathwayDir ppiDir textDir outDir - given various tab sep files with text-mining, gene interaction or pathway information, build the table ggLink, ggDoc, ggDb and ggText + +run it like this: + +%prog medline - to reduce the big medline table to something smaller, only needed once + +%prog build pathways ppi text mysql +%prog docs mysql # creates the ggDocs.tab file, slow +%prog context mysql +%prog load mysql publications +%prog bigBed outDir bigBedFile db +""") + +parser.add_option("-d", "--debug", dest="debug", action="store_true", help="show debug messages") +#parser.add_option("-t", "--test", dest="test", action="store_true", help="run tests") +parser.add_option("-t", "--textDir", dest="textDir", action="store", help="directory with the parsed copy of medline, default %default", default="/hive/data/inside/pubs/text/medline") +parser.add_option("-m", "--meshFile", dest="meshFname", action="store", help="An mtrees<year>.bin file, default %default", default="/hive/data/outside/ncbi/mesh/mtrees2015.bin") +parser.add_option("-j", "--journalInfo", dest="journalInfo", action="store", help="tab-sep file with journal info from the NLM Catalog converted by 'pubPrepCrawl publishers'. Used to shorten the journal names. Optional and not used if file is not found. Default %default", default="/cluster/home/max/projects/pubs/tools/data/journals/journals.tab") +parser.add_option("-b", "--wordFname", dest="wordFname", action="store", help="a file with common English words", default="/hive/data/outside/pubs/wordFrequency/bnc/bnc.txt") +#parser.add_option("-f", "--file", dest="file", action="store", help="run on file") +#parser.add_option("", "--test", dest="test", action="store_true", help="do something") +(options, args) = parser.parse_args() + +if options.debug: + logging.basicConfig(level=logging.DEBUG) +else: + logging.basicConfig(level=logging.INFO) +# ==== FUNCTIONs ===== + +def parseAutoSql(asFname): + " parse auto sql file and return list of field names " + headers = [] + for line in open(asFname): + line = line.replace("(","").replace(")","").strip() + if line.startswith('"') or line.startswith("\n"): + continue + if line.startswith("table"): + continue + if not ";" in line: + continue + parts = line.split(";")[0] + parts = parts.split() + assert(len(parts)>=2) + headers.append(parts[1]) + return headers + +def lineFileNext(fh, headers=None, asFname=None): + """ parses tab-sep file with headers as field names , assumes that file starts with headers + yields collection.namedtuples + """ + line1 = fh.readline() + line1 = line1.strip("\n").strip("#") + if headers==None and asFname!=None: + headers = parseAutoSql(asFname) + elif headers==None: + headers = line1.split("\t") + headers = [h.replace(" ", "_") for h in headers] + headers = [h.replace("(", "") for h in headers] + headers = [h.replace(")", "") for h in headers] + Record = namedtuple('tsvRec', headers) + + for line in fh: + line = line.rstrip("\n") + fields = line.split("\t") + try: + rec = Record(*fields) + except Exception, msg: + logging.error("Exception occured while parsing line, %s" % msg) + logging.error("Filename %s" % fh.name) + logging.error("Line was: %s" % repr(line)) + logging.error("Does number of fields match headers?") + logging.error("Headers are: %s" % headers) + #raise Exception("wrong field count in line %s" % line) + continue + # convert fields to correct data type + yield rec + +def loadFiles(inDir, prefix=None): + """ load .tab files into dict fType -> list of rows and return tuple (ppiRows, textRows) + """ + #typeRows = defaultdict(list) + #jppiRows = list() + #jtextRows = list() + pairs = defaultdict(list) + inFnames = glob.glob(join(inDir, "*.tab")) + rows = list() + for inFname in inFnames: + logging.info("Loading %s" % inFname) + #fType = None + for row in lineFileNext(open(inFname)): + if prefix!=None: + row = row._replace(eventId=prefix+row.eventId) + rows.append(row) + return rows + +def getResultCounts(pairs): + """ + Input is a pair -> rows dictionary + for each PMID, count how many pairs are assigned to it. This is + something like the "resultCount" of a paper, the lower, the better. + Then, for each pair, get the minimum resultCount and return as a dict + pair -> resultCount + """ + logging.info("Getting low-throughput studies in PPI and pathway data") + + # create dict doc -> set of gene pairs + docToPairs = defaultdict(set) + for pair, rows in pairs.iteritems(): + #print "pp", pair + for row in rows: + #print "r", row + #members = row.themeGenes.split("|") + #members.extend(row.causeGenes.split("|")) + # complexes with more than 5 proteins are not low-throughput anyways + # skip these right away + #if len(members)>LTPMAXMEMBERS: + #continue + docIds = row.pmids.split("|") + for docId in docIds: + if docId=="": + continue + docToPairs[docId].add(pair) + + #print "d2p", docToPairs + + pairMinResultCounts = {} + for pair, rows in pairs.iteritems(): + #print "pp", pair + resultCounts = [] + # get all docIds in rows + docIds = [] + for row in rows: + docIds.extend(row.pmids.split("|")) + + # get the minimal resultCount of all docIds + for docId in set(docIds): + if docId=="": + continue + #print "di2", docId + #print "d2p", docToPairs[docId] + resCount = len(docToPairs[docId]) + #print "rc", resCount + resultCounts.append(resCount) + if len(resultCounts)!=0: + minResCount = min(resultCounts) + else: + minResCount = 0 + #print "min", minResCount + pairMinResultCounts[pair] = minResCount + + #ltPairs = set() + #ltDocs = [] + #for pmid, pairList in docToPairs.iteritems(): + #if len(pairList) <= LTPMAX: + #ltPairs.update(pairList) + #ltDocs.append(pmid) + + #logging.info("Found %d low-throughput studies out of %d" % (len(ltDocs), len(pmidToPairs))) + #logging.info("Found %d low-throughput interactions out of %d" % (len(ltPairs), len(pairs))) + #return ltPairs, ltDocs + return pairMinResultCounts, docToPairs + +def allSubPairs(pair): + """ given a pair of two strings, where each can be a _-separate list of genes (a family), + return all combinations of each member + """ + x, y = pair + xs = x.split("_") + ys = y.split("_") + for subPair in [(a, b) for a in xs for b in ys]: + a, b = subPair + if a=="" or b=="" or a=="-" or b=="-" or \ + a.startswith("compound") or b.startswith("compound"): + continue + yield subPair + +def iterAllPairs(row): + """ yield all pairs of interacting genes for a given row. Handles families + >>> list(iterAllPairs("gene", ["TP1","TP2"], "complex", ["OMG1","OMG2"])) + [('OMG1', 'OMG2'), ('TP1', 'OMG1'), ('TP1', 'OMG2'), ('TP2', 'OMG1'), ('TP2', 'OMG2')] + >>> list(iterAllPairs("complex", ["TP1_TEST2","TP2"], "complex", ["OMG1","OMG2"])) + """ + + type1 = row.causeType + type2 = row.themeType + genes1 = set(row.causeGenes.split("|")) + genes2 = set(row.themeGenes.split("|")) + # all genes of complexes interact in some way + if type1=="complex": + for pair in itertools.combinations(genes1, 2): + # a complex can contain families + for subPair in allSubPairs(pair): + yield tuple(subPair) + if type2=="complex": + for pair in itertools.combinations(genes2, 2): + # a complex can contain families + for subPair in allSubPairs(pair): + yield tuple(subPair) + + # all genes from the left and the right side interact + if type2!="": + pairs = list([(aa, bb) for aa in genes1 for bb in genes2]) + for pair in pairs: + for subPair in allSubPairs(pair): + gene1, gene2 = subPair + if gene1=="-" or gene1=="" or gene2=="" or gene2=="-": + #skipCount += 1 + continue + if gene1.startswith("compound") or gene2.startswith("compound"): + continue + yield subPair + +def indexPairs(ppiRows, desc): + """ given rows with theme and cause genes, return + a dict with sorted (gene1, gene2) -> list of eventIds """ + logging.info("enumerating all interacting pairs: %s" % desc) + pairs = defaultdict(list) + for row in ppiRows: + for pair in iterAllPairs(row): + gene1, gene2 = pair + if gene1.startswith("compound") or gene2.startswith("compound"): + continue + pairs[tuple(sorted(pair))].append(row) + logging.info("got %d pairs" % len(pairs)) + return pairs + +def mergePairs(dicts): + " merge a list of defaultdict(list) into one defaultdict(set) " + logging.info("Merging pairs") + data = defaultdict(set) + for defDict in dicts: + for key, valList in defDict.iteritems(): + data[key].update(valList) + return data + +def directedPairToDocs(rows): + """ get documents of text mining pairs. a dict with pair -> text rows + create a dict pair -> set of document IDs. Note that these pairs are + DIRECTED - so they can be used to infer the direction of the interaction + """ + # create a dict with pair -> pmids + pairPmids = defaultdict(set) + for row in rows: + genes1 = set(row.causeGenes.split("|")) + genes2 = set(row.themeGenes.split("|")) + pairs = list([(aa, bb) for aa in genes1 for bb in genes2]) + for cause, theme in pairs: + pairPmids[(cause, theme)].add(row.pmid) + + return pairPmids + +def writeGraphTable(allPairs, pairDocs, pairToDbs, pairMinResCounts, pwDirPairs, bestSentences, outFname, outFname2): + " write the ggLink table " + logging.info("writing merged graph to %s" % outFname) + rows = [] + rows2 = [] + for pair,pairRows in allPairs.iteritems(): + gene1, gene2 = pair + + dbs = set() + flags = [] + if pair in dbPairs: + flags.append("ppi") + if pair in pwPairs: + flags.append("pwy") + if pair in textPairs: + flags.append("text") + refs = [row.eventId for row in pairRows] + #if pairMinResultCounts: + #flags.append("low") + # direction of interaction - only based on pathways + if pair in pwDirPairs: + flags.append("fwd") + if tuple(reversed(pair)) in pwDirPairs: + flags.append("rev") + + forwDocs = pairDocs.get(pair, []) + revDocs = pairDocs.get(tuple(reversed(pair)), []) + allDocs = set(forwDocs).union(set(revDocs)) + + if len(allDocs)<MINSUPP and "pwy" not in flags and "ppi" not in flags: + # if it's text-mining only and less than X documents, just skip it + continue + pairMinResCount = pairMinResCounts.get(pair, 0) + + pairDbs = "|".join(pairToDbs.get(pair, [])) + snippet = bestSentences.get(pair, "") + row = [gene1, gene2, ",".join(flags), str(len(forwDocs)), str(len(revDocs)), \ + str(len(allDocs)), pairDbs, str(pairMinResCount), snippet] + rows.append(row) + + refs = list(refs) + refs.sort() + for ref in refs: + #row2 = [gene1, gene2, ",".join(refs)] + row = [gene1, gene2, ref] + rows2.append(row) + + ofh = open(outFname, "w") + rows.sort() + for row in rows: + ofh.write("\t".join(row)) + ofh.write("\n") + ofh.close() + + ofh2 = open(outFname2, "w") + rows2.sort() + for row in rows2: + ofh2.write("\t".join(row)) + ofh2.write("\n") + ofh2.close() + +def runCmd(cmd): + """ run command in shell, exit if not successful """ + msg = "Running shell command: %s" % cmd + logging.debug(msg) + ret = os.system(cmd) + if ret!=0: + raise Exception("Could not run command (Exitcode %d): %s" % (ret, cmd)) + return ret + +def asToSql(table, sqlDir): + " given a table name, return the name of a .sql file with CREATE TABLE for it" + + asPath = join(autoSqlDir, table+".as") + #tempBase = tempfile.mktemp() + outBase = join(sqlDir, table) + cmd = "autoSql %s %s" % (asPath, outBase) + runCmd(cmd) + #sql = open("%s.sql" % sqlFname).read() + + # delete the files that are not needed + #assert(len(tempBase)>5) # paranoia check + #cmd = "rm -f %s.h %s.c" % (tempBase, tempBase) + #runCmd(cmd) + + return outBase+".sql" + +def loadTable(db, tableDir, table): + " load table into mysql, using autoSql " + #sqlFname = join(tableDir, table+".sql") + tmpSqlFname = asToSql(table, tableDir) + tabFname = join(tableDir, table+".tab") + + cmd = "hgLoadSqlTab %s %s %s %s" % (db, table, tmpSqlFname, tabFname) + try: + runCmd(cmd) + except: + # make sure that the temp file gets deleted + os.remove(tmpSqlFname) + raise + + os.remove(tmpSqlFname) + +def hgsql(db, query): + assert('"' not in sql) + cmd = "hgsql %s -NBe '%s'" % (db, query) + +def addIndexes(db): + " add the indexes for mysql " + query = "ALTER TABLE ggLinkEvent ADD INDEX gene12Idx (gene1, gene2);" + hgsql(db, query) + + query = "ALTER TABLE ggEventText ADD INDEX docIdIdx (docId);" + hgsql(db, query) + +def loadTables(tableDir, db): + " load graph tables into mysql " + + loadTable(db, tableDir, "ggDoc") + loadTable(db, tableDir, "ggDocEvent") + loadTable(db, tableDir, "ggEventDb") + loadTable(db, tableDir, "ggEventText") + loadTable(db, tableDir, "ggLink") + loadTable(db, tableDir, "ggLinkEvent") + + addIndexes(db) + +def indexPmids(rowList, textRows): + " return dict pmid -> list of event Ids " + pmidToIds = defaultdict(set) + for rows in rowList: + for row in rows: + pmidStr = row.pmids + if pmidStr=="": + continue + pmids = pmidStr.split("|") + rowId = row.eventId + for pmid in pmids: + if pmid=="": + continue + pmidToIds[pmid].add(rowId) + + for row in textRows: + if row.pmid=="": + continue + pmidToIds[row.pmid].add(row.eventId) + + return pmidToIds + +def writeDocEvents(pmidToId, outFname): + " write a table with PMID -> list of event Ids " + logging.info("Writing docId-eventId %s" % outFname) + ofh = open(outFname, "w") + for docId, eventIds in pmidToId.iteritems(): + eventIds = sorted(list(eventIds)) + for eventId in eventIds: + ofh.write("%s\t%s\n" % (docId, eventId)) + ofh.close() + +def writeEventTable(rowList, outFname, colCount=None): + " write the event table with event details " + logging.info("Writing events to %s" % outFname) + ofh = open(outFname, "w") + for rows in rowList: + for row in rows: + if colCount: + row = row[:colCount] + ofh.write("%s\n" % ("\t".join(row))) + ofh.close() + +def pairToDbs(pairs): + """ given pairs and data rows, return a dict pair -> int + that indicates how many DBs a pair is referenced in + """ + # first make dict event -> source dbs + #eventDbs = defaultdict(set) + #for row in pwRows: + #eventDbs[row.eventId].add(row.sourceDb) + #for row in dbRows: + #sourceDbs = row.sourceDbs.split("|") + #eventDbs[row.eventId].update(sourceDbs) + + # construct a dict pair -> source dbs + pairDbs = defaultdict(set) + for pair, rows in pairs.iteritems(): + for row in rows: + pairDbs[pair].add(row.sourceDb) + + return pairDbs + +def parseMeshContext(fname): + " given a medline trees file, return the list of disease and pathway names in it " + # ex. filename is mtrees2013.bin (it's ascii) + # WAGR Syndrome;C10.597.606.643.969 + terms = set() + lines = open(fname) + for line in lines: + line = line.strip() + term, code = line.split(";") + term = term.strip() + # all disease terms start with C + if code.startswith("C"): + terms.add(term) + # all signalling pathways a below a very specific branch + elif code.startswith("G02.149.115.800") and not code=="G02.149.115.800": + terms.add(term) + logging.info("Read %d disease/context MESH terms from %s" % (len(terms), fname)) + return terms + + +def getDirectedPairs(pwRows): + " get the set of directed gene pairs from the rows, keep the direction " + pairs = set() + for row in pwRows: + for pair in iterAllPairs(row): + pairs.add(pair) + return pairs + +def writeAllDocInfo(textDir, outFname): + " get all author/year/journal/title as tab-sep from a pubtools-style input directory, ~5GB big " + mask = join(textDir, "*.articles.gz") + + ofh = open(outFname, "w") + fnames = glob.glob(mask) + doneDocs = set() + for i, fname in enumerate(fnames): + if i % 10 == 0: + logging.info("%d out of %d files" % (i, len(fnames))) + + for row in lineFileNext(gzip.open(fname)): + # skip duplicates + if row.pmid in doneDocs: + continue + doneDocs.add(row.pmid) + + if row.year.isdigit() and int(row.year)>1975: + newRow = (row.pmid, row.authors, row.year, row.journal, row.printIssn, \ + row.title, row.abstract, row.keywords) + ofh.write("\t".join(newRow)) + ofh.write("\n") + ofh.close() + logging.info("Article info written to %s" % outFname) + +def parseShortNames(journalFname): + # get dict ISSN -> short name + shortNames = {} + if isfile(journalFname): + for row in lineFileNext(open(journalFname)): + if row.medlineTA!="" and row.pIssn!="": + shortNames[row.pIssn] = row.medlineTA + logging.info("Read a short journal name for %d ISSNs from %s" % (len(shortNames), journalFname)) + else: + logging.info("%s not found, not shortening journal names" % journalFname) + return shortNames + +def writeDocsTable(pmidEventPath, medlinePath, shortNames, contextFilter, resCounts, outFname): + """ join pmid-Event info and our shortened medline version + resCount is a set of docIds with low-throughput data (fewer than LTPMAX interactions per doc) + """ + # parse the PMIDs to export + docIds = set() + for row in lineFileNext(open(pmidEventPath), headers=["docId", "eventId"]): + docIds.add(row.docId) + logging.info("read %d document IDs from %s" % (len(docIds), pmidEventPath)) + + docContexts = {} + + logging.info("Writing to %s" % outFname) + ofh = open(outFname, "w") + # fields are: docId, authors, year, journal, printIssn, title, abstract, keywords + foundIds = set() + for line in open(medlinePath): + fields = line.rstrip("\n").split("\t") + docId = fields[0] + if docId in docIds: + issn = fields[4] + shortName = shortNames.get(issn) + if shortName!=None: + fields[3] = shortName + + newKeywords = [] + for kw in fields[7].split("/"): + if kw in contextFilter: + newKeywords.append(kw) + docContext = "|".join(newKeywords) + fields[7] = docContext + if docContext!="": + docContexts[docId] = docContext + + # add a field: how many gene-pairs are associated to this paper + fields.append(resCounts.get(docId, "0")) + + line = "\t".join(fields)+"\n" + + ofh.write(line) + foundIds.add(docId) + ofh.close() + + notFoundIds = docIds - foundIds + logging.info("No info for %d documents" % len(notFoundIds)) + logging.debug("No info for these documents: %s" % notFoundIds) + + return docContexts + +def sumBasic(sentences, commonWords): + """ given probabilities of words, rank sentences by average prob + (removing commonWords). + Sentences is a list of list of words + Algorithm is described in http://ijcai.org/papers07/Papers/IJCAI07-287.pdf + + Returns sentence with highest score and shortest length, if several have a highest score + """ + + if len(sentences)==0: + return "" + sentWordsList = [set(wordRe.findall(sentence)) for sentence in sentences] + words = list(chain.from_iterable(sentWordsList)) + wordProbs = {word: float(count)/len(words) for word, count in Counter(words).items()} + + scoredSentences = [] + for sentWords, sentence in zip(sentWordsList, sentences): + mainWords = sentWords - commonWords + if len(mainWords)==0: + continue + avgProb = sum([wordProbs[word] for word in mainWords]) / len(mainWords) + scoredSentences.append((avgProb, sentence, sentWords)) + + # happens rarely: all words are common English words + if len(scoredSentences)==0: + return "" + + # get sentences with equally good top score + scoredSentences.sort(key=operator.itemgetter(0), reverse=True) + topScore = scoredSentences[0][0] + topSents = [(sent, words) for score, sent, words in scoredSentences if score >= topScore] + + # sort these by length and pick shortest one + topSentLens = [(len(s), s, w) for s, w in topSents] + topSentLens.sort(key=operator.itemgetter(0)) + topLen, topSent, topWords = topSentLens[0] + + # update word frequencies + for word in topWords: + wordProbs[word] *= wordProbs[word] + + return topSent + +def runSumBasic(textPairs, wordFname): + """ Get all sentences for an interaction and use sumBasic to pick the best one + """ + # get list of very common English words + bncWords = set([line.split()[0] for line in open(wordFname).read().splitlines()]) + logging.info("Loaded %d common English words from %s" % (len(bncWords), wordFname)) + + logging.info("Running SumBasic on sentences") + bestSentences = {} + for pair, rows in textPairs.iteritems(): + sentSet = set() + for row in rows: + if row.sentence!="": + sentSet.add(row.sentence) + + sentences = list(sentSet) + bestSentences[pair] = sumBasic(sentences, bncWords) + return bestSentences + +def readDictList(fname, reverse=False): + " read a key-value tab-sep table and return as a dict of key -> values" + logging.info("reading %s" % fname) + data = defaultdict(list) + for line in open(fname): + key, val = line.rstrip("\n").split("\t") + if reverse: + key, val = val, key + data[key].append(val) + return data + +def readDict(fname, reverse=False): + " read a key-value tab-sep table and return as a dict of key -> value" + logging.info("reading %s" % fname) + data = defaultdict(list) + for line in open(fname): + key, val = line.rstrip("\n").split("\t") + if reverse: + key, val = val, key + data[key] = val + return data + +def readPairEvent(fname): + """ read a tab-sep table in format (gene1, gene2, eventId) and return as + dict (gene1,gene2) -> list of eventId""" + logging.info("reading %s" % fname) + data = defaultdict(list) + for line in open(fname): + gene1, gene2, val = line.rstrip("\n").split("\t") + data[(gene1, gene2)].append(val) + return data + +def addContext(ctFname, docEventFname, linkEventFname, linkFname): + " read the data from the first three files, and put it into the last field of linkFname " + docContext = readDict(ctFname) + eventDocs = readDictList(docEventFname, reverse=True) + pairEvents = readPairEvent(linkEventFname) + + logging.info("Reading %s" % linkFname) + newLines = [] + for line in open(linkFname): + #print line + fields = line.rstrip("\n").split("\t") + pair = (fields[0], fields[1]) + + contextCounts = Counter() + for eventId in pairEvents[pair]: + #print "pair %s, event %s" % (pair, eventId) + for docId in eventDocs.get(eventId, []): + #print "doc %s" % docId + contexts = docContext.get(docId, "") + for context in contexts.split("|"): + #print "context %s" % context + if context=="" or context==" ": + continue + contextCounts[context]+=1 + # take best three contexts and reformat as a string + suffix = "" + if len(contextCounts)>3: + suffix = "..." + bestContexts = contextCounts.most_common() + contextStrings = ["%s (%d)" % (ct, count) for ct, count in bestContexts] + contextStr = ", ".join(contextStrings) + fields.append(contextStr) + newLines.append("\t".join(fields)) + return newLines + +def convGraph(outDir): + import networkx as nx + fname = join(outDir, "ggLink.tab") + G=nx.Graph() + for line in open(fname): + g1, g2 = line.split()[:2] + G.add_edge(g1, g2) + outFname = join(outDir, "graph.bin") + #nx.write_adjlist(G, outFname) + cPickle.dump(G, open(outFname, "w"), cPickle.HIGHEST_PROTOCOL) + logging.info("Wrote graph to %s" % outFname) + +def convGraph2(outDir): + import igraph as ig + fname = join(outDir, "ggLink.tab") + G=ig.Graph() + geneToId = {} + nextId = 0 + edges = [] + for line in open(fname): + g1, g2 = line.split()[:2] + + if g1 not in geneToId: + id1 = geneToId[g1] = nextId + nextId += 1 + else: + id1 = geneToId[g1] + + if g2 not in geneToId: + id2 = geneToId[g2] = nextId + nextId += 1 + else: + id2 = geneToId[g2] + + edges.append( (id1, id2) ) + #print nextId + #print edges + G.add_vertices(nextId+1) + G.add_edges(edges) + #outFname = join(outDir, "graph2.bin") + #nx.write_adjlist(G, outFname) + #cPickle.dump(G, open(outFname, "w"), cPickle.HIGHEST_PROTOCOL) + #logging.info("Wrote graph to %s" % outFname) + outFname = join(outDir, "graph.lgl") + G.write(outFname, "lgl") + logging.info("Wrote graph to %s" % outFname) + + outFname = join(outDir, "graph.genes.txt") + ofh = open(outFname, "w") + for gene, geneId in geneToId.iteritems(): + ofh.write("%s\t%s\n" % (gene, geneId)) + ofh.close() + logging.info("Wrote nodeId-symbol mapping to %s" % outFname) + #outFname = join(outDir, "graph.adj") + #G.write(outFname, "adjacency") + #outFname = join(outDir, "graph.leda") + #G.write(outFname, "leda") + #outFname = join(outDir, "graph.dot") + #G.write(outFname, "dot") + #outFname = join(outDir, "graph.pajek") + #G.write(outFname, "pajek") + #logging.info("Wrote graph to %s" % outFname) + +def convGraph3(outDir): + """ convert graph to a compact format gene -> list of connected genes + """ + fname = join(outDir, "ggLink.tab") + geneLinks = defaultdict(set) + #genes = set() + # format: gene1, gene2, flags, forwDocs, revDocs, allDocs, pairDbs, pairMinResCount, snippet + for line in open(fname): + fields = line.split("\t") + + g1, g2, flags = fields[:3] + #flags = flags.split(",") + pairMinResCount = int(fields[7]) + docCount = int(fields[5]) + + # require many documents or manually curated database + if docCount < 3 and not "ppi" in flags and not "pwy" in flags: + continue + #pairDbs = fields[6] + + # ignore all interactions that are derived from papers + # that reported more than 100 interactions, e.g. big complexes + if pairMinResCount > 100: + continue + + #geneLinks[g1].add((g2, flags)) + #geneLinks[g2].add((g1, flags)) + #geneLinks[g1].add(g2) + #geneLinks[g2].add(g1) + geneLinks[g1].add((g2, docCount)) + geneLinks[g2].add((g1, docCount)) + + #genes.add(g1) + #genes.add(g2) + geneLinks = dict(geneLinks) + + # rewrite to dict str -> list + #geneLinks = {k:list(v) for k,v in geneLinks.iteritems()} + + # map gene -> integer + #geneToId = {} + #for geneId, gene in enumerate(genes): + #geneToId[gene] = geneId + + # this is only slightly slower - probably best format + #outFname = join(outDir, "graph.txt") + #ofh = open(outFname, "w") + #for gene, neighbors in geneLinks.iteritems(): + #ofh.write("%s\t%s\n" % (gene, ",".join(neighbors))) + #ofh.close() + #logging.info("Wrote links to %s" % outFname) + + # this is not faster + #intLinks = {} + #for gene, neighbors in geneLinks.iteritems(): + #intLinks[geneToId[gene]] = [geneToId[n] for n in neighbors] + + # fastest when taking into account set-building time + outFname = join(outDir, "graph.marshal") + marshal.dump(geneLinks, open(outFname, "w"), 2) # prot 0, 1 not faster + logging.info("Wrote links to %s" % outFname) + + #outFname = join(outDir, "graph.ujson") + #ujson.dump(geneLinks, open(outFname, "w")) + #logging.info("Wrote json strings links to %s" % outFname) + + #outFname = join(outDir, "graphInt.marshal") + #idToGene = {v: k for k, v in geneToId.iteritems()} + #data = (idToGene, intLinks) + #marshal.dump(data, open(outFname, "w"), 2) + #logging.info("Wrote integer links to %s" % outFname) + + #outFname = join(outDir, "graphInt.ujson") + #ujson.dump(data, open(outFname, "w")) + #logging.info("Wrote json integer links to %s" % outFname) + +def loadGraph(outDir): + import networkx as nx + inFname = join(outDir, "graph.bin") + gc.disable() # no need for GC here, saves 2 seconds + G = cPickle.load(open(inFname)) + gc.enable() + geneList = ["OTX2", "PITX2", "APOE", "TP53", "TNF", "SP1"] + foundPairs = set() + G2 = G.subgraph(geneList) + print nx.nodes(G) + #for g1, g2 in combinations(geneList, 2): + #path = nx.shortest_path(G, g1, g2) + #print path + #for i in range(0, len(path)-1): + #pair = tuple(sorted(path[i:i+2])) + #foundPairs.add(pair) + #print foundPairs + +def parseGeneToId(inFname): + geneToId = {} + idToGene = {} + for l in open(inFname): + gene, geneId = l.rstrip("\n").split("\t") + geneId = int(geneId) + geneToId[gene] = geneId + idToGene[geneId] = gene + return geneToId, idToGene + +def loadGraph2(outDir): + import igraph as ig + inFname = join(outDir, "graph.lgl") + #gc.disable() # no need for GC here, saves 2 seconds + #G = cPickle.load(open(inFname)) + #gc.enable() + #G=ig.Graph() + G = ig.load(inFname) + #print G + + inFname = join(outDir, "graph.genes.txt") + geneToId, idToGene = parseGeneToId(inFname) + + geneList = ["OTX2", "PITX2", "APOE", "TP53", "TNF", "SP1","ABCA1", "CD4", "BRCA2", "APP", "SRY", "GAST", "MYOD1"] + idList = [geneToId[g] for g in geneList] + allSyms = set() + for i in range(0, len(idList)-1): + fromId = idList[i] + fromGene = geneList[i] + toIds = idList[i+1:] + toGenes = geneList[i+1:] + print fromId, fromGene, toIds, toGenes + paths = G.get_shortest_paths(fromId, toIds, mode=ig.ALL) + genePaths = [] + for idPath in paths: + genePath = [idToGene[i] for i in idPath] + genePaths.append(genePath) + allSyms.update(genePath) + print "paths", genePaths + print ",".join(allSyms) + #for i in range(0, len(path)-1): + #pair = tuple(sorted(path[i:i+2])) + #foundPairs.add(pair) + #print foundPairs + #geneList = ["OTX2", "PITX2", "APOE", "TP53", "TNF", "SP1"] + #G2 = G.subgraph(geneList) + +def loadGraph3(outDir): + inFname = join(outDir, "graph.marshal") + #idToGene, idLinks = marshal.load(open(inFname)) + #idToGene, idLinks = ujson.load(open(inFname)) + #geneLinks = ujson.load(open(inFname)) + geneLinks = marshal.load(open(inFname)) + return geneLinks + + #idToGene = {int(k):v for k,v in idToGene.iteritems()} + + #geneLinks = {} + #for geneId, linkedIds in idLinks.iteritems(): + #geneLinks[idToGene[int(geneId)]] = [idToGene[i] for i in linkedIds] + + # reversing the list was 25% slower than reading it all from disk + # 1.1 seconds, so 20% slower than the marshal version + #inFname = join(outDir, "graph.links.txt") + #ofh = open(outFname, "w") + #graph = {} + #for line in open(inFname): + #gene, neighbors = line.rstrip("\n").split("\t") + #neighbors = set(neighbors.split(",")) + #graph[gene] = neighbors + ##logging.info("Wrote links to %s" % outFname) + +def parseLinkTargets(outDir, validSyms): + """ parse the ggLink table in outDir and return a dict gene -> Counter() of targetGenes -> count. + Count is either the article count or, if there is no text mining hit, the count of databases + """ + errFh = open("ggLink.errors.tab", "w") + + inFname = join(outDir, "ggLink.tab") + logging.info("Parsing %s" % inFname) + + asPath = join(autoSqlDir, "ggLink.as") + targets = defaultdict(Counter) + for row in lineFileNext(open(inFname), asFname=asPath): + gene1, gene2 = row.gene1, row.gene2 + count = int(row.docCount) + if count==0: + count = len(row.dbList.split("|")) + + if gene1 not in validSyms: + if gene2 not in validSyms: + errFh.write("BothSymsInvalid\t"+"\t".join(row)+"\n") + else: + errFh.write("sym1Invalid\t"+"\t".join(row)+"\n") + if gene2 not in validSyms: + errFh.write("sym2Invalid\t"+"\t".join(row)+"\n") + + targets[gene1][gene2]=count + targets[gene2][gene1]=count + + errFh.close() + logging.info("Wrote rows from ggLink.tab with invalid symbols to ggLink.errors.tab") + + return targets + +def makeBigBed(inDir, outDir, bedFname, db): + " create a file geneInteractions.<db>.bb in outDir from bedFname " + validSyms = set() + for line in open(bedFname): + sym = line.split("\t")[3].rstrip("\n") + validSyms.add(sym) + + # get interactors from our ggLink table + geneCounts = parseLinkTargets(inDir, validSyms) + + # get genes from knownGenes table and write to bed + #bedFname = join(outDir, "genes.%s.bed" % db) + #logging.info("Writing genes to %s" % bedFname) + #cmd = "hgsql %s -NBe 'select chrom, chromStart, chromEnd, geneSymbol from knownCanonical JOIN kgXref ON kgId=transcript' > %s" % (db, bedFname) + #runCmd(cmd) + #bedFname = "geneModels/gencode19.median.bed" + + # rewriting bed file and fill with counts + bedOutFname = join(outDir, "geneInteractions.%s.bed" % db) + ofh = open(bedOutFname, "w") + logging.info("Rewriting %s to %s" % (bedFname, bedOutFname)) + doneSymbols = set() + for line in open(bedFname): + row = line.rstrip("\n").split("\t") + gene = row[3] + counts = geneCounts.get(gene, None) + if counts==None: + # skip gene if not found + continue + + # create the new name field + docCount = 0 + strList = [] + geneCount = 0 + for targetGene, count in counts.most_common(): + #strList.append("%s:%d" % (targetGene, count)) + if geneCount < 10: + strList.append("%s" % (targetGene)) + docCount += count + geneCount += 1 + + score = min(docCount, 1000) + + targetGenes = ",".join(strList) + row[3] = gene+": "+targetGenes # why a space? see linkIdInName trackDb statement + + row.append( str(score) ) + row.append(".") + row.append(row[1]) + row.append(row[2]) + + if docCount > 100: + color = "0,0,0" + elif docCount > 10: + color = "0,0,128" + else: + color = "173,216,230" + + row.append(color) + + ofh.write("\t".join(row)) + ofh.write("\n") + doneSymbols.add(gene) + ofh.close() + + missingSyms = set(geneCounts) - set(doneSymbols) + logging.info("%d symbols in ggLink not found in BED file" % len(missingSyms)) + logging.info("missing symbols written to missSym.txt") + + ofh= open("missSym.txt", "w") + ofh.write("\n".join(missingSyms)) + ofh.close() + + cmd = "bedSort %s %s" % (bedOutFname, bedOutFname) + runCmd(cmd) + + bbFname = join(outDir, "geneInteractions.%s.bb" % db) + chromSizeFname = "/hive/data/genomes/%s/chrom.sizes" % db + cmd = "bedToBigBed -tab %s %s %s" % (bedOutFname, chromSizeFname, bbFname) + runCmd(cmd) + logging.info("bigBed file written to %s" % bbFname) + +def findBestPaths(genes, geneLinks): + " find best paths of max length 2 between genes using geneLinks. return pairs. " + pairs = set() + links = defaultdict(list) # dict (from, to) -> list of (docCountSum, path) + for gene1 in genes: + # search at distance 1 + for gene2, docCount2 in geneLinks.get(gene1, []): + if gene2 in genes and gene2!=gene1: + # stop if found + print "%s-%s" % (gene1, gene2) + pairs.add( tuple(sorted((gene1, gene2))) ) + links[ (gene1, gene2) ].append( (docCount2, [gene1, gene2]) ) + continue + + # search at distance 2 + for gene3, docCount3 in geneLinks.get(gene2, []): + if gene3 in genes and gene3!=gene2 and gene3!=gene1: + # distance = 2 + print "%s-%s-%s" % (gene1, gene2, gene3) + pairs.add( tuple(sorted((gene1, gene2))) ) + pairs.add( tuple(sorted((gene2, gene3))) ) + links[ (gene1, gene3) ].append( ((docCount2+docCount3)/2, [gene1, gene2, gene3]) ) + + for genePair, paths in links.iteritems(): + paths.sort(reverse=True) + print genePair, paths + + return pairs + +# ----------- MAIN -------------- +#if options.test: + #import doctest + #doctest.testmod() + #sys.exit(0) + +if args==[]: + parser.print_help() + exit(1) + +cmd = args[0] +if cmd == "build": + wordFname = options.wordFname + + pathwayDir, dbDir, textDir, outDir = args[1:] + # load the input files into memory + dbRows = loadFiles(dbDir, prefix="ppi_") + pwRows = loadFiles(pathwayDir) + textRows = loadFiles(textDir) + + # index and merge them + dbPairs = indexPairs(dbRows, "ppi databases") + pwPairs = indexPairs(pwRows, "pathways") + textPairs = indexPairs(textRows, "text mining") + pwDirPairs = getDirectedPairs(pwRows) + + curatedPairs = mergePairs([dbPairs, pwPairs]) + pairMinResultCounts, docToPairs = getResultCounts(curatedPairs) + + bestSentences = runSumBasic(textPairs, wordFname) + allPairs = mergePairs([curatedPairs, textPairs]) + + #ltPairs, ltDocs = getResultCounts(curatedPairs) + # keep result counts for the "docs" step + ofh = open(join(outDir, "resultCounts.tmp.txt"), "w") + for docId, pairs in docToPairs.iteritems(): + ofh.write("%s\t%d\n" % (docId, len(pairs))) + ofh.close() + + pairDirDocs = directedPairToDocs(textRows) + pairDbs = pairToDbs(curatedPairs) + + outFname = join(outDir, "ggLink.tmp.txt") # needs the addContext step to complete it + eventFname = join(outDir, "ggLinkEvent.tab") + writeGraphTable(allPairs, pairDirDocs, pairDbs, pairMinResultCounts, pwDirPairs, \ + bestSentences, outFname, eventFname) + + pmidToId = indexPmids([dbRows,pwRows], textRows) + outFname = join(outDir, "ggDocEvent.tab") + writeDocEvents(pmidToId, outFname) + + outFname = join(outDir, "ggEventDb.tab") + writeEventTable([dbRows, pwRows], outFname, colCount=13) + + outFname = join(outDir, "ggEventText.tab") + writeEventTable([textRows], outFname) + + # make sure we don't forget to update the link table with context + linkFname = join(outDir, "ggLink.tab") + if isfile(linkFname): + os.remove(linkFname) + +elif cmd == "medline": + outDir = args[1] + textDir = options.textDir + medlineFname = join(outDir, allArtFname) + writeAllDocInfo(textDir, medlineFname) + +elif cmd == "docs": + outDir = args[1] + outFname = join(outDir, "ggDoc.tab") + pmidEventPath = join(outDir, "ggDocEvent.tab") + + medlineFname = join(outDir, allArtFname) + meshTerms = parseMeshContext(options.meshFname) + + shortNames = parseShortNames(options.journalInfo) + + resCountFname = join(outDir, "resultCounts.tmp.txt") + resCounts = readDict(resCountFname) + + docContext = writeDocsTable(pmidEventPath, medlineFname, shortNames, meshTerms, resCounts, outFname) + + # write docContext to file + ctFname = join(outDir, "docContext.txt") + ofh = open(ctFname, "w") + for docId, context in docContext.iteritems(): + ofh.write("%s\t%s\n" % (docId, context)) + ofh.close() + logging.info("Written document contexts to %s for %d documents" % (ctFname, len(docContext))) + +elif cmd == "context": + outDir = args[1] + ctFname = join(outDir, "docContext.txt") + docEventFname = join(outDir, "ggDocEvent.tab") + linkEventFname = join(outDir, "ggLinkEvent.tab") + linkFname = join(outDir, "ggLink.tmp.txt") + newLines = addContext(ctFname, docEventFname, linkEventFname, linkFname) + + outFname = join(outDir, "ggLink.tab") + ofh = open(outFname, "w") + for l in newLines: + ofh.write("%s\n" % l) + ofh.close() + logging.info("appended document context to %s" % outFname) + +elif cmd == "bigBed": + inDir = args[1] + outDir = args[2] + geneBedFile = args[3] + db = args[4] + makeBigBed(inDir, outDir, geneBedFile, db) + +elif cmd == "load": + inDir = args[1] + db = args[2] + loadTables(inDir, db) + +# --- DEBUGGING / TESTING ---- + +elif cmd=="sumBasic": # for debugging + inFname = args[1] + rows = [] + for row in lineFileNext(open(inFname)): + rows.append(row) + textPairs = indexPairs(rows, "text mining") + for pair, sent in runSumBasic(textPairs, options.wordFname).iteritems(): + print pair, sent + +elif cmd == "graph": + outDir = args[1] + convGraph(outDir) + +elif cmd == "graph2": + outDir = args[1] + convGraph2(outDir) + +elif cmd == "graph3": + outDir = args[1] + convGraph3(outDir) + +elif cmd == "loadgraph": + outDir = args[1] + loadGraph(outDir) + +elif cmd == "load2": + outDir = args[1] + loadGraph2(outDir) + +elif cmd == "load3": + outDir = args[1] + loadGraph3(outDir) + +elif cmd == "subnet": + outDir, geneFile = args[1:] + geneLinks = loadGraph3(outDir) + genes = set(open(geneFile).read().splitlines()) + + print len(findBestPaths(genes, geneLinks)) + +else: + logging.error("unknown command %s" % cmd) +