diff --git a/modules/core/src/main/java/org/terrier/querying/RM1.java b/modules/core/src/main/java/org/terrier/querying/RM1.java new file mode 100644 index 00000000..e3d8a09c --- /dev/null +++ b/modules/core/src/main/java/org/terrier/querying/RM1.java @@ -0,0 +1,294 @@ +package org.terrier.querying; + +import static java.util.stream.Collectors.toMap; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.terrier.matching.BaseMatching; +import org.terrier.matching.MatchingQueryTerms; +import org.terrier.matching.ResultSet; +import org.terrier.matching.matchops.SingleTermOp; +import org.terrier.querying.parser.Query.QTPBuilder; +import org.terrier.structures.Index; +import org.terrier.structures.LexiconEntry; +import org.terrier.structures.postings.IterablePosting; +import org.terrier.utility.ApplicationSetup; + +import it.unimi.dsi.fastutil.ints.Int2FloatMap; +import it.unimi.dsi.fastutil.ints.Int2FloatOpenHashMap; +import it.unimi.dsi.fastutil.ints.Int2IntMap; +import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; +import it.unimi.dsi.fastutil.objects.ObjectArrayList; + +/** + * RM1 preliminary implementation + * + * See: http://people.cs.vt.edu/~jiepu/cs5604_fall2018/10_qm.pdf + * + * @author Nicola Tonellotto + */ +@ProcessPhaseRequisites({ManagerRequisite.MQT, ManagerRequisite.RESULTSET}) +public class RM1 implements MQTRewritingProcess +{ + + protected static Logger logger = LoggerFactory.getLogger(RM1.class); + + /** + * This class represents a simple expansion term struct. + * + * @author Nicola Tonellotto + */ + static class ExpansionTerm + { + protected int termid; + protected String text; + protected double weight; + + public ExpansionTerm(final int termid, final String text, final double weight) + { + this.termid = termid; + this.text = text; + this.weight = weight; + } + } + + /** + * This class represents one of the top documents used to perform pseudo-relevance feedback. + * It is composed by a list of terms composing the document, with the term frequency associated to each term in doc, + * its original score, computed by any matching model, e.g., BM25, and its length, i.e., the sum of the term frequencies. + * + * @author Nicola Tonellotto + */ + public class FeedbackDocument + { + protected final int MIN_DF = Integer.parseInt(ApplicationSetup.getProperty("prf.mindf", "2")); + + // if a term appears in more than 10% of documents, we ignore it + protected final double MAX_DOC_PERCENTAGE = Float.parseFloat(ApplicationSetup.getProperty("prf.maxdp", "0.1")); + + // termid -> term frequency in document map + protected Int2IntMap terms; + + protected int length; + protected double originalScore; + protected double qlScore; + + public FeedbackDocument(final int docid, final double originalScore, final Index index) throws IOException + { + this.originalScore = originalScore; + + this.terms = new Int2IntOpenHashMap(); + + final int MAX_DOC_FREQ = (int) (MAX_DOC_PERCENTAGE * index.getCollectionStatistics().getNumberOfDocuments()); + final IterablePosting dp = index.getDirectIndex().getPostings(index.getDocumentIndex().getDocumentEntry(docid)); + while (dp.next() != IterablePosting.EOL) { + this.length = dp.getDocumentLength(); //this supports the terrier-lucene better. + LexiconEntry le = index.getLexicon().getLexiconEntry(dp.getId()).getValue(); + if (le.getDocumentFrequency() >= MIN_DF && le.getDocumentFrequency() < MAX_DOC_FREQ) + this.terms.put(dp.getId(), dp.getFrequency()); + } + if (this.length > 0 && this.terms.size() == 0) { + logger.warn("Did not identify any usable candidate expansion terms from docid " + docid + " among " + this.length + " possibilities"); + } + dp.close(); + + //this.length = index.getDocumentIndex().getDocumentLength(docid); + assert this.length > 0; + + } + + public IntSet getTermIds() + { + return terms.keySet(); + } + + public long getFrequency(final int termid) + { + return terms.get(termid); + } + } + + protected final int fbTerms; + protected final int fbDocs; + protected Index index = null; + + protected IntSet topLexicon; + protected List topDocs; + protected Int2FloatMap feedbackTermScores; + + protected double lambda = 1.0; + + /** + * Constructor + * + * @param fbTerms how many feedback terms to return + * @param fbDocs how many feedback documents to use (should be less than or equal to the top documents) + * @param index the index to used to access the direct index postings + */ + public RM1(final int fbTerms, final int fbDocs, final Index index) + { + this.fbTerms = fbTerms; + this.fbDocs = fbDocs; + this.index = index; + + this.topLexicon = new IntOpenHashSet(); + this.topDocs = new ObjectArrayList<>(); + this.feedbackTermScores = new Int2FloatOpenHashMap(); + } + + public RM1() + { + this.topLexicon = new IntOpenHashSet(); + this.topDocs = new ObjectArrayList<>(); + this.feedbackTermScores = new Int2FloatOpenHashMap(); + this.fbTerms = ApplicationSetup.EXPANSION_TERMS; + this.fbDocs = ApplicationSetup.EXPANSION_DOCUMENTS; + } + + public void process(Manager manager, Request q) { + try{ + this.expandQuery(q.getMatchingQueryTerms(), q); + + //THIS ASSUMES THAT QueryExpansion directly follows Matching + ((LocalManager)manager).runNamedProcess(q.getControl("previousprocess"), q); + }catch (IOException ioe) { + throw new RuntimeException(ioe); + } + } + + /** MQTRewriting implementation. */ + public boolean expandQuery(MatchingQueryTerms mqt, Request rq) throws IOException + { + this.index = rq.getIndex(); + List expansions = this.expand(rq); + mqt.clear(); + StringBuilder sQuery = new StringBuilder(); + for (ExpansionTerm et : expansions) + { + mqt.add(QTPBuilder.of(new SingleTermOp(et.text)) + .setWeight(et.weight) + .setTag(BaseMatching.BASE_MATCHING_TAG) + .build()); + sQuery.append(et.text + "^" + et.weight + " "); + } + logger.info("Reformulated query: " + sQuery.toString()); + //logger.info("Reformulated query: " + mqt.toString()); + return true; + } + + /** + * This method computes a list of expansion terms from a given search request from Terrier + * + * @param srq the processed search request from Terrier containing the top documents' docids and scores + * + * @return a list of expansion terms + * + * @throws IOException if there are problems in accessing the direct index + */ + public List expand(final Request srq) throws IOException + { + this.topLexicon.clear(); + this.topDocs.clear(); + this.feedbackTermScores.clear(); + + retrieveTopDocuments(srq.getResultSet()); + computeFeedbackTermScores(); + + clipTerms(); + normalizeFeedbackTermScores(); + + List rtr = new ObjectArrayList<>(); + for (int termid: feedbackTermScores.keySet()) + rtr.add(new ExpansionTerm(termid, index.getLexicon().getLexiconEntry(termid).getKey(), feedbackTermScores.get(termid))); + return rtr; + } + + /** + * This method retrieves from the direct index all terms if the top documents with the necessary statistics. + + * @param rs the search request returned by Terrier with top documents' docids & scores + * + * @throws IOException if there are problems in accessing the direct index + */ + protected void retrieveTopDocuments(final ResultSet rs) throws IOException + { + final int numDocs = rs.getResultSize() < fbDocs ? rs.getResultSize() : fbDocs; + final double norm = logSumExp(rs.getScores()); + for (int i = 0; i < numDocs; ++i) { + FeedbackDocument doc = new FeedbackDocument(rs.getDocids()[i], Math.exp(rs.getScores()[i] - norm), index); + topDocs.add(doc); + topLexicon.addAll(doc.getTermIds()); + } + if (topLexicon.size() > 0) { + logger.info("Found " + topLexicon.size() + " terms after feedback document analysis"); + } else { + logger.warn("Did not find any useful candidate expansion terms after analysis of "+ numDocs + " feedback documents"); + } + } + + /** + * This method computes the relevance scores of all terms in the top documents according to RM1 + */ + protected void computeFeedbackTermScores() + { + for (int termid: topLexicon) { + float fbWeight = 0.0f; + for (FeedbackDocument doc: topDocs) + fbWeight += (double) doc.getFrequency(termid) / (double) doc.length * doc.originalScore; + feedbackTermScores.put(termid, fbWeight * (1.0f/topDocs.size())); //see galago line 231 in scoreGrams(). + } + } + + /** + * This method reduces the number of feedback terms to a fixed amount + */ + protected void clipTerms() + { + feedbackTermScores = feedbackTermScores + .int2FloatEntrySet() + .stream() + .sorted( + Map.Entry.comparingByValue().reversed() // sort by descending weight + .thenComparing(Map.Entry.comparingByKey()) // tie-break by ascending termid + ) + .limit(fbTerms) + .collect(toMap(Map.Entry::getKey, + Map.Entry::getValue, + (e1, e2) -> e2, + Int2FloatOpenHashMap::new) + ); + } + + /** + * This method transforms the feedback term scores into a probability distribution + */ + protected void normalizeFeedbackTermScores() + { + float norm = feedbackTermScores.values().stream().reduce(0.0f, Float::sum); + feedbackTermScores.replaceAll((termid, score) -> score / norm); + } + + private static double logSumExp(final double[] scores) + { + double max = Double.NEGATIVE_INFINITY; + for (double score : scores) + max = Math.max(score, max); + + double sum = 0.0; + for (int i = 0; i < scores.length; i++) + sum += Math.exp(scores[i] - max); + + return max + Math.log(sum); + } + + public void setLambda(double value) { + this.lambda = value; + } + +} diff --git a/modules/core/src/main/java/org/terrier/querying/RM3.java b/modules/core/src/main/java/org/terrier/querying/RM3.java new file mode 100644 index 00000000..6133d924 --- /dev/null +++ b/modules/core/src/main/java/org/terrier/querying/RM3.java @@ -0,0 +1,128 @@ +package org.terrier.querying; + +import java.io.IOException; +import java.util.List; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.terrier.matching.BaseMatching; +import org.terrier.matching.MatchingQueryTerms; +import org.terrier.matching.MatchingQueryTerms.MatchingTerm; +import org.terrier.matching.matchops.SingleTermOp; +import org.terrier.querying.parser.Query.QTPBuilder; +import org.terrier.structures.Index; +import org.terrier.structures.LexiconEntry; + +import it.unimi.dsi.fastutil.ints.Int2FloatMap; +import it.unimi.dsi.fastutil.ints.Int2FloatOpenHashMap; +import it.unimi.dsi.fastutil.objects.ObjectArrayList; + +/** + * RM3 implementation. This has been closely compared to the Anserini implementation using a common index. + * + * @author Nicola Tonellotto and Craig Macdonald + */ +@ProcessPhaseRequisites({ ManagerRequisite.MQT, ManagerRequisite.RESULTSET }) +public class RM3 extends RM1 { + + protected static Logger logger = LoggerFactory.getLogger(RM3.class); + + protected static final float DEFAULT_LAMBDA = 0.6F; + protected Int2FloatMap originalQueryTermScores; + protected float lambda; + + public RM3(final int fbTerms, final int fbDocs, final Index index) { + this(fbTerms, fbDocs, index, DEFAULT_LAMBDA); + } + + public RM3(final int fbTerms, final int fbDocs, final Index index, final float lambda) { + super(fbTerms, fbDocs, index); + + this.originalQueryTermScores = new Int2FloatOpenHashMap(); + this.lambda = lambda; + } + + public RM3() { + super(); + this.originalQueryTermScores = new Int2FloatOpenHashMap(); + this.lambda = DEFAULT_LAMBDA; + } + + public boolean expandQuery(MatchingQueryTerms mqt, Request rq) throws IOException { + this.index = rq.getIndex(); + computeOriginalTermScore(mqt); + if (rq.hasControl("rm3.lambda")) + this.lambda = Float.parseFloat(rq.getControl("rm3.lambda")); + List expansions = this.expand(rq); + mqt.clear(); + StringBuilder sQuery = new StringBuilder(); + for (ExpansionTerm et : expansions) { + mqt.add(QTPBuilder.of(new SingleTermOp(et.text)).setTag(BaseMatching.BASE_MATCHING_TAG) + .setWeight(et.weight).build()); + sQuery.append(et.text + "^" + et.weight + " "); + } + logger.info("Reformulated query "+ mqt.getQueryId() +" @ lambda="+this.lambda+": " + sQuery.toString()); + //logger.info("Reformulated query: " + mqt.toString()); + return true; + } + + protected void computeOriginalTermScore(final MatchingQueryTerms mqt) { + this.originalQueryTermScores.clear(); + final float queryLength = (float) mqt.stream().map(mt -> mt.getValue().getWeight()) + .mapToDouble(Double::doubleValue).sum(); + for (MatchingTerm mt : mqt) { + + LexiconEntry le = super.index.getLexicon().getLexiconEntry(mt.getKey().toString()); + if (le == null) + continue; + int termid = le.getTermId(); + float termCount = (float) mt.getValue().getWeight(); + originalQueryTermScores.put(termid, termCount / queryLength); + } + } + + @Override + protected void computeFeedbackTermScores() { + super.computeFeedbackTermScores(); + super.clipTerms(); + super.normalizeFeedbackTermScores(); + + for (int termid : feedbackTermScores.keySet()) { + //System.err.println("termid " + termid + " term " + super.index.getLexicon().getLexiconEntry(termid).getKey() +" " + feedbackTermScores.get(termid)); + if (originalQueryTermScores.containsKey(termid)) { + //System.err.println("termid " + termid + " term " + super.index.getLexicon().getLexiconEntry(termid).getKey() +" " +"not new: old weight = " + originalQueryTermScores.get(termid) + " fbweight=" + feedbackTermScores.get(termid)); + + float weight = lambda * originalQueryTermScores.get(termid) + + (1 - lambda) * feedbackTermScores.get(termid); + feedbackTermScores.put(termid, weight); + } else { + feedbackTermScores.put(termid, (1 - lambda) * feedbackTermScores.get(termid)); + //System.err.println("termid " + termid + " term " + super.index.getLexicon().getLexiconEntry(termid).getKey() +" " + feedbackTermScores.get(termid)); + } + } + + for (int termid : originalQueryTermScores.keySet()) { + if (!feedbackTermScores.containsKey(termid)) { + float weight = lambda * originalQueryTermScores.get(termid); + feedbackTermScores.put(termid, weight); + } + } + } + + @Override + public List expand(Request srq) throws IOException { + //return super.expand(srq); + + this.topLexicon.clear(); + this.topDocs.clear(); + this.feedbackTermScores.clear(); + + retrieveTopDocuments(srq.getResultSet()); + computeFeedbackTermScores(); + + List rtr = new ObjectArrayList<>(); + for (int termid: feedbackTermScores.keySet()) + rtr.add(new ExpansionTerm(termid, index.getLexicon().getLexiconEntry(termid).getKey(), feedbackTermScores.get(termid))); + return rtr; + } +} diff --git a/modules/tests/src/test/java/org/terrier/querying/TestRM.java b/modules/tests/src/test/java/org/terrier/querying/TestRM.java new file mode 100644 index 00000000..76ddd1b7 --- /dev/null +++ b/modules/tests/src/test/java/org/terrier/querying/TestRM.java @@ -0,0 +1,42 @@ +package org.terrier.querying; + +import static org.junit.Assert.assertTrue; + +import org.junit.Test; +import org.terrier.indexing.IndexTestUtils; +import org.terrier.matching.BaseMatching; +import org.terrier.structures.Index; +import org.terrier.tests.ApplicationSetupBasedTest; +import org.terrier.utility.ApplicationSetup; + +public class TestRM extends ApplicationSetupBasedTest +{ + @Test public void testItWorksRM1() throws Exception + { + testModel("RM1"); + } + + @Test public void testItWorksRM3() throws Exception + { + testModel("RM3"); + } + + protected void testModel(String clzName) throws Exception + { + ApplicationSetup.setProperty("termpipelines", ""); + ApplicationSetup.setProperty("prf.mindf", "0"); + ApplicationSetup.setProperty("prf.maxdp", "1"); + ApplicationSetup.setProperty("querying.processes", + "terrierql:TerrierQLParser,parsecontrols:TerrierQLToControls,parseql:TerrierQLToMatchingQueryTerms,matchopql:MatchingOpQLParser,applypipeline:ApplyTermPipeline,localmatching:LocalManager$ApplyLocalMatching,rm:"+clzName+",qe:QueryExpansion,labels:org.terrier.learning.LabelDecorator,filters:LocalManager$PostFilterProcess"); + Index indx = IndexTestUtils.makeIndex( + new String[]{"doc1", "doc2"}, + new String[]{"the lazy fox jumped over the dog", "but had the presence of mind"}); + Manager m = ManagerFactory._from_(indx.getIndexRef()); + SearchRequest srq = m.newSearchRequest("testQ", "fox"); + srq.setControl("rm", "on"); + m.runSearchRequest(srq); + assertTrue( ((Request)srq).getMatchingQueryTerms().size() > 1); + assertTrue( ((Request)srq).getMatchingQueryTerms().get(0).getValue().getTags().contains(BaseMatching.BASE_MATCHING_TAG)); + } + +}