Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformer.compile improvements #480

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 192 additions & 82 deletions pyterrier/ops.py

Large diffs are not rendered by default.

18 changes: 14 additions & 4 deletions pyterrier/terrier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pyterrier.terrier import java
from pyterrier.terrier._text_loader import TerrierTextLoader, terrier_text_loader
from pyterrier.terrier.java import configure, set_version, set_helper_version, extend_classpath, J, set_property, set_properties, run, version, check_version, check_helper_version
from pyterrier.terrier.retriever import RetrieverBase, Retriever, FeaturesRetriever, TextScorer
from pyterrier.terrier.retriever import Retriever, FeaturesRetriever, TextScorer
from pyterrier.terrier.index_factory import IndexFactory
from pyterrier.terrier.stemmer import TerrierStemmer
from pyterrier.terrier.tokeniser import TerrierTokeniser
Expand All @@ -11,6 +11,7 @@
from pyterrier.terrier.index import TerrierIndexer, FilesIndexer, TRECCollectionIndexer, DFIndexer, DFIndexUtils, IterDictIndexer, IndexingType, treccollection2textgen
from pyterrier.terrier import rewrite
from deprecated import deprecated
import pyterrier as pt


@deprecated(version='0.11.0', reason="use pt.terrier.Retriever() instead")
Expand Down Expand Up @@ -40,9 +41,18 @@ def from_dataset(*args, **kwargs):
return FeaturesRetriever.from_dataset(*args, **kwargs)


@deprecated(version='0.11.0', reason="use pt.terrier.RetrieverBase() instead")
class BatchRetrieveBase(RetrieverBase):
pass
@deprecated(version='0.12.0', reason="This class provides no functionality; inherit from pt.Transformer and set a verbose flag in your constructor instead")
class RetrieverBase(pt.Transformer):
def __init__(self, verbose=0, **kwargs):
super().__init__(kwargs)
self.verbose = verbose


@deprecated(version='0.12.0', reason="This class provides no functionality; inherit from pt.Transformer and set a verbose flag in your constructor instead")
class BatchRetrieveBase(pt.Transformer):
def __init__(self, verbose=0, **kwargs):
super().__init__(kwargs)
self.verbose = verbose


__all__ = [
Expand Down
162 changes: 76 additions & 86 deletions pyterrier/terrier/retriever.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Union
from typing import Union, Optional
import pandas as pd
import numpy as np
from deprecated import deprecated
from warnings import warn
from pyterrier.datasets import Dataset
from pyterrier.transformer import Symbol
from pyterrier.model import coerce_queries_dataframe, FIRST_RANK
import concurrent
from concurrent.futures import ThreadPoolExecutor
Expand Down Expand Up @@ -89,19 +88,9 @@ def _parse_index_like(index_location):
or an pyterrier.index.TerrierIndexer object'''
)

class RetrieverBase(pt.Transformer, Symbol):
"""
A base class for retrieval

Attributes:
verbose(bool): If True transform method will display progress
"""
def __init__(self, verbose=0, **kwargs):
super().__init__(kwargs)
self.verbose = verbose


@pt.java.required
class Retriever(RetrieverBase):
class Retriever(pt.Transformer):
"""
Use this class for retrieval by Terrier
"""
Expand Down Expand Up @@ -171,7 +160,7 @@ def from_dataset(dataset : Union[str,Dataset],
"termpipelines": "Stopwords,PorterStemmer"
}

def __init__(self, index_location, controls=None, properties=None, metadata=["docno"], num_results=None, wmodel=None, threads=1, **kwargs):
def __init__(self, index_location, controls=None, properties=None, metadata=["docno"], num_results=None, wmodel=None, threads=1, verbose=False):
"""
Init method

Expand All @@ -183,7 +172,6 @@ def __init__(self, index_location, controls=None, properties=None, metadata=["do
num_results(int): Number of results to retrieve.
metadata(list): What metadata to retrieve
"""
super().__init__(kwargs)
self.indexref = _parse_index_like(index_location)
self.properties = _mergeDicts(Retriever.default_properties, properties)
self.concurrentIL = pt.java.autoclass("org.terrier.structures.ConcurrentIndexLoader")
Expand All @@ -193,6 +181,7 @@ def __init__(self, index_location, controls=None, properties=None, metadata=["do
self.threads = threads
self.RequestContextMatching = pt.java.autoclass("org.terrier.python.RequestContextMatching")
self.search_context = {}
self.verbose = verbose

for key, value in self.properties.items():
pt.terrier.J.ApplicationSetup.setProperty(str(key), str(value))
Expand Down Expand Up @@ -458,6 +447,30 @@ def setControls(self, controls):
def setControl(self, control, value):
self.controls[str(control)] = str(value)

def fuse_rank_cutoff(self, k: int) -> Optional[pt.Transformer]:
"""
Support fusing with RankCutoffTransformer.
"""
if self.controls.get('end', float('inf')) < k:
return self # the applied rank cutoff is greater than the one already applied
if self.controls.get('context_wmodel') == 'on':
return None # we don't store the original wmodel value so we can't reconstruct
# apply the new k as num_results
return Retriever(self.indexref, controls=self.controls, properties=self.properties, metadata=self.metadata,
num_results=k, wmodel=self.controls["wmodel"], threads=self.threads, verbose=self.verbose)

def fuse_feature_union(self, other: pt.Transformer, is_left: bool) -> Optional[pt.Transformer]:
if isinstance(other, Retriever) and \
self.indexref == other.indexref and \
self.controls.get('context_wmodel') != 'on' and \
other.controls.get('context_wmodel') != 'on':
features = ["WMODEL:" + self.controls['wmodel'], "WMODEL:" + other.controls['wmodel']] if is_left else ["WMODEL:" + other.controls['wmodel'], "WMODEL:" + self.controls['wmodel']]
controls = dict(self.controls)
del controls['wmodel']
return FeaturesRetriever(self.indexref, features, controls=controls, properties=self.properties,
metadata=self.metadata, threads=self.threads, verbose=self.verbose)


@pt.java.required
class TextIndexProcessor(pt.Transformer):
'''
Expand Down Expand Up @@ -613,7 +626,7 @@ def __init__(self, index_location, features, controls=None, properties=None, thr

# record the weighting model
self.wmodel = None
if "wmodel" in kwargs:
if "wmodel" in kwargs and kwargs['wmodel'] is not None:
assert isinstance(kwargs["wmodel"], str), "Non-string weighting models not yet supported by FBR"
self.wmodel = kwargs["wmodel"]
if "wmodel" in controls:
Expand All @@ -627,6 +640,8 @@ def __init__(self, index_location, features, controls=None, properties=None, thr
raise ValueError("Multi-threaded retrieval not yet supported by FeaturesRetriever")

super().__init__(index_location, controls, properties, **kwargs)
if self.wmodel is None and 'wmodel' in self.controls:
del self.controls['wmodel'] # Retriever sets a default controls['wmodel'], we only want this
cmacdonald marked this conversation as resolved.
Show resolved Hide resolved

def __reduce__(self):
return (
Expand Down Expand Up @@ -803,72 +818,47 @@ def __str__(self):
return "TerrierFeatRetr(" + str(len(self.features)) + " features)"
return "TerrierFeatRetr(" + self.controls["wmodel"] + " and " + str(len(self.features)) + " features)"

rewrites_setup = False

def setup_rewrites():
from pyterrier.transformer import rewrite_rules
from pyterrier.ops import FeatureUnionPipeline, ComposedPipeline
from matchpy import ReplacementRule, Wildcard, Pattern, CustomConstraint
#three arbitrary "things".
x = Wildcard.dot('x')
xs = Wildcard.plus('xs')
y = Wildcard.dot('y')
z = Wildcard.dot('z')
# two different match retrives
_br1 = Wildcard.symbol('_br1', Retriever)
_br2 = Wildcard.symbol('_br2', Retriever)
_fbr = Wildcard.symbol('_fbr', FeaturesRetriever)

# batch retrieves for the same index
BR_index_matches = CustomConstraint(lambda _br1, _br2: _br1.indexref == _br2.indexref)
BR_FBR_index_matches = CustomConstraint(lambda _br1, _fbr: _br1.indexref == _fbr.indexref)

# rewrite nested binary feature unions into one single polyadic feature union
rewrite_rules.append(ReplacementRule(
Pattern(FeatureUnionPipeline(x, FeatureUnionPipeline(y,z)) ),
lambda x, y, z: FeatureUnionPipeline(x,y,z)
))
rewrite_rules.append(ReplacementRule(
Pattern(FeatureUnionPipeline(FeatureUnionPipeline(x,y), z) ),
lambda x, y, z: FeatureUnionPipeline(x,y,z)
))
rewrite_rules.append(ReplacementRule(
Pattern(FeatureUnionPipeline(FeatureUnionPipeline(x,y), xs) ),
lambda x, y, xs: FeatureUnionPipeline(*[x,y]+list(xs))
))

# rewrite nested binary compose into one single polyadic compose
rewrite_rules.append(ReplacementRule(
Pattern(ComposedPipeline(x, ComposedPipeline(y,z)) ),
lambda x, y, z: ComposedPipeline(x,y,z)
))
rewrite_rules.append(ReplacementRule(
Pattern(ComposedPipeline(ComposedPipeline(x,y), z) ),
lambda x, y, z: ComposedPipeline(x,y,z)
))
rewrite_rules.append(ReplacementRule(
Pattern(ComposedPipeline(ComposedPipeline(x,y), xs) ),
lambda x, y, xs: ComposedPipeline(*[x,y]+list(xs))
))

# rewrite batch a feature union of BRs into an FBR
rewrite_rules.append(ReplacementRule(
Pattern(FeatureUnionPipeline(_br1, _br2), BR_index_matches),
lambda _br1, _br2: FeaturesRetriever(_br1.indexref, ["WMODEL:" + _br1.controls["wmodel"], "WMODEL:" + _br2.controls["wmodel"]])
))

def push_fbr_earlier(_br1, _fbr):
#TODO copy more attributes
_fbr.wmodel = _br1.controls["wmodel"]
return _fbr

# rewrite a BR followed by a FBR into a FBR
rewrite_rules.append(ReplacementRule(
Pattern(ComposedPipeline(_br1, _fbr), BR_FBR_index_matches),
push_fbr_earlier
))

global rewrites_setup
rewrites_setup = True

setup_rewrites()
def fuse_left(self, left: pt.Transformer) -> Optional[pt.Transformer]:
# Can merge Retriever >> FeaturesRetriever into a single FeaturesRetriever that also retrieves
# if the indexref matches and the current FeaturesRetriever isn't already reranking.
if isinstance(left, Retriever) and \
self.indexref == left.indexref and \
left.controls.get('context_wmodel') != 'on' and \
self.wmodel is None:
return FeaturesRetriever(
self.indexref,
self.features,
controls=self.controls,
properties=self.properties,
threads=self.threads,
wmodel=left.controls['wmodel'],
)

def fuse_rank_cutoff(self, k: int) -> Optional[pt.Transformer]:
"""
Support fusing with RankCutoffTransformer.
"""
if self.controls.get('end', float('inf')) < k:
return self # the applied rank cutoff is greater than the one already applied
if self.wmodel is None:
return None # not a retriever
# apply the new k as num_results
return FeaturesRetriever(self.indexref, self.features, controls=self.controls, properties=self.properties,
threads=self.threads, wmodel=self.wmodel, verbose=self.verbose, num_results=k)

def fuse_feature_union(self, other: pt.Transformer, is_left: bool) -> Optional[pt.Transformer]:
if isinstance(other, FeaturesRetriever) and \
self.indexref == other.indexref and \
self.wmodel is None and \
other.wmodel is None:
features = self.features + other.features if is_left else other.features + self.features
return FeaturesRetriever(self.indexref, features, controls=self.controls, properties=self.properties,
threads=self.threads, wmodel=self.wmodel, verbose=self.verbose)

if isinstance(other, Retriever) and \
self.indexref == other.indexref and \
self.wmodel is None and \
other.controls.get('context_wmodel') != 'on':
features = self.features + ["WMODEL:" + other.controls['wmodel']] if is_left else ["WMODEL:" + other.controls['wmodel']] + self.features
return FeaturesRetriever(self.indexref, features, controls=self.controls, properties=self.properties,
threads=self.threads, wmodel=self.wmodel, verbose=self.verbose)
Loading
Loading