Skip to content

Commit

Permalink
Merge pull request #3 from llm-efficiency-challenge/msaroufim/corr2cause
Browse files Browse the repository at this point in the history
Corr2Cause
  • Loading branch information
msaroufim authored Nov 8, 2023
2 parents 1111ae1 + 8b40dd7 commit a210323
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 2 deletions.
7 changes: 5 additions & 2 deletions private_run_specs.conf
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

entries: [
## Real

Expand Down Expand Up @@ -82,6 +83,8 @@ entries: [
{description: "math:model=text_code,subject=counting_and_probability,level=5,use_chain_of_thought=True", priority: 3}
{description: "math:model=text_code,subject=precalculus,level=5,use_chain_of_thought=True", priority: 3}

{description: "sam_sum:model=neurips/local,max_train_instances=3", priority: 1}
{description: "ethics_utilitarianism:model=neurips/local", priority: 1}
{description: "sam_sum:model=neurips/local,max_train_instances=3", priority: 1}
{description: "ethics_utilitarianism:model=neurips/local", priority: 1}
{description: "corr2cause:model=neurips/local", priority: 1}

]
21 changes: 21 additions & 0 deletions src/helm/benchmark/run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,27 @@ def get_ethics_utilitarianism_spec(method: str = ADAPT_MULTIPLE_CHOICE_JOINT) ->
)


@run_spec_function("corr2cause")
def get_corr2cause_spec(method: str = ADAPT_MULTIPLE_CHOICE_JOINT) -> RunSpec:
scenario_spec = ScenarioSpec(class_name="helm.benchmark.scenarios.corr2cause_scenario.Corr2CauseScenario", args={})

prompt = """
Given a scenario with a premise and a hypothesis, determine if the hypothesis can be inferred from the premise.
"""

adapter_spec = get_multiple_choice_adapter_spec(
method=method, max_tokens=1, instructions=prompt, input_noun="Scenario\n", output_noun="Answer"
)

return RunSpec(
name=f"corr2cause,method={method}",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=get_exact_match_metric_specs(),
groups=["corr2cause"],
)


@run_spec_function("twitter_aae")
def get_twitter_aae_spec(demographic: str) -> RunSpec:
scenario_spec = ScenarioSpec(
Expand Down
88 changes: 88 additions & 0 deletions src/helm/benchmark/scenarios/corr2cause_scenario.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import csv
import os
import random
from typing import List, Dict, Any
from helm.common.general import ensure_file_downloaded, ensure_directory_exists
from .scenario import Scenario, Instance, Reference, ALL_SPLITS, CORRECT_TAG, VALID_SPLIT, Input, Output

# TODO: Should I just get rid of the train/test split?


class Corr2CauseScenario(Scenario):
"""Information on this class"""

name = "corr2cause"
description = "Can Large Language Models Infer Causation from Correlation?"
tags = ["classification"]
DATASET_FILE_NAME = "corr2cause.csv"
TRAIN_RATIO = 0.8 # 80% for training, 20% for validation
TRAIN_SPLIT = "train"
VALID_SPLIT = "valid"

def download_dataset(self, output_path: str):
"""Downloads the Corr2Cause dataset if not already present."""
# Define the target path for the dataset
data_dir = os.path.join(output_path, "data")
dataset_path = os.path.join(data_dir, self.DATASET_FILE_NAME)

# Check if the dataset already exists
if os.path.exists(dataset_path):
print(f"The dataset '{self.DATASET_FILE_NAME}' already exists at '{dataset_path}'. Skipping download.")
return

# Download the raw data
url = "https://gist.githubusercontent.com/msaroufim/2835e9a27490bb183de86c54c0614169/raw/4160842cd2574716355a5fe9134387a20baed9f8/corr2cause.csv"
ensure_directory_exists(data_dir)
ensure_file_downloaded(source_url=url, target_path=dataset_path)

def load_dataset(self, output_path: str) -> List[Dict[str, Any]]:
file_path = os.path.join(output_path, "data", self.DATASET_FILE_NAME)

data = []
with open(file_path, encoding="utf-8") as f:
csv_reader = csv.reader(f)
next(csv_reader)
for label, question in csv_reader:
data_point = {"label": int(label), "input": question.strip()}
data.append(data_point)
random.seed(0)
random.shuffle(data)
return data

def get_label(self, label: int) -> str:
return "No" if label == 0 else "Yes"

def data_to_instance(self, data_point: Dict[str, Any], split: str, instance_id: str) -> Instance:
input_text = Input(text=data_point["input"])

# Create reference choices with "No" and "Yes"
choices = [
Reference(output=Output(text="No"), tags=[]),
Reference(output=Output(text="Yes"), tags=[])
]

# Assign the CORRECT_TAG to the correct choice
correct_label = self.get_label(data_point["label"])
for choice in choices:
if choice.output.text == correct_label:
choice.tags.append(CORRECT_TAG)

return Instance(
input=input_text, references=choices, split=split
)

def get_instances(self, output_path: str) -> List[Instance]:
"""Returns the instances for this scenario."""
self.download_dataset(output_path)
data = self.load_dataset(output_path)
# Split the data
split_k = int(len(data) * self.TRAIN_RATIO)
train_data = data[:split_k]
valid_data = data[split_k:]

train_instances = [self.data_to_instance(dt, self.TRAIN_SPLIT, f"id{i}") for i, dt in enumerate(train_data)]
valid_instances = [
self.data_to_instance(dt, self.VALID_SPLIT, f"id{i+len(train_data)}") for i, dt in enumerate(valid_data)
]

return train_instances + valid_instances

0 comments on commit a210323

Please sign in to comment.