-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathgenerate_images.py
62 lines (52 loc) · 2.49 KB
/
generate_images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import csv
import torch
import argparse
from tqdm import tqdm
from PIL import Image
from .data import ImageDataset
from torchvision import transforms
from accelerate import Accelerator
from diffusers import StableDiffusionImageVariationPipeline
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type = int, default = 2)
parser.add_argument("--split", type = str, default = "train", help = "Path to eval test data")
parser.add_argument("--data_dir", type = str, default = "/home/data/ImageNet1K/validation", help = "Path to eval test data")
parser.add_argument("--save_image_gen", type = str, default = None, help = "Path saved generated images")
args = parser.parse_args()
accelerator = Accelerator()
os.makedirs(args.save_image_gen, exist_ok = True)
def generate_images(pipe, dataloader, args):
pipe, dataloader = accelerator.prepare(pipe, dataloader)
pipe = pipe.to(accelerator.device)
filename = os.path.join(args.save_image_gen, 'images_variation.csv')
with torch.no_grad():
for image_locations, original_images in tqdm(dataloader):
indices = list(filter(lambda x: not os.path.exists(os.path.join(args.save_image_gen, image_locations[x])), range(len(image_locations))))
if len(indices) == 0:
continue
original_images = original_images[indices]
image_locations = [image_locations[i] for i in indices]
images = pipe(original_images, guidance_scale = 3).images
for index in range(len(images)):
os.makedirs(os.path.join(args.save_image_gen, os.path.dirname(image_locations[index])), exist_ok = True)
images[index].save(os.path.join(args.save_image_gen, image_locations[index]))
def main():
model_name_path = "lambdalabs/sd-image-variations-diffusers"
pipe = StableDiffusionImageVariationPipeline.from_pretrained(model_name_path, revision = "v2.0")
tform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(
(224, 224),
interpolation=transforms.InterpolationMode.BICUBIC,
antialias=False,
),
transforms.Normalize(
[0.48145466, 0.4578275, 0.40821073],
[0.26862954, 0.26130258, 0.27577711]),
])
dataset = ImageDataset(args.data_dir, tform, split = args.split)
dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, shuffle = False)
generate_images(pipe, dataloader, args)
if __name__ == "__main__":
main()