Skip to content

Commit

Permalink
lots of type fixes, mostly for tqdm
Browse files Browse the repository at this point in the history
  • Loading branch information
cmacdonald committed Sep 19, 2024
1 parent d44182a commit bbd09f4
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 22 deletions.
2 changes: 2 additions & 0 deletions pyterrier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

# will be set in terrier.terrier.java once java is loaded
IndexRef = None
# will be set in once utils.set_tqdm() once _() runs
tqdm = None


# deprecated functions explored to the main namespace, which will be removed in a future version
Expand Down
16 changes: 8 additions & 8 deletions pyterrier/apply_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
# batching
iterator = pt.model.split_df(inp, batch_size=self.batch_size)
if self.verbose:
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row')
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row') # type: ignore
return pd.concat([self._apply_df(chunk_df) for chunk_df in iterator])

def _apply_df(self, inp: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -148,7 +148,7 @@ def transform(self, res: pd.DataFrame) -> pd.DataFrame:
it = res.groupby("qid")
lastqid = None
if self.verbose:
it = pt.tqdm(it, unit='query')
it = pt.tqdm(it, unit='query') # type: ignore
try:
if self.batch_size is None:
query_dfs = []
Expand Down Expand Up @@ -275,7 +275,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:

iterator = pt.model.split_df(outputRes, batch_size=self.batch_size)
if self.verbose:
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row')
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row') # type: ignore
rtr = pd.concat([self._transform_batchwise(chunk_df) for chunk_df in iterator])
rtr = pt.model.add_ranks(rtr)
return rtr
Expand Down Expand Up @@ -313,7 +313,7 @@ def transform_iter(self, inp: pt.model.IterDict) -> pt.model.IterDict:
# we assume that the function can take a dictionary as well as a pandas.Series. As long as [""] notation is used
# to access fields, both should work
if self.verbose:
inp = pt.tqdm(inp, desc="pt.apply.doc_features")
inp = pt.tqdm(inp, desc="pt.apply.doc_features") # type: ignore
for row in inp:
row["features"] = self.fn(row)
yield row
Expand All @@ -322,7 +322,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
fn = self.fn
outputRes = inp.copy()
if self.verbose:
pt.tqdm.pandas(desc="pt.apply.doc_features", unit="d")
pt.tqdm.pandas(desc="pt.apply.doc_features", unit="d") # type: ignore
outputRes["features"] = outputRes.progress_apply(fn, axis=1)
else:
outputRes["features"] = outputRes.apply(fn, axis=1)
Expand Down Expand Up @@ -368,7 +368,7 @@ def transform_iter(self, inp: pt.model.IterDict) -> pt.model.IterDict:
# we assume that the function can take a dictionary as well as a pandas.Series. As long as [""] notation is used
# to access fields, both should work
if self.verbose:
inp = pt.tqdm(inp, desc="pt.apply.query")
inp = pt.tqdm(inp, desc="pt.apply.query") # type: ignore
for row in inp:
row = row.copy()
if "query" in row:
Expand All @@ -384,7 +384,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
outputRes = inp.copy()
try:
if self.verbose:
pt.tqdm.pandas(desc="pt.apply.query", unit="d")
pt.tqdm.pandas(desc="pt.apply.query", unit="d") # type: ignore
outputRes["query"] = outputRes.progress_apply(self.fn, axis=1)
else:
outputRes["query"] = outputRes.apply(self.fn, axis=1)
Expand Down Expand Up @@ -444,7 +444,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
# batching
iterator = pt.model.split_df(inp, batch_size=self.batch_size)
if self.verbose:
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row')
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row') # type: ignore
rtr = pd.concat([self.fn(chunk_df) for chunk_df in iterator])
return rtr

Expand Down
6 changes: 3 additions & 3 deletions pyterrier/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def download(URLs : Union[str,List[str]], filename : str, **kwargs):
r = requests.get(url, allow_redirects=True, stream=True, **kwargs)
r.raise_for_status()
total = int(r.headers.get('content-length', 0))
with pt.io.finalized_open(filename, 'b') as file, pt.tqdm(
with pt.io.finalized_open(filename, 'b') as file, pt.tqdm( # type: ignore
desc=basename,
total=total,
unit='iB',
Expand Down Expand Up @@ -610,7 +610,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
set_docnos = set(docnos)
it = (tuple(getattr(doc, f) for f in fields) for doc in docstore.get_many_iter(set_docnos))
if self.verbose:
it = pd.tqdm(it, unit='d', total=len(set_docnos), desc='IRDSTextLoader')
it = pt.tqdm(it, unit='d', total=len(set_docnos), desc='IRDSTextLoader') # type: ignore
metadata = pd.DataFrame(list(it), columns=fields).set_index('doc_id')
metadata_frame = metadata.loc[docnos].reset_index(drop=True)

Expand Down Expand Up @@ -1104,7 +1104,7 @@ def _merge_years(self, component, variant):
"corpus_iter" : lambda dataset, **kwargs : pt.index.treccollection2textgen(dataset.get_corpus(), num_docs=11429, verbose=kwargs.get("verbose", False))
}

DATASET_MAP = {
DATASET_MAP : Dict[str, Dataset] = {
# used for UGlasgow teaching
"50pct" : RemoteDataset("50pct", FIFTY_PCT_FILES),
# umass antique corpus - see http://ciir.cs.umass.edu/downloads/Antique/
Expand Down
30 changes: 19 additions & 11 deletions pyterrier/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from . import Transformer
from .model import coerce_dataframe_types
import ir_measures
import tqdm as tqdm_module
from ir_measures.measures import BaseMeasure
import pyterrier as pt
MEASURE_TYPE=Union[str,BaseMeasure]
Expand Down Expand Up @@ -135,17 +136,17 @@ def _ir_measures_to_dict(
for m in seq:
metric = m.measure
metric = rev_mapping.get(metric, str(metric))
rtr[metric].add(m.value)
rtr[metric].add(m.value) # type: ignore # THERE is no typing for aggregators in ir_measures
for m in rtr:
rtr[m] = rtr[m].result()
rtr[m] = rtr[m].result() # type: ignore # THERE is no typing for aggregators in ir_measures
return rtr

def _run_and_evaluate(
system : SYSTEM_OR_RESULTS_TYPE,
topics : Optional[pd.DataFrame],
qrels: pd.DataFrame,
metrics : MEASURES_TYPE,
pbar : Optional[pt.tqdm] = None,
pbar : Optional[tqdm_module.tqdm] = None,
save_mode : Optional[SAVEMODE_TYPE] = None,
save_file : Optional[str] = None,
perquery : bool = False,
Expand All @@ -155,7 +156,7 @@ def _run_and_evaluate(
from .io import read_results, write_results

if pbar is None:
pbar = pt.tqdm(disable=True)
pbar = pt.tqdm(disable=True) # type: ignore

metrics, rev_mapping = _convert_measures(metrics)
qrels = qrels.rename(columns={'qid': 'query_id', 'docno': 'doc_id', 'label': 'relevance'})
Expand All @@ -178,12 +179,16 @@ def _run_and_evaluate(
else:
raise ValueError("Unknown save_mode argument '%s', valid options are 'error', 'warn', 'reuse' or 'overwrite'." % save_mode)

res : pd.DataFrame
# if its a DataFrame, use it as the results
if isinstance(system, pd.DataFrame):
res = system
res = coerce_dataframe_types(res)
if len(res) == 0:
raise ValueError("%d topics, but no results in dataframe" % len(topics))
if topics is None:
raise ValueError("No topics specified, and no results in dataframe")
else:
raise ValueError("%d topics, but no results in dataframe" % len(topics))
evalMeasuresDict = _ir_measures_to_dict(
ir_measures.iter_calc(metrics, qrels, res.rename(columns=_irmeasures_columns)),
metrics,
Expand All @@ -194,6 +199,8 @@ def _run_and_evaluate(
pbar.update()

elif batch_size is None:

assert topics is not None, "topics must be specified"
#transformer, evaluate all queries at once

starttime = timer()
Expand All @@ -219,13 +226,14 @@ def _run_and_evaluate(
backfill_qids)
pbar.update()
else:
assert topics is not None, "topics must be specified"

#transformer, evaluate queries in batches
assert batch_size > 0
starttime = timer()
evalMeasuresDict = {}
remaining_qrel_qids = set(qrels.query_id)
try:
res : pd.DataFrame
batch_topics : pd.DataFrame
for i, (res, batch_topics) in enumerate( system.transform_gen(topics, batch_size=batch_size, output_topics=True)):
if len(res) == 0:
Expand Down Expand Up @@ -474,7 +482,7 @@ def _apply_round(measure, value):
# round number of batches up for each system
tqdm_args['total'] = math.ceil((len(topics) / batch_size)) * len(retr_systems)

with pt.tqdm(**tqdm_args) as pbar:
with pt.tqdm(**tqdm_args) as pbar: # type: ignore
# run and evaluate each system
for name, system in zip(names, retr_systems):
save_file = None
Expand Down Expand Up @@ -523,15 +531,15 @@ def _apply_round(measure, value):
if dataframe:
if perquery:
df = pd.DataFrame(evalsRows, columns=["name", "qid", "measure", "value"]).sort_values(['name', 'qid'])
if round is not None:
if round is not None and isinstance(round, int):
df["value"] = df["value"].round(round)
return df

highlight_cols = { m : "+" for m in actual_metric_names }
if mrt_needed:
highlight_cols["mrt"] = "-"

p_col_names=[]
p_col_names : List[str] = []
if baseline is not None:
assert len(evalDictsPerQ) == len(retr_systems)
baselinePerQuery={}
Expand Down Expand Up @@ -570,7 +578,7 @@ def _apply_round(measure, value):

# multiple testing correction. This adds two new columns for each measure experience statistical significance testing
if baseline is not None and correction is not None:
import statsmodels.stats.multitest
import statsmodels.stats.multitest # type: ignore
for pcol in p_col_names:
pcol_reject = pcol.replace("p-value", "reject")
pcol_corrected = pcol + " corrected"
Expand Down Expand Up @@ -910,7 +918,7 @@ def _evaluate_several_settings(inputs : List[Tuple]):
eval_list = []
#for each combination of parameter values
if jobs == 1:
for v in pt.tqdm(combinations, total=len(combinations), desc="GridScan", mininterval=0.3) if verbose else combinations:
for v in pt.tqdm(combinations, total=len(combinations), desc="GridScan", mininterval=0.3) if verbose else combinations: # type: ignore
parameter_list, eval_scores = _evaluate_one_setting(keys, v)
eval_list.append( (parameter_list, eval_scores) )
else:
Expand Down

0 comments on commit bbd09f4

Please sign in to comment.