Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lit-llama benchmarks (logits, autoregressive generation, lora fine tuning) #1730

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "submodules/FAMBench"]
path = submodules/FAMBench
url = https://github.com/facebookresearch/FAMBench.git
[submodule "submodules/lit-llama"]
path = submodules/lit-llama
url = https://github.com/Lightning-AI/lit-llama.git
1 change: 1 addition & 0 deletions submodules/lit-llama
Submodule lit-llama added at 8aa65b
54 changes: 54 additions & 0 deletions torchbenchmark/models/lit_llama/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from ...util.model import BenchmarkModel
from torchbenchmark.tasks import NLP
import torch
import os
from torchbenchmark import REPO_PATH
import sys
import lightning as L

LIT_LLAMA_PATH = os.path.join(REPO_PATH, "submodules", "lit-llama")

sys.path.insert(0, LIT_LLAMA_PATH)
ezyang marked this conversation as resolved.
Show resolved Hide resolved

from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
from lit_llama import LLaMA, Tokenizer

class Model(BenchmarkModel):
task = NLP.LANGUAGE_MODELING
DEFAULT_EVAL_BSIZE = 1
DEFAULT_TRAIN_BSIZE = 32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why the default train batch size is 32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I should just delete this, it's meaningless, you can't train 7B without some sort of distribution haha


def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args)

checkpoint_path = os.path.join(LIT_LLAMA_PATH, "checkpoints/lit-llama/7B/lit-llama.pth")
if not os.path.exists(checkpoint_path):
raise NotImplementedError("checkpoint doesn't exist")
with lazy_load(checkpoint_path) as checkpoint:
name = llama_model_lookup(checkpoint)

with EmptyInitOnDevice(device=device):
model = LLaMA.from_name(name)
model.load_state_dict(checkpoint)

self.model = model
self.seq_len = 32
self.max_seq_len = 64
self.example_inputs = (
torch.ones([self.batch_size, self.seq_len], dtype=torch.int32, device=self.device),
self.max_seq_len,
torch.arange(self.seq_len, dtype=torch.int64, device=self.device) # positions
)


def get_module(self):
return self.model, self.example_inputs

def train(self):
return NotImplementedError("you will OOM trying to train directly")

def eval(self):
self.model.eval()
with torch.no_grad():
logits = self.model(*self.example_inputs)
return (logits,)
4 changes: 4 additions & 0 deletions torchbenchmark/models/lit_llama/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from torchbenchmark.util.framework.lit_llama import install_lit_llama

if __name__ == '__main__':
install_lit_llama()
10 changes: 10 additions & 0 deletions torchbenchmark/models/lit_llama/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
devices:
NVIDIA A100-SXM4-40GB:
eval_batch_size: 32
eval_benchmark: false
eval_deterministic: false
eval_nograd: true
train_benchmark: false
train_deterministic: false
not_implemented:
- test: eval
41 changes: 41 additions & 0 deletions torchbenchmark/models/lit_llama_generate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from .. import lit_llama as lit_llama
from ..lit_llama import LIT_LLAMA_PATH
import importlib.util
import os.path
import torch.nn as nn
import sys
from lit_llama import Tokenizer

def import_from_file_path(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
sys.modules[module_name] = module
return module

lit_llama_generate = import_from_file_path("lit_llama_generate", os.path.join(LIT_LLAMA_PATH, 'generate.py'))

class GenerationWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, idx, max_new_tokens):
return lit_llama_generate.generate(self.model, idx, max_new_tokens)

class Model(lit_llama.Model):
def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args)
self.model = GenerationWrapper(self.model)
tokenizer = Tokenizer(os.path.join(LIT_LLAMA_PATH, "checkpoints/lit-llama/tokenizer.model"))
# max_new_tokens matches lit-llama/generate.py
self.example_inputs = (tokenizer.encode("The meaning of life is", bos=True, eos=False, device=device), 50)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is 50 the max number of tokens to generate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes


def train(self):
return NotImplementedError("cannot train on autoregressive generation")

def eval(self):
self.model.eval()
with torch.no_grad():
y = self.model(*self.example_inputs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mind printing the input prompt and the output, will be nice to do vibe checks later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, but I don't want to print it here, because then the detokenization would also count as part of the benchmark?

return (y,)
4 changes: 4 additions & 0 deletions torchbenchmark/models/lit_llama_generate/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from torchbenchmark.util.framework.lit_llama import install_lit_llama

if __name__ == '__main__':
install_lit_llama()
10 changes: 10 additions & 0 deletions torchbenchmark/models/lit_llama_generate/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
devices:
NVIDIA A100-SXM4-40GB:
eval_batch_size: 32
eval_benchmark: false
eval_deterministic: false
eval_nograd: true
train_benchmark: false
train_deterministic: false
not_implemented:
- test: eval
66 changes: 66 additions & 0 deletions torchbenchmark/models/lit_llama_lora/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from ...util.model import BenchmarkModel
from torchbenchmark.tasks import NLP
import torch
from ..lit_llama import LIT_LLAMA_PATH
import importlib.util
import os.path
import torch.nn as nn
import sys
from lit_llama.lora import mark_only_lora_as_trainable, lora, lora_state_dict
from torchbenchmark import REPO_PATH

LIT_LLAMA_PATH = os.path.join(REPO_PATH, "submodules", "lit-llama")

sys.path.insert(0, LIT_LLAMA_PATH)

from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
from lit_llama import LLaMA, Tokenizer

class Model(BenchmarkModel):
task = NLP.LANGUAGE_MODELING
DEFAULT_EVAL_BSIZE = 1
DEFAULT_TRAIN_BSIZE = 4 # micro_batch_size in lora.py

def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args)

# From finetune/lora.py hyperparameters
lora_r = 8
lora_alpha = 16
lora_dropout = 0.05

checkpoint_path = os.path.join(LIT_LLAMA_PATH, "checkpoints/lit-llama/7B/lit-llama.pth")
if not os.path.exists(checkpoint_path):
raise NotImplementedError("checkpoint doesn't exist")
with lazy_load(checkpoint_path) as checkpoint, lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
name = llama_model_lookup(checkpoint)

with EmptyInitOnDevice(device=device):
model = LLaMA.from_name(name)
# LoRA weights won't be in base checkpoint
model.load_state_dict(checkpoint, strict=False)

mark_only_lora_as_trainable(model)

self.model = model
self.seq_len = 32
self.max_seq_len = 64
self.example_inputs = (
torch.ones([self.batch_size, self.seq_len], dtype=torch.int32, device=self.device),
self.max_seq_len,
)


def get_module(self):
return self.model, self.example_inputs

def train(self):
logits = self.model(*self.example_inputs)
logits.sum().backward()
# meh this sucks
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xd, this might be a good dataset https://huggingface.co/datasets/OpenAssistant/oasst1

Even finetuning on two examples of questions you make up might be not bad as a sanity

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix this later, I think. Not needed for dynamo benchmarks.


def eval(self):
self.model.eval()
with torch.no_grad():
logits = self.model(*self.example_inputs)
return (logits,)
4 changes: 4 additions & 0 deletions torchbenchmark/models/lit_llama_lora/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from torchbenchmark.util.framework.lit_llama import install_lit_llama

if __name__ == '__main__':
install_lit_llama()
10 changes: 10 additions & 0 deletions torchbenchmark/models/lit_llama_lora/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
devices:
NVIDIA A100-SXM4-40GB:
eval_batch_size: 32
eval_benchmark: false
eval_deterministic: false
eval_nograd: true
train_benchmark: false
train_deterministic: false
not_implemented:
- test: train
50 changes: 50 additions & 0 deletions torchbenchmark/util/framework/lit_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os
import sys
import subprocess
import traceback
from pathlib import Path
from torchbenchmark import REPO_PATH

LIT_LLAMA_PATH = os.path.join(REPO_PATH, "submodules", "lit-llama")

def update_lit_llama_submodule():
update_command = ["git", "submodule", "update",
"--init", "--recursive", os.path.join("submodules", "lit-llama")]
subprocess.check_call(update_command, cwd=REPO_PATH)

def pip_install_requirements():
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', os.path.join(LIT_LLAMA_PATH, "requirements.txt")])

def openllama_download():
if os.path.exists(os.path.join(LIT_LLAMA_PATH, "checkpoints/lit-llama/7B/lit-llama.pth")):
return
subprocess.check_call([
sys.executable,
os.path.join(LIT_LLAMA_PATH, 'scripts/download.py'),
'--repo_id',
'openlm-research/open_llama_7b_700bt_preview',
'--local_dir',
os.path.join(LIT_LLAMA_PATH, 'checkpoints/open-llama/7B')
])
subprocess.check_call([
sys.executable,
os.path.join(LIT_LLAMA_PATH, 'scripts/convert_hf_checkpoint.py'),
'--checkpoint_dir', os.path.join(LIT_LLAMA_PATH, 'checkpoints/open-llama/7B'),
'--model_size', '7B',
], cwd=LIT_LLAMA_PATH)

def install_lit_llama():
import torch
update_lit_llama_submodule()
pip_install_requirements()
try:
from pynvml import nvmlDeviceGetMemoryInfo
info = nvmlDeviceGetMemoryInfo(torch.cuda._get_pynvml_handler())
if info.total < 40 * 1024 ** 3:
print("not enough GPU memory for 7B parameters, skipping llama (avail: {info.total / 1024 ** 3}GB)")
return
except Exception as e:
print("failed to test GPU memory, skipping llama weights")
traceback.print_exc()
return
openllama_download()
Loading