Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add timing helper #18

Merged
merged 2 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions experiments/scaling/pythia.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from gptchem.data import get_photoswitch_data

from chemlift.finetune.peftmodels import PEFTClassifier, ChemLIFTClassifierFactory
from sklearn.model_selection import train_test_split

from fastcore.xtras import load_pickle, save_pickle
from gptchem.evaluator import evaluate_classification
import time
import os


def get_timestr():
return time.strftime("%Y-%m-%d_%H-%M-%S")


models = [
"EleutherAI/pythia-12b-deduped",
"EleutherAI/pythia-6.9b-deduped",
"EleutherAI/pythia-2.8b-deduped",
"EleutherAI/pythia-1.4b-deduped",
"EleutherAI/pythia-1b-deduped",
"EleutherAI/pythia-410m-deduped",
"EleutherAI/pythia-160m-deduped",
"EleutherAI/pythia-70m-deduped",
]


def train_test(train_size, model_name, random_state=42):
data = get_photoswitch_data()

data = data.dropna(subset=["SMILES", "E isomer pi-pi* wavelength in nm"])

data["binned"] = data["E isomer pi-pi* wavelength in nm"].apply(
lambda x: 1 if x > data["E isomer pi-pi* wavelength in nm"].median() else 0
)

train, test = train_test_split(
data, train_size=train_size, stratify=data["binned"], random_state=random_state
)

train_median = train["E isomer pi-pi* wavelength in nm"].median()
train["binned"] = train["E isomer pi-pi* wavelength in nm"].apply(
lambda x: 1 if x > train_median else 0
)
test["binned"] = test["E isomer pi-pi* wavelength in nm"].apply(
lambda x: 1 if x > train_median else 0
)

model = ChemLIFTClassifierFactory(
"transition wavelength class",
model_name=model_name,
load_in_8bit=True,
inference_batch_size=32,
tokenizer_kwargs={"cutoff_len": 50},
tune_settings={"num_train_epochs": 32},
).create_model()

model.fit(train["SMILES"].values, train["binned"].values)

start = time.time()
predictions = model.predict(test["SMILES"].values)
end = time.time()

report = evaluate_classification(test["binned"].values, predictions)

if not os.path.exists("results"):
os.makedirs("results")

outname = f"results/{get_timestr()}_peft_{model_name}_{train_size}.pkl"

report["model_name"] = model_name
report["train_size"] = train_size
report["random_state"] = random_state
report["predictions"] = predictions
report["targets"] = test["binned"].values
report["fine_tune_time"] = model.fine_tune_time
report["inference_time"] = end - start

save_pickle(outname, report)


if __name__ == "__main__":
for seed in range(5):
for model in models:
for train_size in [10, 50, 100, 200, 300]:
train_test(train_size, model, random_state=seed)
9 changes: 9 additions & 0 deletions src/chemlift/finetune/peftmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from functools import partial
from peft.utils.save_and_load import set_peft_model_state_dict
from fastcore.basics import basic_repr
import time


class ChemLIFTClassifierFactory:
Expand Down Expand Up @@ -125,8 +126,14 @@ def __init__(

self.tune_settings["per_device_train_batch_size"] = self.batch_size

self._fine_tune_time = None

__repr__ = basic_repr(["property_name", "_base_model"])

@property
def fine_tune_time(self):
return self._fine_tune_time

def _prepare_df(self, X: ArrayLike, y: ArrayLike):
rows = []
for i in range(len(X)):
Expand Down Expand Up @@ -255,6 +262,7 @@ def fit(
dfs.append(formatted)

formatted = pd.concat(dfs)
start_time = time.time()
train_model(
self.model,
self.tokenizer,
Expand All @@ -263,6 +271,7 @@ def fit(
hub_model_name=None,
report_to=None,
)
self._fine_tune_time = time.time() - start_time

def _predict(
self,
Expand Down
6 changes: 6 additions & 0 deletions src/chemlift/icl/fewshotpredictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import enum
from typing import Union
from chemlift.icl.utils import LangChainChatModelWrapper
import time


class Strategy(enum.Enum):
Expand Down Expand Up @@ -86,6 +87,11 @@ def __init__(
self._materialclass = "molecules"
self._max_test = max_test
self._prefix = prefix
self._prediction_time = None

@property
def prediction_time(self):
return self._prediction_time

def _format_examples(self, examples, targets):
"""Format examples and targets into a string.
Expand Down
Loading