Skip to content

Commit

Permalink
Merge pull request #461 from terrier-org/module_loader
Browse files Browse the repository at this point in the history
dynamic module loading
  • Loading branch information
seanmacavaney authored Aug 20, 2024
2 parents 1fb1705 + 0342ad7 commit 299da7d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 20 deletions.
52 changes: 36 additions & 16 deletions pyterrier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,6 @@
cast = deprecated(version='0.11.0', reason="use pt.java.cast(...) instead")(java.cast)


# Additional setup performed in a function to avoid polluting the namespace with other imports like platform
def _():
# check python version
import platform
from packaging.version import Version
if Version(platform.python_version()) < Version('3.7.0'):
raise RuntimeError("From PyTerrier 0.8, Python 3.7 minimum is required, you currently have %s" % platform.python_version())

# apply is an object, not a module, as it also has __get_attr__() implemented
from pyterrier.apply import _apply
globals()['apply'] = _apply()

utils.set_tqdm()
_()

__all__ = [
'java', 'terrier', 'anserini', 'cache', 'debug', 'io', 'measures', 'model', 'new', 'ltr', 'parallel', 'pipelines',
'text', 'transformer', 'datasets', 'get_dataset', 'find_datasets', 'list_datasets', 'Experiment', 'GridScan',
Expand All @@ -71,9 +56,44 @@ def _():
'BatchRetrieve', 'TerrierRetrieve', 'FeaturesBatchRetrieve', 'IndexFactory',
'run', 'rewrite', 'index', 'FilesIndexer', 'TRECCollectionIndexer', 'DFIndexer', 'DFIndexUtils', 'IterDictIndexer',
'IndexingType', 'TerrierStemmer', 'TerrierStopwords', 'TerrierTokeniser',
'IndexRef', 'ApplicationSetup', 'properties', 'apply',
'IndexRef', 'ApplicationSetup', 'properties',

# Deprecated:
'init', 'started', 'logging', 'version', 'check_version', 'extend_classpath', 'set_tqdm', 'set_property', 'set_properties',
'redirect_stdouterr', 'autoclass', 'cast',

# Entry point modules (appended loaded below):
]


# Additional setup performed in a function to avoid polluting the namespace with other imports like platform
def _():
from warnings import warn
import platform
from packaging.version import Version

# check python version
if Version(platform.python_version()) < Version('3.7.0'):
raise RuntimeError("From PyTerrier 0.8, Python 3.7 minimum is required, you currently have %s" % platform.python_version())

globs = globals()

# Load the _apply object as pt.apply so that the dynamic __getattr__ methods work
from pyterrier.apply import _apply
globs['apply'] = _apply()
__all__.append('apply')

# load modules defined as package entry points into the global pyterrier namespace
for entry_point in utils.entry_points('pyterrier.modules'):
if entry_point.name in globs:
warn(f'skipping loading {entry_point} because a module with this name is already loaded.')
continue
module = entry_point.load()
if callable(module): # if the entry point refers to an function/class, call it to get the module
module = module()
globs[entry_point.name] = module
__all__.append(entry_point.name)

# guess the environment and set an appropriate tqdm as pt.tqdm
utils.set_tqdm()
_()
7 changes: 3 additions & 4 deletions pyterrier/ltr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@

import pyterrier as pt
from . import Transformer, Estimator
from .apply import doc_score, doc_features
from .model import add_ranks
from typing import Sequence, Union, Tuple
import numpy as np, pandas as pd
Expand Down Expand Up @@ -244,7 +243,7 @@ def feature_to_score(fid : int) -> Transformer:
Args:
fid: a single feature id that should be kept
"""
return doc_score(lambda row : row["features"][fid])
return pt.apply.doc_score(lambda row : row["features"][fid])

def apply_learned_model(learner, form : str = 'regression', **kwargs) -> Transformer:
"""
Expand Down Expand Up @@ -284,4 +283,4 @@ def score_to_feature() -> Transformer:
three_features = cands >> (bm25f ** pl2f ** pt.ltr.score_to_feature())
"""
return doc_features(lambda row : np.array(row["score"]))
return pt.apply.doc_features(lambda row : np.array(row["score"]))

0 comments on commit 299da7d

Please sign in to comment.