Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Fix petals on amd #108

Merged
merged 6 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions cht-petals/build.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
#!/bin/bash
set -e
export VERSION=1.0.0
export VERSION=1.0.1
source "$(dirname "${BASH_SOURCE[0]}")/../utils.sh"

# TODO: support linux/amd64
BUILDX_PLATFORM=linux/arm64 TESTS_SKIP_CPU=1 \
build_cpu ghcr.io/premai-io/chat-stable-beluga-2-cpu petals-team/StableBeluga2 ${@:1}
BUILDX_PLATFORM=linux/arm64 TESTS_SKIP_CPU=1 \
build_cpu ghcr.io/premai-io/chat-codellama-34b-cpu premai-io/CodeLlama-34b-Instruct-hf ${@:1}
8 changes: 2 additions & 6 deletions cht-petals/docker/cpu/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
FROM python:3.10-slim-bullseye

ARG MODEL_ID

RUN apt update && apt install -y libopenblas-dev ninja-build build-essential wget git
RUN python -m pip install --upgrade pip pytest cmake scikit-build setuptools

Expand All @@ -12,11 +10,9 @@ COPY requirements.txt ./
RUN pip install --no-cache-dir -r ./requirements.txt --upgrade pip

COPY download.py .

ARG MODEL_ID
ENV MODEL_ID=$MODEL_ID
RUN python3 download.py --model $MODEL_ID

COPY . .

ENV MODEL_ID=$MODEL_ID

CMD python main.py
14 changes: 9 additions & 5 deletions cht-petals/download.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import argparse
from platform import machine

import torch
from petals import AutoDistributedModelForCausalLM
from tenacity import retry, stop_after_attempt, wait_fixed
from transformers import AutoTokenizer, LlamaTokenizer
Expand All @@ -13,11 +15,13 @@

@retry(stop=stop_after_attempt(3), wait=wait_fixed(5))
def download_model() -> None:
if "llama" in args.model.lower():
_ = LlamaTokenizer.from_pretrained(args.model)
else:
_ = AutoTokenizer.from_pretrained(args.model)
_ = AutoDistributedModelForCausalLM.from_pretrained(args.model)
Tokenizer = LlamaTokenizer if "llama" in args.model.lower() else AutoTokenizer
_ = Tokenizer.from_pretrained(args.model)

kwargs = {}
if "x86_64" in machine():
kwargs["torch_dtype"] = torch.float32
_ = AutoDistributedModelForCausalLM.from_pretrained(args.model, **kwargs)


download_model()
21 changes: 14 additions & 7 deletions cht-petals/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
from abc import ABC, abstractmethod
from platform import machine
from typing import List

import torch
from petals import AutoDistributedModelForCausalLM
from transformers import AutoTokenizer, LlamaTokenizer, logging

Expand Down Expand Up @@ -50,18 +52,23 @@ def generate(
) -> List:
message = messages[-1]["content"]
inputs = cls.tokenizer(message, return_tensors="pt")["input_ids"]
outputs = cls.model.generate(inputs, max_new_tokens=5)
print(cls.tokenizer.decode(outputs[0]))
outputs = cls.model.generate(inputs, max_new_tokens=max_tokens)
return [cls.tokenizer.decode(outputs[0])]

@classmethod
def get_model(cls):
if cls.model is None:
if "llama" in os.getenv("MODEL_ID").lower():
cls.tokenizer = LlamaTokenizer.from_pretrained(os.getenv("MODEL_ID"))
else:
cls.tokenizer = AutoTokenizer.from_pretrained(os.getenv("MODEL_ID"))
Tokenizer = (
LlamaTokenizer
if "llama" in os.getenv("MODEL_ID").lower()
else AutoTokenizer
)
cls.tokenizer = Tokenizer.from_pretrained(os.getenv("MODEL_ID"))

kwargs = {}
if "x86_64" in machine():
kwargs["torch_dtype"] = torch.float32
cls.model = AutoDistributedModelForCausalLM.from_pretrained(
os.getenv("MODEL_ID")
os.getenv("MODEL_ID"), **kwargs
)
return cls.model
Loading