Skip to content

Commit

Permalink
Move get_run_instances behavior from Adapter to runner.py (stanford-c…
Browse files Browse the repository at this point in the history
…rfm#1985)

Co-authored-by: Yifan Mai <yifan@cs.stanford.edu>
  • Loading branch information
brianwgoldman and yifanmai authored Nov 20, 2023
1 parent e98197c commit f264d65
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 44 deletions.
40 changes: 1 addition & 39 deletions src/helm/benchmark/adaptation/adapters/adapter.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from abc import ABC, abstractmethod
from typing import List

import numpy as np

from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.benchmark.adaptation.scenario_state import ScenarioState
from helm.benchmark.scenarios.scenario import Instance, TRAIN_SPLIT, EVAL_SPLITS
from helm.benchmark.scenarios.scenario import Instance
from helm.benchmark.window_services.tokenizer_service import TokenizerService
from helm.benchmark.window_services.window_service import WindowService
from helm.benchmark.window_services.window_service_factory import WindowServiceFactory
from helm.common.hierarchical_logger import hlog


class Adapter(ABC):
Expand All @@ -31,38 +28,3 @@ def adapt(self, instances: List[Instance], parallelism: int) -> ScenarioState:
list of corresponding `RequestState`s.
"""
pass

def get_run_instances(self, instances: List[Instance]) -> List[Instance]:
"""
Get the instances necessary for this run:
Train instances (split=train): keep all (if any) for in-context learning
Eval instances (split=valid or test): keep at most `max_eval_instances` specified in `AdapterSpec` by sampling
Return the resulting train and eval instances.
"""
all_train_instances: List[Instance] = [instance for instance in instances if instance.split == TRAIN_SPLIT]

all_eval_instances: List[Instance] = [instance for instance in instances if instance.split in EVAL_SPLITS]
if (
self.adapter_spec.max_eval_instances is not None
and len(all_eval_instances) > self.adapter_spec.max_eval_instances
):
# Pick the first `self.adapter_spec.max_eval_instances`.
# The random sampling includes instances monotonically.
np.random.seed(0)
selected_eval_instances = list(
np.random.choice(
all_eval_instances, # type: ignore
self.adapter_spec.max_eval_instances,
replace=False,
)
)
else:
selected_eval_instances = all_eval_instances

hlog(
f"{len(instances)} instances, "
f"{len(all_train_instances)} train instances, "
f"{len(selected_eval_instances)}/{len(all_eval_instances)} eval instances"
)

return all_train_instances + selected_eval_instances
51 changes: 46 additions & 5 deletions src/helm/benchmark/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,23 @@
import dataclasses
from dataclasses import dataclass, field
from typing import Any, Dict, List
import numpy as np

from tqdm import tqdm

from helm.common.general import ensure_directory_exists, write, asdict_without_nones
from helm.common.hierarchical_logger import hlog, htrack_block
from helm.common.cache import cache_stats
from .augmentations.data_augmenter import DataAugmenterSpec
from .scenarios.scenario import Scenario, ScenarioSpec, create_scenario, Instance, with_instance_ids
from .scenarios.scenario import (
EVAL_SPLITS,
TRAIN_SPLIT,
Scenario,
ScenarioSpec,
create_scenario,
Instance,
with_instance_ids,
)
from .adaptation.adapters.adapter import Adapter
from .adaptation.adapters.adapter_factory import AdapterFactory
from .adaptation.scenario_state import ScenarioState
Expand Down Expand Up @@ -103,6 +112,38 @@ def remove_per_instance_stats_nans(per_instance_stats_list: List[PerInstanceStat
return result


def downsample_eval_instances(instances: List[Instance], max_eval_instances: int) -> List[Instance]:
"""
Get the instances necessary for this run:
Train instances (split=train): keep all (if any) for in-context learning
Eval instances (split=valid or test): keep at most `max_eval_instances` specified in `AdapterSpec` by sampling
Return the resulting train and eval instances.
"""
all_train_instances: List[Instance] = [instance for instance in instances if instance.split == TRAIN_SPLIT]

all_eval_instances: List[Instance] = [instance for instance in instances if instance.split in EVAL_SPLITS]
if len(all_eval_instances) > max_eval_instances:
# The random sampling includes instances monotonically.
np.random.seed(0)
selected_eval_instances = list(
np.random.choice(
all_eval_instances, # type: ignore
max_eval_instances,
replace=False,
)
)
else:
selected_eval_instances = all_eval_instances

hlog(
f"{len(instances)} instances, "
f"{len(all_train_instances)} train instances, "
f"{len(selected_eval_instances)}/{len(all_eval_instances)} eval instances"
)

return all_train_instances + selected_eval_instances


class Runner:
"""
The main entry point for running the entire benchmark. Mostly just
Expand Down Expand Up @@ -203,9 +244,6 @@ def run_one(self, run_spec: RunSpec):
input_instances_output_path = os.path.join(self.instances_path, scenario_name_with_args)
input_instances_file_path = os.path.join(input_instances_output_path, "input_instances.json")

# Fetch and initialize the Adapter based on the `AdapterSpec`.
adapter: Adapter = AdapterFactory.get_adapter(run_spec.adapter_spec, self.tokenizer_service)

instances: List[Instance]
if self.skip_instances:
instances = []
Expand All @@ -232,14 +270,17 @@ def run_one(self, run_spec: RunSpec):
instances = with_instance_ids(instances)

# Get the instances necessary for this run.
instances = adapter.get_run_instances(instances)
max_eval_instances = run_spec.adapter_spec.max_eval_instances
if max_eval_instances is not None:
instances = downsample_eval_instances(instances, max_eval_instances)

# Data preprocessing
instances = DataPreprocessor(run_spec.data_augmenter_spec).preprocess(
instances, self.executor.execution_spec.parallelism
)

# Adapt (convert to requests)
adapter: Adapter = AdapterFactory.get_adapter(run_spec.adapter_spec, self.tokenizer_service)
scenario_state: ScenarioState = adapter.adapt(instances, self.executor.execution_spec.parallelism)

# Execute (fill up results)
Expand Down

0 comments on commit f264d65

Please sign in to comment.