Skip to content

Commit

Permalink
Clip_vit_large14 and t5 models (nod-ai#535)
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey authored Mar 20, 2024
1 parent 6e3adb3 commit f2c025c
Show file tree
Hide file tree
Showing 3 changed files with 341 additions and 145 deletions.
94 changes: 65 additions & 29 deletions models/turbine_models/custom_models/sd_inference/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,13 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import os
import sys
import re

from iree import runtime as ireert
import iree.compiler as ireec
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from turbine_models.custom_models.sd_inference import utils
import torch
import torch._dynamo as dynamo
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPTextModel, CLIPTokenizer, CLIPProcessor
from turbine_models.turbine_tank import turbine_tank

import argparse
Expand Down Expand Up @@ -60,37 +56,77 @@ def export_clip_model(
max_alloc=None,
upload_ir=False,
):
# Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained(
hf_model_name,
subfolder="tokenizer",
token=hf_auth_token,
)
input_len = 77
if "google/t5" in hf_model_name:
from transformers import T5Tokenizer, T5Model

text_encoder_model = CLIPTextModel.from_pretrained(
hf_model_name,
subfolder="text_encoder",
token=hf_auth_token,
)
tokenizer = T5Tokenizer.from_pretrained(hf_model_name)
text_encoder_model = T5Model.from_pretrained(hf_model_name)
input_len = 512

else:
# TODO: Add better filtering mechanism for things that require CLIPProcessor
if "openai" in hf_model_name:
tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
hf_subfolder = "" # CLIPProcessor does not have a subfolder
input_len = 10
else:
# Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained(
hf_model_name,
subfolder="tokenizer",
token=hf_auth_token,
)
hf_subfolder = "text_encoder"

text_encoder_model = CLIPTextModel.from_pretrained(
hf_model_name,
subfolder=hf_subfolder,
token=hf_auth_token,
)

mapper = {}
utils.save_external_weights(
mapper, text_encoder_model, external_weights, external_weight_path
)

class CompiledClip(CompiledModule):
if external_weights:
params = export_parameters(
text_encoder_model,
external=True,
external_scope="",
name_mapper=mapper.get,
)
else:
params = export_parameters(text_encoder_model)
if "google/t5" in hf_model_name:

class CompiledClip(CompiledModule):
if external_weights:
params = export_parameters(
text_encoder_model,
external=True,
external_scope="",
name_mapper=mapper.get,
)
else:
params = export_parameters(text_encoder_model)

def main(
self,
inp=AbstractTensor(1, input_len, dtype=torch.int64),
decoder_input_ids=AbstractTensor(1, input_len, dtype=torch.int64),
):
return jittable(text_encoder_model.forward)(
input_ids=inp, decoder_input_ids=decoder_input_ids
)

else:

class CompiledClip(CompiledModule):
if external_weights:
params = export_parameters(
text_encoder_model,
external=True,
external_scope="",
name_mapper=mapper.get,
)
else:
params = export_parameters(text_encoder_model)

def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)):
return jittable(text_encoder_model.forward)(inp)
def main(self, inp=AbstractTensor(1, input_len, dtype=torch.int64)):
return jittable(text_encoder_model.forward)(input_ids=inp)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledClip(context=Context(), import_to=import_to)
Expand Down
139 changes: 108 additions & 31 deletions models/turbine_models/custom_models/sd_inference/clip_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from transformers import CLIPTokenizer
from iree import runtime as ireert
import torch
from PIL import Image

parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -52,49 +53,125 @@ def run_clip(
):
runner = vmfbRunner(device, vmfb_path, external_weight_path)

tokenizer = CLIPTokenizer.from_pretrained(
hf_model_name,
subfolder="tokenizer",
token=hf_auth_token,
)
text_input = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
if "google/t5" in hf_model_name:
from transformers import T5Tokenizer, T5Model

tokenizer = T5Tokenizer.from_pretrained(hf_model_name)
text_input = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
# TODO: Integrate with HFTransformerBuilder
else:
if "openai" in hf_model_name:
from transformers import CLIPProcessor
import requests

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
text_input = tokenizer(
text=prompt,
images=image,
truncation=True,
padding=True,
return_tensors="pt",
)
else:
hf_subfolder = "tokenizer"

tokenizer = CLIPTokenizer.from_pretrained(
hf_model_name,
subfolder=hf_subfolder,
token=hf_auth_token,
)

text_input = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
example_input = text_input.input_ids
inp = [ireert.asdevicearray(runner.config.device, example_input)]

if "google/t5" in hf_model_name:
inp += [ireert.asdevicearray(runner.config.device, example_input)]
results = runner.ctx.modules.compiled_clip["main"](*inp)
return results


def run_torch_clip(hf_model_name, hf_auth_token, prompt):
if "google/t5" in hf_model_name:
from transformers import T5Tokenizer, T5Model

tokenizer = T5Tokenizer.from_pretrained(hf_model_name)
model = T5Model.from_pretrained(hf_model_name)
text_input = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
# TODO: Integrate with HFTransformerBuilder
from transformers import CLIPTextModel
else:
if hf_model_name == "openai/clip-vit-large-patch14":
from transformers import CLIPProcessor
import requests

model = CLIPTextModel.from_pretrained(
hf_model_name,
subfolder="text_encoder",
token=hf_auth_token,
)
tokenizer = CLIPTokenizer.from_pretrained(
hf_model_name,
subfolder="tokenizer",
token=hf_auth_token,
)
text_input = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
hf_subfolder = "" # CLIPProcessor does not have a subfolder
from transformers import CLIPTextModel

model = CLIPTextModel.from_pretrained(
hf_model_name,
subfolder=hf_subfolder,
token=hf_auth_token,
)
text_input = tokenizer(
text=prompt,
images=image,
truncation=True,
padding=True,
return_tensors="pt",
)
else:
hf_subfolder = "text_encoder"

tokenizer = CLIPTokenizer.from_pretrained(
hf_model_name,
subfolder="tokenizer",
token=hf_auth_token,
)

from transformers import CLIPTextModel

model = CLIPTextModel.from_pretrained(
hf_model_name,
subfolder=hf_subfolder,
token=hf_auth_token,
)
text_input = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
example_input = text_input.input_ids

results = model.forward(example_input)[0]
if "google/t5" in hf_model_name:
results = model.forward(example_input, decoder_input_ids=example_input)[0]
else:
results = model.forward(example_input)[0]
np_torch_output = results.detach().cpu().numpy()
return np_torch_output

Expand Down
Loading

0 comments on commit f2c025c

Please sign in to comment.