Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
Anshler committed Dec 3, 2023
1 parent f5d6088 commit bd4d4f4
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
5 changes: 4 additions & 1 deletion install.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,7 @@
launch.run_pip("install transformers==4.35.2", "requirements for Image Caption")

if not launch.is_installed("open_clip_torch"):
launch.run_pip("install open_clip_torch", "requirements for Image Caption")
launch.run_pip("install open_clip_torch", "requirements for Image Caption")

if not launch.is_installed("mediafire-dl"):
launch.run_pip("install git+https://github.com/Juvenal-Yescas/mediafire-dl", "requirements for Image Caption")
Binary file modified scripts/__pycache__/image_caption.cpython-310.pyc
Binary file not shown.
19 changes: 14 additions & 5 deletions scripts/image_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import torch
import torch.nn as nn
import open_clip
import mediafire_dl
from torch.nn import functional as nnf
from modules import script_callbacks
from transformers import GPT2Tokenizer, GPT2LMHeadModel, pipeline, AutoModelForMaskedLM, AutoTokenizer
from typing import List, Optional, Union, Tuple, Dict, Any
from itertools import combinations

original_directory = os.getcwd()
current_directory = os.path.join(original_directory,'extensions','image-caption-for-sd','scripts')
current_directory = os.path.join(original_directory,'extensions','ICG_sd_extension','scripts')
previous_choice = 'ClipCap'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encyclopedia, labels = None, None
Expand All @@ -21,6 +22,7 @@
prime_check_list = ['landscape','people','animal','plant','some object']
prime_mapping = {'landscape':'scenery','people':'human','animal':'animal','plant':'plant','some object':'object'} # need mapping because the wordings in label classification is different from those of CLIP's
secondary_check_list = ['time','weather','activity']
prefix_length = 10

secondary_text_labels_landscape = None
secondary_text_labels_human = None
Expand Down Expand Up @@ -80,7 +82,7 @@ def get_caption(choice, image = None):
global text_features_1000, prime_labels, prime_text_features, prime_check_list, prime_mapping, secondary_check_list
global secondary_text_labels_landscape, secondary_text_labels_human, secondary_text_labels_occupation, secondary_text_labels_animal, secondary_text_labels_plant, secondary_text_labels_object, secondary_text_labels_activity, secondary_text_labels_time, secondary_text_labels_weather, secondary_text_labels_clothing
global secondary_text_features_landscape, secondary_text_features_human, secondary_text_features_occupation, secondary_text_features_animal, secondary_text_features_plant, secondary_text_features_object, secondary_text_features_activity, secondary_text_features_time, secondary_text_features_weather, secondary_text_features_clothing

global prefix_length
if image is not None:
torch.cuda.empty_cache()

Expand All @@ -107,14 +109,16 @@ def get_caption(choice, image = None):
if change_model or model is None:
change_model = False
if choice == 'ClipCap':
model_path = os.path.join(current_directory, 'flickr8k_prefix-030.pt')
model_path = os.path.join(current_directory,'flickr8k_prefix-030.pt')
modelCLIP, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai', device= device, precision= 'fp16' if device == torch.device('cuda') else 'fp32')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# Load model weights
prefix_length = 10
model = ClipCaptionModel(prefix_length)
if not os.path.exists(model_path):
url = 'https://download1649.mediafire.com/0jzo588ip91gq4T9uNZPEPaAJrIf8srGL3PIZ5kDUCpOsMHFjn-8xr4786EaZGA8IUxowMuUy3UnHc4aKGlWZHQBL7R4rArvkbhY_pzIuZyv_rGjA36yV38WAPkplghdo11g44kF5LBFvcLSp4dMHqdXUw4WWi2JnfRP03l7sTONEg/qof8qa7odm4dfck/flickr8k_prefix-030.pt'
mediafire_dl.download(url, model_path, quiet=False)

model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

model = model.eval()
model = model.to(device)
else:
Expand All @@ -125,6 +129,11 @@ def get_caption(choice, image = None):
tokenizer.sep_token = '###' # set separator
tokenizer.pad_token = tokenizer.eos_token # set padding token
model_name = 'top_hierarchy'
model_path = os.path.join(current_directory, model_name, 'model.safetensors')
if not os.path.exists(model_path):
url = 'https://download1323.mediafire.com/u0312q9aa6lgEsgJmP1dLaDF0p7Td3f4fnWzlRH_NXvgPtskQq_GpRL23mZ50cv_R-0Ls31a9cvqgX1lUvBevqDDfNrMVNHzJ_FDg0B-81Whaslf5NatrZ4-exVjkuQeBu3EKd8LO0FFiUSsveSoaEZ88M8aWaA_XWFibLMa-Moszg/9rjol6786rlmefx/model.safetensors'
mediafire_dl.download(url, model_path, quiet=False)

os.chdir(current_directory)
generator = pipeline("text-generation", model=model_name, tokenizer=tokenizer)
os.chdir(original_directory)
Expand Down

0 comments on commit bd4d4f4

Please sign in to comment.