From dd9c8ecc78575babf7527f74e830e05667dfcc79 Mon Sep 17 00:00:00 2001 From: Tianhao Xie <52686796+tianhaoxie@users.noreply.github.com> Date: Fri, 30 Jun 2023 08:47:24 -0400 Subject: [PATCH 1/5] Add files via upload --- visualizer_gradio_custom.py | 964 ++++++++++++++++++++++++++++++++++++ 1 file changed, 964 insertions(+) create mode 100644 visualizer_gradio_custom.py diff --git a/visualizer_gradio_custom.py b/visualizer_gradio_custom.py new file mode 100644 index 0000000..6e84057 --- /dev/null +++ b/visualizer_gradio_custom.py @@ -0,0 +1,964 @@ +import os +import os.path as osp +from argparse import ArgumentParser +from functools import partial + +import gradio as gr +import numpy as np +import torch +from PIL import Image +import imageio +import dnnlib +from gradio_utils import (ImageMask, draw_mask_on_image, draw_points_on_image, + get_latest_points_pair, get_valid_mask, + on_change_single_global_state) +from viz.renderer import Renderer, add_watermark_np +from gan_inv.inversion import PTI +from gan_inv.lpips import util +parser = ArgumentParser() +parser.add_argument('--share',default='False') +parser.add_argument('--cache-dir', type=str, default='./checkpoints') +args = parser.parse_args() + +cache_dir = args.cache_dir + +device = 'cuda' + + +def reverse_point_pairs(points): + new_points = [] + for p in points: + new_points.append([p[1], p[0]]) + return new_points + + +def clear_state(global_state, target=None): + """Clear target history state from global_state + If target is not defined, points and mask will be both removed. + 1. set global_state['points'] as empty dict + 2. set global_state['mask'] as full-one mask. + """ + if target is None: + target = ['point', 'mask'] + if not isinstance(target, list): + target = [target] + if 'point' in target: + global_state['points'] = dict() + print('Clear Points State!') + if 'mask' in target: + image_raw = global_state["images"]["image_raw"] + global_state['mask'] = np.ones((image_raw.size[1], image_raw.size[0]), + dtype=np.uint8) + print('Clear mask State!') + + return global_state + + +def init_images(global_state): + """This function is called only ones with Gradio App is started. + 0. pre-process global_state, unpack value from global_state of need + 1. Re-init renderer + 2. run `renderer._render_drag_impl` with `is_drag=False` to generate + new image + 3. Assign images to global state and re-generate mask + """ + + if isinstance(global_state, gr.State): + state = global_state.value + else: + state = global_state + + state['renderer'].init_network( + state['generator_params'], # res + valid_checkpoints_dict[state['pretrained_weight']], # pkl + state['params']['seed'], # w0_seed, + None, # w_load + state['params']['latent_space'] == 'w+', # w_plus + 'const', + state['params']['trunc_psi'], # trunc_psi, + state['params']['trunc_cutoff'], # trunc_cutoff, + None, # input_transform + state['params']['lr'] # lr, + ) + + state['renderer']._render_drag_impl(state['generator_params'], + is_drag=False, + to_pil=True) + + init_image = state['generator_params'].image + state['images']['image_orig'] = init_image + state['images']['image_raw'] = init_image + state['images']['image_show'] = Image.fromarray( + add_watermark_np(np.array(init_image))) + state['mask'] = np.ones((init_image.size[1], init_image.size[0]), + dtype=np.uint8) + return global_state + + +def update_image_draw(image, points, mask, show_mask, global_state=None): + + image_draw = draw_points_on_image(image, points) + if show_mask and mask is not None and not (mask == 0).all() and not ( + mask == 1).all(): + image_draw = draw_mask_on_image(image_draw, mask) + + image_draw = Image.fromarray(add_watermark_np(np.array(image_draw))) + if global_state is not None: + global_state['images']['image_show'] = image_draw + return image_draw + + +def preprocess_mask_info(global_state, image): + """Function to handle mask information. + 1. last_mask is None: Do not need to change mask, return mask + 2. last_mask is not None: + 2.1 global_state is remove_mask: + 2.2 global_state is add_mask: + """ + if isinstance(image, dict): + last_mask = get_valid_mask(image['mask']) + else: + last_mask = None + mask = global_state['mask'] + + # mask in global state is a placeholder with all 1. + if (mask == 1).all(): + mask = last_mask + + # last_mask = global_state['last_mask'] + editing_mode = global_state['editing_state'] + + if last_mask is None: + return global_state + + if editing_mode == 'remove_mask': + updated_mask = np.clip(mask - last_mask, 0, 1) + print(f'Last editing_state is {editing_mode}, do remove.') + elif editing_mode == 'add_mask': + updated_mask = np.clip(mask + last_mask, 0, 1) + print(f'Last editing_state is {editing_mode}, do add.') + else: + updated_mask = mask + print(f'Last editing_state is {editing_mode}, ' + 'do nothing to mask.') + + global_state['mask'] = updated_mask + # global_state['last_mask'] = None # clear buffer + return global_state + + +valid_checkpoints_dict = { + f.split('/')[-1].split('.')[0]: osp.join(cache_dir, f) + for f in os.listdir(cache_dir) + if (f.endswith('pkl') and osp.exists(osp.join(cache_dir, f))) +} +print(f'File under cache_dir ({cache_dir}):') +print(os.listdir(cache_dir)) +print('Valid checkpoint file:') +print(valid_checkpoints_dict) + +init_pkl = 'stylegan2_lions_512_pytorch' + + + +# Network & latents tab listeners +def on_change_pretrained_dropdown(pretrained_value, global_state): + """Function to handle model change. + 1. Set pretrained value to global_state + 2. Re-init images and clear all states + """ + global_state['pretrained_weight'] = pretrained_value + init_images(global_state) + clear_state(global_state) + + return global_state, global_state["images"]['image_show'] + + + +def on_click_reset_image(global_state): + """Reset image to the original one and clear all states + 1. Re-init images + 2. Clear all states + """ + + init_images(global_state) + clear_state(global_state) + + return global_state, global_state['images']['image_show'] + + + + # Update parameters +def on_change_update_image_seed(seed, global_state): + """Function to handle generation seed change. + 1. Set seed to global_state + 2. Re-init images and clear all states + """ + + global_state["params"]["seed"] = int(seed) + init_images(global_state) + clear_state(global_state) + + return global_state, global_state['images']['image_show'] + + + +def on_click_latent_space(latent_space, global_state): + """Function to reset latent space to optimize. + NOTE: this function we reset the image and all controls + 1. Set latent-space to global_state + 2. Re-init images and clear all state + """ + + global_state['params']['latent_space'] = latent_space + init_images(global_state) + clear_state(global_state) + + return global_state, global_state['images']['image_show'] + + + +def on_click_inverse_custom_image(custom_image,global_state): + print('inverse GAN') + + if isinstance(global_state, gr.State): + state = global_state.value + else: + state = global_state + + state['renderer'].init_network( + state['generator_params'], # res + valid_checkpoints_dict[state['pretrained_weight']], # pkl + state['params']['seed'], # w0_seed, + None, # w_load + state['params']['latent_space'] == 'w+', # w_plus + 'const', + state['params']['trunc_psi'], # trunc_psi, + state['params']['trunc_cutoff'], # trunc_cutoff, + None, # input_transform + state['params']['lr'] # lr, + ) + + percept = util.PerceptualLoss( + model="net-lin", net="vgg", use_gpu=True + ) + + image = Image.open(custom_image.name) + + pti = PTI(global_state['renderer'].G,percept) + inversed_img, w_pivot = pti.train(image,state['params']['latent_space'] == 'w+') + inversed_img = (inversed_img[0] * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0) + inversed_img = inversed_img.cpu().numpy() + inversed_img = Image.fromarray(inversed_img) + global_state['images']['image_show'] = Image.fromarray( + add_watermark_np(np.array(inversed_img))) + + global_state['images']['image_orig'] = inversed_img + global_state['images']['image_raw'] = inversed_img + + global_state['mask'] = np.ones((inversed_img.size[1], inversed_img.size[0]), + dtype=np.uint8) + global_state['generator_params'].image = inversed_img + global_state['generator_params'].w = w_pivot.detach().cpu().numpy() + global_state['renderer'].set_latent(w_pivot,global_state['params']['trunc_psi'],global_state['params']['trunc_cutoff']) + + del percept + del pti + print('inverse end') + + return global_state, global_state['images']['image_show'], gr.Button.update(interactive=True) + +def on_save_image(global_state,form_save_image_path): + imageio.imsave(form_save_image_path,global_state['images']['image_raw']) + +def on_reset_custom_image(global_state): + if isinstance(global_state, gr.State): + state = global_state.value + else: + state = global_state + clear_state(state) + state['renderer'].w = state['renderer'].w0.detach().clone() + state['renderer'].w.requires_grad = True + state['renderer'].w_optim = torch.optim.Adam([state['renderer'].w], lr=state['renderer'].lr) + state['renderer']._render_drag_impl(state['generator_params'], + is_drag=False, + to_pil=True) + + init_image = state['generator_params'].image + state['images']['image_orig'] = init_image + state['images']['image_raw'] = init_image + state['images']['image_show'] = Image.fromarray( + add_watermark_np(np.array(init_image))) + state['mask'] = np.ones((init_image.size[1], init_image.size[0]), + dtype=np.uint8) + return state, state['images']['image_show'] +def on_change_lr(lr, global_state): + if lr == 0: + print('lr is 0, do nothing.') + return global_state + else: + global_state["params"]["lr"] = lr + renderer = global_state['renderer'] + renderer.update_lr(lr) + print('New optimizer: ') + print(renderer.w_optim) + return global_state + + +def on_click_start(global_state, image): + p_in_pixels = [] + t_in_pixels = [] + valid_points = [] + + # handle of start drag in mask editing mode + global_state = preprocess_mask_info(global_state, image) + + # Prepare the points for the inference + if len(global_state["points"]) == 0: + # yield on_click_start_wo_points(global_state, image) + image_raw = global_state['images']['image_raw'] + update_image_draw( + image_raw, + global_state['points'], + global_state['mask'], + global_state['show_mask'], + global_state, + ) + + yield ( + global_state, + 0, + global_state['images']['image_show'], + # gr.File.update(visible=False), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + # latent space + gr.Radio.update(interactive=True), + gr.Button.update(interactive=True), + # NOTE: disable stop button + gr.Button.update(interactive=False), + + # update other comps + gr.Dropdown.update(interactive=True), + gr.Number.update(interactive=True), + gr.Number.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Checkbox.update(interactive=True), + # gr.Number.update(interactive=True), + gr.Number.update(interactive=True), + ) + else: + + # Transform the points into torch tensors + for key_point, point in global_state["points"].items(): + try: + p_start = point.get("start_temp", point["start"]) + p_end = point["target"] + + if p_start is None or p_end is None: + continue + + except KeyError: + continue + + p_in_pixels.append(p_start) + t_in_pixels.append(p_end) + valid_points.append(key_point) + + mask = torch.tensor(global_state['mask']).float() + drag_mask = 1 - mask + + renderer: Renderer = global_state["renderer"] + global_state['temporal_params']['stop'] = False + global_state['editing_state'] = 'running' + + # reverse points order + p_to_opt = reverse_point_pairs(p_in_pixels) + t_to_opt = reverse_point_pairs(t_in_pixels) + #print('Running with:') + #print(f' Source: {p_in_pixels}') + #print(f' Target: {t_in_pixels}') + step_idx = 0 + while True: + if global_state["temporal_params"]["stop"]: + break + + # do drage here! + renderer._render_drag_impl( + global_state['generator_params'], + p_to_opt, # point + t_to_opt, # target + drag_mask, # mask, + global_state['params']['motion_lambda'], # lambda_mask + reg=0, + feature_idx=5, # NOTE: do not support change for now + r1=global_state['params']['r1_in_pixels'], # r1 + r2=global_state['params']['r2_in_pixels'], # r2 + # random_seed = 0, + # noise_mode = 'const', + trunc_psi=global_state['params']['trunc_psi'], + # force_fp32 = False, + # layer_name = None, + # sel_channels = 3, + # base_channel = 0, + # img_scale_db = 0, + # img_normalize = False, + # untransform = False, + is_drag=True, + to_pil=True) + + if step_idx % global_state['draw_interval'] == 0: + #print('Current Source:') + for key_point, p_i, t_i in zip(valid_points, p_to_opt, + t_to_opt): + global_state["points"][key_point]["start_temp"] = [ + p_i[1], + p_i[0], + ] + global_state["points"][key_point]["target"] = [ + t_i[1], + t_i[0], + ] + start_temp = global_state["points"][key_point][ + "start_temp"] + #print(f' {start_temp}') + + image_result = global_state['generator_params']['image'] + image_draw = update_image_draw( + image_result, + global_state['points'], + global_state['mask'], + global_state['show_mask'], + global_state, + ) + global_state['images']['image_raw'] = image_result + + yield ( + global_state, + step_idx, + global_state['images']['image_show'], + # gr.File.update(visible=False), + gr.Button.update(interactive=False), + gr.Button.update(interactive=False), + gr.Button.update(interactive=False), + gr.Button.update(interactive=False), + gr.Button.update(interactive=False), + # latent space + gr.Radio.update(interactive=False), + gr.Button.update(interactive=False), + # enable stop button in loop + gr.Button.update(interactive=True), + + # update other comps + gr.Dropdown.update(interactive=False), + gr.Number.update(interactive=False), + gr.Number.update(interactive=False), + gr.Button.update(interactive=False), + gr.Button.update(interactive=False), + gr.Checkbox.update(interactive=False), + # gr.Number.update(interactive=False), + gr.Number.update(interactive=False), + ) + + # increate step + step_idx += 1 + + image_result = global_state['generator_params']['image'] + global_state['images']['image_raw'] = image_result + image_draw = update_image_draw(image_result, + global_state['points'], + global_state['mask'], + global_state['show_mask'], + global_state) + + # fp = NamedTemporaryFile(suffix=".png", delete=False) + # image_result.save(fp, "PNG") + + global_state['editing_state'] = 'add_points' + + yield ( + global_state, + 0, # reset step to 0 after stop. + global_state['images']['image_show'], + # gr.File.update(visible=True, value=fp.name), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + # latent space + gr.Radio.update(interactive=True), + gr.Button.update(interactive=True), + # NOTE: disable stop button with loop finish + gr.Button.update(interactive=False), + + # update other comps + gr.Dropdown.update(interactive=True), + gr.Number.update(interactive=True), + gr.Number.update(interactive=True), + gr.Checkbox.update(interactive=True), + gr.Number.update(interactive=True), + ) + + + +def on_click_stop(global_state): + """Function to handle stop button is clicked. + 1. send a stop signal by set global_state["temporal_params"]["stop"] as True + 2. Disable Stop button + """ + global_state["temporal_params"]["stop"] = True + + return global_state, gr.Button.update(interactive=False) + + + +def on_click_remove_point(global_state): + choice = global_state["curr_point"] + del global_state["points"][choice] + + choices = list(global_state["points"].keys()) + + if len(choices) > 0: + global_state["curr_point"] = choices[0] + + return ( + gr.Dropdown.update(choices=choices, value=choices[0]), + global_state, + ) + + # Mask +def on_click_reset_mask(global_state): + global_state['mask'] = np.ones( + ( + global_state["images"]["image_raw"].size[1], + global_state["images"]["image_raw"].size[0], + ), + dtype=np.uint8, + ) + image_draw = update_image_draw(global_state['images']['image_raw'], + global_state['points'], + global_state['mask'], + global_state['show_mask'], global_state) + return global_state, image_draw + + + + # Image +def on_click_enable_draw(global_state, image): + """Function to start add mask mode. + 1. Preprocess mask info from last state + 2. Change editing state to add_mask + 3. Set curr image with points and mask + """ + global_state = preprocess_mask_info(global_state, image) + global_state['editing_state'] = 'add_mask' + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw(image_raw, global_state['points'], + global_state['mask'], True, + global_state) + return (global_state, + gr.Image.update(value=image_draw, interactive=True)) + +def on_click_remove_draw(global_state, image): + """Function to start remove mask mode. + 1. Preprocess mask info from last state + 2. Change editing state to remove_mask + 3. Set curr image with points and mask + """ + global_state = preprocess_mask_info(global_state, image) + global_state['edinting_state'] = 'remove_mask' + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw(image_raw, global_state['points'], + global_state['mask'], True, + global_state) + return (global_state, + gr.Image.update(value=image_draw, interactive=True)) + + + +def on_click_add_point(global_state, image: dict): + """Function switch from add mask mode to add points mode. + 1. Updaste mask buffer if need + 2. Change global_state['editing_state'] to 'add_points' + 3. Set current image with mask + """ + + global_state = preprocess_mask_info(global_state, image) + global_state['editing_state'] = 'add_points' + mask = global_state['mask'] + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw(image_raw, global_state['points'], mask, + global_state['show_mask'], global_state) + + return (global_state, + gr.Image.update(value=image_draw, interactive=False)) + + + +def on_click_image(global_state, evt: gr.SelectData): + """This function only support click for point selection + """ + xy = evt.index + if global_state['editing_state'] != 'add_points': + print(f'In {global_state["editing_state"]} state. ' + 'Do not add points.') + + return global_state, global_state['images']['image_show'] + + points = global_state["points"] + + point_idx = get_latest_points_pair(points) + if point_idx is None: + points[0] = {'start': xy, 'target': None} + print(f'Click Image - Start - {xy}') + elif points[point_idx].get('target', None) is None: + points[point_idx]['target'] = xy + print(f'Click Image - Target - {xy}') + else: + points[point_idx + 1] = {'start': xy, 'target': None} + print(f'Click Image - Start - {xy}') + + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw( + image_raw, + global_state['points'], + global_state['mask'], + global_state['show_mask'], + global_state, + ) + + return global_state, image_draw + + + +def on_click_clear_points(global_state): + """Function to handle clear all control points + 1. clear global_state['points'] (clear_state) + 2. re-init network + 2. re-draw image + """ + clear_state(global_state, target='point') + + renderer: Renderer = global_state["renderer"] + renderer.feat_refs = None + + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw(image_raw, {}, global_state['mask'], + global_state['show_mask'], global_state) + return global_state, image_draw + + + +def on_click_show_mask(global_state, show_mask): + """Function to control whether show mask on image.""" + global_state['show_mask'] = show_mask + + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw( + image_raw, + global_state['points'], + global_state['mask'], + global_state['show_mask'], + global_state, + ) + return global_state, image_draw + + +if __name__ == "__main__": + with gr.Blocks() as app: + # renderer = Renderer() + global_state = gr.State({ + "images": { + # image_orig: the original image, change with seed/model is changed + # image_raw: image with mask and points, change durning optimization + # image_show: image showed on screen + }, + "temporal_params": { + # stop + }, + 'mask': + None, # mask for visualization, 1 for editing and 0 for unchange + 'last_mask': None, # last edited mask + 'show_mask': True, # add button + "generator_params": dnnlib.EasyDict(), + "params": { + "seed": 0, + "motion_lambda": 20, + "r1_in_pixels": 3, + "r2_in_pixels": 12, + "magnitude_direction_in_pixels": 1.0, + "latent_space": "w+", + "trunc_psi": 0.7, + "trunc_cutoff": None, + "lr": 0.001, + }, + "device": device, + "draw_interval": 1, + "renderer": Renderer(disable_timing=True), + "points": {}, + "curr_point": None, + "curr_type_point": "start", + 'editing_state': 'add_points', + 'pretrained_weight': init_pkl + }) + + # init image + global_state = init_images(global_state) + + with gr.Row(): + with gr.Row(): + # Left --> tools + with gr.Column(scale=3): + # Pickle + with gr.Row(): + with gr.Column(scale=1, min_width=10): + gr.Markdown(value='Pickle', show_label=False) + + with gr.Column(scale=4, min_width=10): + form_pretrained_dropdown = gr.Dropdown( + choices=list(valid_checkpoints_dict.keys()), + label="Pretrained Model", + value=init_pkl, + ) + + # Latent + with gr.Row(): + with gr.Column(scale=1, min_width=10): + gr.Markdown(value='Latent', show_label=False) + + with gr.Column(scale=4, min_width=10): + form_seed_number = gr.Number( + value=global_state.value['params']['seed'], + interactive=True, + label="Seed", + ) + form_lr_number = gr.Number( + value=global_state.value["params"]["lr"], + interactive=True, + label="Step Size") + + with gr.Row(): + with gr.Column(scale=2, min_width=10): + form_reset_image = gr.Button("Reset Image") + with gr.Column(scale=3, min_width=10): + form_latent_space = gr.Radio( + ['w', 'w+'], + value=global_state.value['params'] + ['latent_space'], + interactive=True, + label='Latent space to optimize', + show_label=False, + ) + with gr.Row(): + with gr.Column(scale=3, min_width=10): + form_custom_image = gr.UploadButton(label="inverse custom image", + file_types=['.png', '.jpg', '.jpeg']) + with gr.Column(scale=3, min_width=10): + form_reset_custom_image = gr.Button('reset custom image', interactive=False) + with gr.Row(): + with gr.Column(scale=3, min_width=10): + form_save_image_path = gr.Textbox(label="save image to",value='./test.png') + form_save_image = gr.Button('save',interactive=True) + + + # Drag + with gr.Row(): + with gr.Column(scale=1, min_width=10): + gr.Markdown(value='Drag', show_label=False) + with gr.Column(scale=4, min_width=10): + with gr.Row(): + with gr.Column(scale=1, min_width=10): + enable_add_points = gr.Button('Add Points') + with gr.Column(scale=1, min_width=10): + undo_points = gr.Button('Reset Points') + with gr.Row(): + with gr.Column(scale=1, min_width=10): + form_start_btn = gr.Button("Start") + with gr.Column(scale=1, min_width=10): + form_stop_btn = gr.Button("Stop") + + form_steps_number = gr.Number(value=0, + label="Steps", + interactive=False) + + # Mask + with gr.Row(): + with gr.Column(scale=1, min_width=10): + gr.Markdown(value='Mask', show_label=False) + with gr.Column(scale=4, min_width=10): + enable_add_mask = gr.Button('Edit Flexible Area') + with gr.Row(): + with gr.Column(scale=1, min_width=10): + form_reset_mask_btn = gr.Button("Reset mask") + with gr.Column(scale=1, min_width=10): + show_mask = gr.Checkbox( + label='Show Mask', + value=global_state.value['show_mask'], + show_label=False) + + with gr.Row(): + form_lambda_number = gr.Number( + value=global_state.value["params"] + ["motion_lambda"], + interactive=True, + label="Lambda", + ) + + form_draw_interval_number = gr.Number( + value=global_state.value["draw_interval"], + label="Draw Interval (steps)", + interactive=True, + visible=False) + + # Right --> Image + with gr.Column(scale=8): + form_image = ImageMask( + value=global_state.value['images']['image_show'], + brush_radius=20).style( + width=768, + height=768) # NOTE: hard image size code here. + gr.Markdown(""" + ## Quick Start + + 1. Select desired `Pretrained Model` and adjust `Seed` to generate an + initial image. + 2. Click on image to add control points. + 3. Click `Start` and enjoy it! + + ## Advance Usage + + 1. Change `Step Size` to adjust learning rate in drag optimization. + 2. Select `w` or `w+` to change latent space to optimize: + * Optimize on `w` space may cause greater influence to the image. + * Optimize on `w+` space may work slower than `w`, but usually achieve + better results. + * Note that changing the latent space will reset the image, points and + mask (this has the same effect as `Reset Image` button). + 3. Click `Edit Flexible Area` to create a mask and constrain the + unmasked region to remain unchanged. + """) + gr.HTML(""" + +
+ Gradio demo supported by + + OpenMMLab MMagic +
+ """) + show_mask.change( + on_click_show_mask, + inputs=[global_state, show_mask], + outputs=[global_state, form_image], + ) + undo_points.click(on_click_clear_points, + inputs=[global_state], + outputs=[global_state, form_image]) + form_image.select( + on_click_image, + inputs=[global_state], + outputs=[global_state, form_image], + ) + enable_add_mask.click(on_click_enable_draw, + inputs=[global_state, form_image], + outputs=[ + global_state, + form_image, + ]) + enable_add_points.click(on_click_add_point, + inputs=[global_state, form_image], + outputs=[global_state, form_image]) + form_reset_mask_btn.click( + on_click_reset_mask, + inputs=[global_state], + outputs=[global_state, form_image], + ) + + form_stop_btn.click(on_click_stop, + inputs=[global_state], + outputs=[global_state, form_stop_btn]) + + form_draw_interval_number.change( + partial( + on_change_single_global_state, + "draw_interval", + map_transform=lambda x: int(x), + ), + inputs=[form_draw_interval_number, global_state], + outputs=[global_state], + ) + form_start_btn.click( + on_click_start, + inputs=[global_state, form_image], + outputs=[ + global_state, + form_steps_number, + form_image, + # form_download_result_file, + # >>> buttons + form_reset_image, + enable_add_points, + enable_add_mask, + undo_points, + form_reset_mask_btn, + form_latent_space, + form_start_btn, + form_stop_btn, + # <<< buttonm + # >>> inputs comps + form_pretrained_dropdown, + form_seed_number, + form_lr_number, + show_mask, + form_lambda_number, + ], + ) + form_lr_number.change( + on_change_lr, + inputs=[form_lr_number, global_state], + outputs=[global_state], + ) + form_custom_image.upload(on_click_inverse_custom_image, inputs=[form_custom_image, global_state], + outputs=[global_state, form_image,form_reset_custom_image]) + form_save_image.click(on_save_image,inputs=[global_state,form_save_image_path],outputs=[]) + + form_reset_custom_image.click(on_reset_custom_image,inputs=[global_state],outputs=[global_state,form_image]) + # ==== Params + form_lambda_number.change( + partial(on_change_single_global_state, ["params", "motion_lambda"]), + inputs=[form_lambda_number, global_state], + outputs=[global_state], + ) + form_latent_space.change(on_click_latent_space, + inputs=[form_latent_space, global_state], + outputs=[global_state, form_image]) + form_seed_number.change( + on_change_update_image_seed, + inputs=[form_seed_number, global_state], + outputs=[global_state, form_image], + ) + form_reset_image.click( + on_click_reset_image, + inputs=[global_state], + outputs=[global_state, form_image], + ) + form_pretrained_dropdown.change( + on_change_pretrained_dropdown, + inputs=[form_pretrained_dropdown, global_state], + outputs=[global_state, form_image], + ) + #gr.close_all() + app.queue(concurrency_count=3, max_size=20) + app.launch(share=args.share) From 54403de24585a8b8c41e8825480f1ded854b9899 Mon Sep 17 00:00:00 2001 From: Tianhao Xie <52686796+tianhaoxie@users.noreply.github.com> Date: Fri, 30 Jun 2023 08:47:46 -0400 Subject: [PATCH 2/5] Add files via upload --- gan_inv/PTI.py | 53 +++ gan_inv/__init__.py | 9 + gan_inv/__pycache__/PTI.cpython-39.pyc | Bin 0 -> 2054 bytes gan_inv/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 145 bytes gan_inv/__pycache__/inversion.cpython-39.pyc | Bin 0 -> 7641 bytes gan_inv/checkpoints/weights/v0.1/vgg.pth | Bin 0 -> 7289 bytes gan_inv/inversion.py | 355 ++++++++++++++++++ gan_inv/lpips/__init__.py | 5 + .../lpips/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 269 bytes .../__pycache__/base_model.cpython-39.pyc | Bin 0 -> 2791 bytes .../__pycache__/dist_model.cpython-39.pyc | Bin 0 -> 11720 bytes .../__pycache__/networks_basic.cpython-39.pyc | Bin 0 -> 8063 bytes .../pretrained_networks.cpython-39.pyc | Bin 0 -> 5237 bytes gan_inv/lpips/__pycache__/util.cpython-39.pyc | Bin 0 -> 5781 bytes gan_inv/lpips/base_model.py | 58 +++ gan_inv/lpips/dist_model.py | 314 ++++++++++++++++ gan_inv/lpips/networks_basic.py | 188 ++++++++++ gan_inv/lpips/pretrained_networks.py | 181 +++++++++ gan_inv/lpips/util.py | 160 ++++++++ 19 files changed, 1323 insertions(+) create mode 100644 gan_inv/PTI.py create mode 100644 gan_inv/__init__.py create mode 100644 gan_inv/__pycache__/PTI.cpython-39.pyc create mode 100644 gan_inv/__pycache__/__init__.cpython-39.pyc create mode 100644 gan_inv/__pycache__/inversion.cpython-39.pyc create mode 100644 gan_inv/checkpoints/weights/v0.1/vgg.pth create mode 100644 gan_inv/inversion.py create mode 100644 gan_inv/lpips/__init__.py create mode 100644 gan_inv/lpips/__pycache__/__init__.cpython-39.pyc create mode 100644 gan_inv/lpips/__pycache__/base_model.cpython-39.pyc create mode 100644 gan_inv/lpips/__pycache__/dist_model.cpython-39.pyc create mode 100644 gan_inv/lpips/__pycache__/networks_basic.cpython-39.pyc create mode 100644 gan_inv/lpips/__pycache__/pretrained_networks.cpython-39.pyc create mode 100644 gan_inv/lpips/__pycache__/util.cpython-39.pyc create mode 100644 gan_inv/lpips/base_model.py create mode 100644 gan_inv/lpips/dist_model.py create mode 100644 gan_inv/lpips/networks_basic.py create mode 100644 gan_inv/lpips/pretrained_networks.py create mode 100644 gan_inv/lpips/util.py diff --git a/gan_inv/PTI.py b/gan_inv/PTI.py new file mode 100644 index 0000000..1b39679 --- /dev/null +++ b/gan_inv/PTI.py @@ -0,0 +1,53 @@ +import torch +from inversion import inverse_image,get_lr + +from tqdm import tqdm +from torch.nn import functional as F +from lpips import util +def toogle_grad(model, flag=True): + for p in model.parameters(): + p.requires_grad = flag + + +class PTI: + def __init__(self,G,l2_lambda = 1,max_pti_step = 400, pti_lr = 3e-4 ): + self.g_ema = G + self.l2_lambda = l2_lambda + self.max_pti_step = max_pti_step + self.pti_lr = pti_lr + def cacl_loss(self,percept, generated_image,real_image): + + mse_loss = F.mse_loss(generated_image, real_image) + p_loss = percept(generated_image, real_image).sum() + loss = p_loss +self.l2_lambda * mse_loss + return loss + + def train(self,img): + inversed_result = inverse_image(self.g_ema,img,self.g_ema.img_resolution) + w_pivot = inversed_result['latent'] + ws = w_pivot.repeat([1, self.g_ema.mapping.num_ws, 1]) + toogle_grad(self.g_ema,True) + percept = util.PerceptualLoss( + model="net-lin", net="vgg", use_gpu='cuda:0' + ) + optimizer = torch.optim.Adam(self.g_ema.parameters(), lr=self.pti_lr) + print('start PTI') + pbar = tqdm(range(self.max_pti_step)) + for i in pbar: + lr = get_lr(i, self.pti_lr) + optimizer.param_groups[0]["lr"] = lr + + generated_image,feature = self.g_ema.synthesis(ws,noise_mode='const') + loss = self.cacl_loss(percept,generated_image,inversed_result['real']) + pbar.set_description( + ( + f"loss: {loss.item():.4f}" + ) + ) + optimizer.zero_grad() + loss.backward() + optimizer.step() + with torch.no_grad(): + generated_image = self.g_ema.synthesis(ws, noise_mode='const') + + return generated_image diff --git a/gan_inv/__init__.py b/gan_inv/__init__.py new file mode 100644 index 0000000..939e7c6 --- /dev/null +++ b/gan_inv/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/gan_inv/__pycache__/PTI.cpython-39.pyc b/gan_inv/__pycache__/PTI.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1454320e702a0de59fc089746d365a778fedb3b3 GIT binary patch literal 2054 zcmZ8iOOF&c5Vqa#-kDukmd6I7tcVn4qJ&*YP6!bYC@Yj`g~UctT0|`~_Vo6oU+eZj z7L{DmBnRCNQ{V9Vu4xoy{1RbTVrFeFfpZ@d)WdW8Ielas}Q z$t`I0LlA-pnvsYICajF+EMmD8S-BnA6m@py$d z2jmq!7%>x8T=86Lc!L=#c|1#%tm8y+0grR?Qpn8MlZ+>ZRIhmy6lxMqlVPXIF!7eU>oxArB?_H^$nl)XPNhZ5nz!{{r zK@d8i>NNCYoSZCZ%`Is4CmP&jkRK}jc#tN1zWNeX`RMt$kb}=Epnc; zZsY!E3jN0TImBL8*0r(fS>C;L{N_w5VDe4KUqP1xHT$Od#B>U@X>zSI9NkP!_8qqA@ie5fImrxaZ@rBk-BAe@N3g_N(bm>3t_Z+^675ZW3>XZS=T__JU>Z)>P`@0k@Z6#BpR?B0 z>mpcC*uSAq!RqM+UC`ERJV+na`g*hRaqoh)K@&9Y^o$6%A-QjKFPm;PHl+UWIeW<3 zu!Yl?X94>VUb((aR#rP`yyZEu(F~U7S45n8Y5&eva986lFRTdkgXMPDyUP!-KmJt* z%Zn@W+Yac^ws+^Z;SWGWHq?bhNwImAyrA{vJ@^LekI))Sg3|}a&A@Gi?h@f}?iEsh zk)?&P9w*7%9nS>6zBjVe0DWkI(s(lnhoqVTx3EmjgIaT?cQBQZuhgfw(HWOTt&xk0 zGKK9(62``mn7ioz>pOGn%2$)oC54%*&f+)gX`aMD6lDgCRTjoovXUI`%Xw9$1u$JP z%i||?#{df5Zr1Pj;Wy*icNm|JE`el>Q$l{_#=b3hu8`EboH9nu}a}+c$}STnVVoPRoQWr@gMQ=&ri4# zkkXjqUePN?!#-0iurAto!{h?(oMk%lp2Su9xYUM$T_531LHG6)4mNi6h$}7@VfKC!Ia=|wz`U&YKoGh~eR>W!)28k}u5|%s>(C&r*cRw)w*I&6_^=W} z-DC3r6RkVr0342_r_P|*0x@!tg9`W!#IXrubm$D|IW~hhJ_h)Ejk<{Y5ixZa1^x<7 z2&6KF368;3qaWaI$M?Fm5j_N~uEL@TdYAy`K0J`b6nd%mo!@!+N>M0eIEAEaoUBT# ix{HcFKJc0oS9dSzsq{ww_L~^3HNM2BTfj9~L+4+_=<{j- literal 0 HcmV?d00001 diff --git a/gan_inv/__pycache__/__init__.cpython-39.pyc b/gan_inv/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e02497c77fba70a8c38f579a13812acb1995eacd GIT binary patch literal 145 zcmYe~<>g`k0^jxXQVuXOFgylvkO5GP!2yViS%5?eLokCTqu)w~B9JhG_+_S_k)NBY zUy_-amywvSUzA#$npl*aq3=?ZnC|ZA7atJfsh^&h7oV9|rXLR^GfU#*^$IF)aoFVM Pr9o&-+S&|uTm*#xZe2dABwL(rfL69g@e}w3Kwz5oNJoU1Pe7M ze=>)m%tPMR9i8QMhGU>k4~@3zn5xbQt#-k&6>o;ccF8Fz-U`RsWv8t8LRe{6oQmRE zSaquEJLAr{uthPPXiqwmqqdq;6D2YBuI5aMvZ&xWEnXHAV)9+)91=A#g_;>LEe_#1 z>sMQIuV`XM%)YCM*FdS|1T7#G~S*I3*rKeO;V>zbFn1+t*u1#gcgZ7E^CL zA)XW~V)gxTXIY$aj)|wl+Pj+12FLwlLv~rij2`iq#o2jna01V#=QW?o&&+G!mc_I4 zns^R38}aABpW5W&bEES4{a()>==H+9HiV4*^59YN`5|*oqBp+KI^~@5CtHt+7sUn4 z@Zv4qsr!rmDZh@jp7u`<9`{d*r6G5o@N0u7#Y_II|AfFicCpJCkz)tg7qt52yQCdg z>nyXS-x0p_C3rLHCP5oaX}#BJBtg{iLMkQF>%{AkY{%5;wYJyv!5GPwXy0WRhYGzU z2pd$`y9_Wz+?R02?}G@|XQ_5sd;6Pxp7IuyZeHu_DNA|6TBs4aFobzapVkKYP#YM6 z5814Cpp}2R6(j1x+GO pS;Oi>UzYWQJ{zV~ceLq=a5(D-%}qoRa9i}?E14BjJw z6t7B}{9POmP1?liAa{6X#T&tT5@&`8+VNf0av#@S6U}uiGvf`f>t|*X$;L)z$S4dm zqwRa0Oz%e9SwZ@$L*2-%P87sG1`4A%&P-Lxj5yfw?`U!nqkZ(;>PFP|SChc&Yb`mHwbiEAaf8l{RouReu|q4}o7sdK%$0t#7kUySPGY$D2#CfE zZnKiX+2Vc97w_@=+_D+xM_84Wxuhmw1`s3E`3uZ@6_s&Y@KvCPMw>G7%C% zPL{y!t*Ha8^Fl1ig;~>LlE21m`f+&z{D)MbY5#KJ!|#2Ut~(bRH2=MecFDr6)^iz5 zqAWnF%|Wpiv|ZkZnh7Z2fTsr35Xv`X(`fA*D5Gzxr8@PSrW%$S(^z3$zp7MZ^z9#A z!0zhppo5L~l8sDnL~+KuLEX$)l9hr^fE^9pP&&5s+FcQCcN~kjUe_sv(k1rc)mu$J z!5e2Vw5qTtQ8PlyvVRHKkjva0X1G4KV8=b_{dom_MJzoaIq8)w!I)BP-!N zi;f!xn|@u-tk6sRPLkPi0$b2bHu9Z$Xg|t%FITDBIgAo3IkPI4q$dw(5uK!cslSXu zs*(1z$oArx-{J$v9D0KTN!lY!P&2Vw1w3tG;g}ZCZVZaTMono*GRFZWC#D3Cb&xhd z=-xY$9N*ph)I^_^W2&Li?bXLJOZbV0V`())$k@$vD4~2rb!hi+PLrEmKbMHv3p&ZU ztblU6m-rb^A{^0PyL%J1&4PI1NItwrrM9=}y8&6l7cpkM45Hz*8TT#D?-@G(xnb~+ z4b%8v!?Hgz3jAl*eR&am8r1RNMH9b>JAMMh*T^F~TrdYl3Z7VZOz@D3QxK*D`8Y+? zl|%tjv9ro+$`bgOqRx8IlqWHQd=#XapNGG{&_n^(k1qs2hMwf5)=w{9g?wz+?Rah9 zb+eM|wj-*1HLjxv6Z$5XUrO|Wf!j=rLx!g{dKQKnN_JBd#s@5?L*9fg_Dg9gWxK3D z);XW(t>OSByu|Gwt4IReAC4-IC9aLbtaWI8#R zCI0Q1ehssnOKbQ>O_bA`JAa#P{bo8oI1Jak01q&g8o*TFPNxQoqJU8*TT7xa)cZhm z>6AO49BCa*r&dhR8+S;|D7aGt z7R9`u(=f)tcKb8hYoi&>ZnE3Dv~KgQ?_gFd7|#}SkZlb$f0Ne^R%v$7q+{t3$m>jc zbXrrp@Kkz?xc>2UdayQxE~Li?XVr6haDtvA4ZVn#J&icoe?dcVrnDo%#S#dZ$GI$Ph zS@^yo7Sp5Zxs)E;uRD^?Q{4~XAsBRf>z}d5hFnZ%#8JqZ_E>TM&mJ#BT1$B^l9M?0 zzTQ8Po`6+7el?gl1!MKgur*iex&*HsAY`t|EeQWc(D0Qo@L#Za z1hxXW!nbF>#QVUYLQB~VO%kl&-{M;aKFszEY9s2v8``W|W)Py@_^uDMsC4&;6%*=RWqLJQjTWycR)0=jnv z>UPR-vJJnRz{i|l{zzYWdVOd5`DOLp;XvicR@>HDo7po6w#734lEK`ha0yW6oO{mP7&gx5w`5kOG1RqzTlX63F2q=_-3UN_Dm-dSpV zg-DHFVOO#-z#w0CBLaJJnkq})ji%d;qHs-QHnzj1@I^U!`5LtscYHaLf_)v50ag&$ z6(9(Sut>&RGRekcpjqL^4H+Q7i~y8_#BbLh&CI6jx4q0%AWESxQVgf0%1UK2YuoLH zJ%|VzN&x3q=%diUX6#ftE=5{yj!flsHJ$xfHPeH3(=k9|r<`kbt}9M)+s(PyF}Gc; zj~)nBJ_$>x0vgB!fW zt18m7nZ;@!^ZQ1zq~|p?#C*V?z*!~44<)_E%j_`t^X4=w<5@Dw^bWJEBWeuPpuWnE zqhzrXpW_q72{yq?hJpGyzRZ@{1V)=d{5S`+i=R$!l-W66!xs!b!|eM&b@z>Xb}=8J zYFLaj{&UNcWTLM9;sAyM!#r332q{sxnR1tH*%VlgLYSrtAdLOUO5f0uob( zsEDE{p~e(ru-(;6M}V%p!QI<%5g&SMzjo&IowFAj6bkKK#B5S>OQ6zlL7OOBk^qu4n8O zRhEB@O8F;5zDwkx&<-I?YGtoOgG}PyMP`P|YAJLOROE@& z3J`~KYN@`9NDo-iq$Ra9HoOMHZkVfk=#q!gQ2r`@I~7{he*2XMZTa3+0YiBPwXPSi zlZY-5yV6D>O$81s;kzbsAi7|{8b-?H>D1_(otf02w3X^(wk*LFsx%YA$B0OFk$p?- zRTU28$DfJ3aLBFtTi7OKfOY9|yjE~EtWDlB0nV3C-P+?H;L>JQ4ZAySFhx1`a4wn zbr5HKWT6F34Hi6u59kIrqD1}y_4!>QzenWviR`-$`2#9`mIj+}WoxqWm`EhE-EJIY<(zHwl4yNhxorT{st985I=ll?-vDh2 z-~qFgC}G)+peeDl7$-MFAF0`Gs$04(UXA2c_To$AWUYGcE1Z05r~5V6pANKusW zdlQU`SYFK&z(`9lNoi7~eibzmGi5^<5sG34mY~E7vijf4#=2?M8 zXath=^e@WlnU!v*xl`!)$&(=_tlwxhaq@0sm`q1yR_FmiH@m%^N$f%XX5z;Qy@j<< zeoJNW>g-y5Mb@eKkLukU!A?H&DXJrHmDx)BRyrLOrzl670qEclgBt<;F(6gcqrdsu z>yl)k>|HiS$#`T#WBMn;5}i&MT+hs~8z3Yo(x^bH$Y%yoj_^;=i^?*t5qFV#6!(Wv z>AKuBD~n5}L0Oc{{v|KIc^P48X1_K1dr6%uDDywj7zGMeaCEa7=yo971KNlajn~Qw zkU(PFk=)FT2%(q!Dh>aAdcFGQegyeCMI&FMN*w@5en5S0Q4wb{_V6FcXV|dJej!hG zzDVvTCi725v#f4Hvh*9(81}4fmu%azs|K5x0E&eH=07WaMLS$Ft523VHvxGG-NG#A t&ja))#u$}Ks$JFp$U062+`ea-Rb{3*j5zh6J|*-Ahpeb|-_psj{~s!YYl{E? literal 0 HcmV?d00001 diff --git a/gan_inv/checkpoints/weights/v0.1/vgg.pth b/gan_inv/checkpoints/weights/v0.1/vgg.pth new file mode 100644 index 0000000000000000000000000000000000000000..47e943cfacabf7040b4af8cf4084ab91177f1b88 GIT binary patch literal 7289 zcma)Ad00)`|8AZLg;1#^Iu)8n=lv+^ra`GxNTpG8?K-9D)M-A=GeV@u6f#8S5*L{w z(>2^{7NtxXuIYE~P51k~-@kt6dG^_RuXjK1yPmz?^?ue~JdpyKh_CB^`K%VUZQ-O@ zJdp)2L42lg3vn-v)y7!)x}f7Z&N&=nyZMl666 z9UBlb%Qv1A8WAVs8yn;w9~u$p%L$5#i;ne8Fpy!ySuA(SIKg9nD}oq_q330OiHP>& zxD7SRFp};v`er7krluyQ#s-X(yF@^Ipr3^vBV8bD>BMny5Mnt9J6H-1VGe8PEV#;d z;`iGwPlm74XF#Tu& z{g@$oga4(M`)_&^Gb1B&BU1xL{vZ0W1tLPi!}JS;1fBn@`#&#+1;z;k#t#V?{;z<- ze+w8J8=DyDndvbT{t=ic5D@u~fXJ}G-xtFIlLP{aLjp$sE1>jm0X-8lQ$rI*`5%GF z{~rOBpr9OAMwQE`xl8*-266lX{WyNiln{Y&)CI*m?ARH&?({z@= zlmbIBnp}nz^0QHkjkkU`jFeY5abQm%lA{(}o9%IgBEQZ0^{|5CLHkZLLNJB~n%p5LbIoxP9 zr1aB1%v>%rZy2&3A{%y=9%I90ei;UB{|)Lh^SO-OFlaxdWXL#h8OPzq1w%?deZ)9% z8RubW;Skxd#q<~#F5@~3y8Ro}XBKf8_hHauNXdxtv4-hZW?APtd4`#dl^ap}&ojSryZy@1pF^Xip0gf-ms;Z&jn>rRady;747)pK6gk#5Mp>CH9 zrR;S?_4QR~>hi;nMi;zje}r(!MRZH(0km@YFm&09@y=V}XnzpNZ|5T6YXbrUrsCQ7 zQq;6>Maqmegr};~HyuA@9ZZM+3I)2c{WH#=ybPUvx^Osk0f8wg*wSCiz7*y}DuESv zyt@+~Y6>*_p*cmn*Fj%s0y#|l|TmJ@0P6sxfeuI~n{=m&gme|x=fTt2ZSkkl)DKhfpYrGMsU+}@i zzGuJ9zm6LRMCfVgH7t|Tz=X(=FdnSKyge=8y$Pp$gZ0?lF`N3S7>dDaSl5(-`RDv- zaOy))*h3hXsbRb5M`Tw|L+qdGkn{1OPfxF4(;WtD>wmyBT>^z^M%Z&P3y-%2pniWT zx<%LHrSmIzck!`QXBjSdZpS006Th^+Bq!My>=u_^d~i@C3mtWe?A3(pp#+G%6(gD{ zM)TwDKymABq;<22vsRsYoMs|v#(WsNykYm*9zx!dwfOW%n(AdXLt%>~=@`Z#UGRSH zjjO<`#)-7;%VYNb?>QLN$3srJ1cKS;uv*Ln3KuS7=R8U3{QM3zUE8tG-~xgts-thW zCH8e`Q&9IU$c#M$>)gwbQ6GTzx@L@@?SzZsi=be76%#fsrs+TCl5|r6)}}wf6NfFZ z_gIS^2QGqF+lbQ>_>jM=NOylPftJQYG+a%`w;*x4v09nVg@JBPuER`?EPQnM4d*B9 zV;_w^gpnF4=-6@;;(tV8+T~S9zC4D~a;i|Mx&bY%a`bSKKPFhp)5+L!#An}Uzt5;c zOA8@$#9?@fSW-r|JC)wj#k1Ag=>8stGc9E(eKnm*BDgs9WH0zy&mr|)8WInxpe?lq z1EG1)eCLC64&&fgIEu_-+i>E*To{AS+cJjw>#jrf{b~5$jX|l}D_nY<2-VyXR3`qN?Ofr8 z*vY3b@lg?sq!y5zUKtu?jp$Qc4YvGZKu;5->9M~iJ?RNXg{2hjd^4WP8Xh6cKLGi& z4zZOJ8t~O>3lS}`AmpE&PggaYp(CzB4Z_j*mS919n|C6+ z=Lu}@j~amOfl9bHHx@4kb@DJ{B~5QxcF+R%7J9pOLjvf00}(5S?sM;F@g zy@HFf4I(7f^@Oebax~7jN}+VcS@bmYLAC7yP83SP!B7wD-ygxJd^L!V&BWVN=5U(c z1SgX#Y(w+4IJAp!L|KK#rrN_`%3~<~`U4XxpW*6QZ)}p(qi#7R`tsD8BDGs!mYjey zmkvU>N1i_2*2R`9F4Uc7kDHbX5a}+$jcF4g{D%VCPiWFWmoRlaZNcWRilk{2iZvQh z5bYR4>%{m-DXzxsjth7>*Bi4$eud4BDKHGbhK+Nj5F8pqsGdsVkABBGMuK>z8hG6z zNzeHsspM`0csir$L*ZfkxbY4;_pNa)=Lgu#e%$HNLD5nr($1cR6(hbu(*7A%&9lHH zX<0HFH=RxwmO_#W;Zj+P_Q#X)h5Zc2W5!drkS(1*IS8lOY^0|4v&|H^$Uk)!_2=&( zSNs(W=ZVvAZ@N(zsYwA-wD4{go36Qg;q!G}yb6AeD1&9VY^@886Y8+*??>;Kn<((S zf{Q*UV5J)ex0rk+9W#M!lN!cydALhn*jp|K&E@eh=$wl^)^#}SFAdHiY4Yzpi|4yC z@Xk>Q7CK(YTn~sA?uN&QQ|R=rfTfH+l50M|Z`1?aLOOD`Mc~_*lc@W14|-Ltv5MCV z5s|m})^AFr2fC9wnkwH*(elohxaJjv?*lwUJlCbI*OjQTQwYB`_OKn^7o)Wq(3H4< z+Go?~jeiXK|FB2gjH{T-Hz#??5%j>>4u^j)!yH9Nhy>3>d6yocH57B}d!e*pFFLd| zpwQNXUYn`t)Dxo*PHJeX_o6^Q20u1A!yxA>LcU&rj(ai&?{37x*M`*AJ{|X>K0~F_ z1W#^$VP8Jzgwger=}eb83EwcIwIW>*ZaIWUtZ!^{J8#5=e8z@>r|cCI`*7;!6ZU$Q ztr%AE{Q5AWB5euxWrI1Ulf%W@rl+N5lzNdxKak=6$$w4z#uv~4etCDv}SuCX5MMayjVozrzGRN?Jqd4V}rrw z_xMvj4h<)S=+8ZcNORSqoxUZ=vyMl{;!x^|D1@PZ95fQfl0lvd{dw>m2Fz*^%c{cu z>r3!$PdWUL1yk5h(I63v?bl#O+dQh67z4Kp3FuFDLacfn ziK;~7_qBK7c5E`$%6p^pSOMY{IT+tE7q?rSY24BXI6R(A@A?!d`f)a67hULR2m;HHi|Y3gI6PnvHjkl0Dm4&) z7mYuUP}j$B}qZDMHO(@-b>qn0n6G!*Z7inUAQ$XGFqw zc?`B6sGykfk8$*SI(V;sKtpUZENvo5t}PDg12K5*AA%m?4E)*R4fi*KvAtJ-Nxo^= z9>YL(<0-7km`d*FRmk2d2J_Mf(LDYPb%&os!>tT>jJSmFktvvBUkZ(sJm?Ge;6gwN za^ll*-ysO~$7j%-RzdzhO9yvTCFa?$f~RvBB1Q{Q-$psQBbAJi15q#!%tZNv3<%%J z$KIGGeEYH#LOoG%p1U3+q)ITcJqp)dlHe|$33G8il#<#Jo05&)1=g@r<%54a4g0J- zaBf-&s?PF&kqpdgwBfbffE39V??2DPJSkNg8#|E{losGs$ui7-I0p--W#O~x3G8#{ zL(=jbp3mQf(N}`8`DhCc@ojLhb1WV1VjwD>1Fe-jjGM8FI!rfWlIH@Pij70*h#(xC z9e}SIGPL(!xXz>e2UG4*kXk_GJ~~b z``nCbeTtzk$PeX#G`KmK(Dv;?NMku5SNJKWEzP1H!!+>6)?&0Y7groZG40@c*gBNs zO#CHi9$taqZN1o+{Tgbwl2OxpjU;RGG5T~Ws(zhIx`KBw@k=NsmAm7ahhTkVV!A#3D|M$7xEm-!&BWmLa{&9KeS!Ip8Yh2+pn1D?> z;V3(hhz`>X%ubV~CuuM7VQ(QUKJ)Nk+5?mvEWneQk`&l+9Y#75)To|^ytr(^e4mG+ zw@Gkz(xm`{E4bupgqq3{te5qHLQyfg3_J0))r$;PAHh1-Hu8KE4gLWhuD;|W-A9wo zX(!^-f{~rhI2XU>@NrwN1WVYX zNa=V6#(gozy}lH3xL1QHL zuGkz9_@J;D#bIu!dYKIQ!b5O57=|_v!Rew6|^wNolTCL zQn9;Hn%vsk>FwGoOp2^S*bW~OaZbS?FAZ~*vM@5G9hWaA<5p%h9o-y(K9&$!ww6Ko zo*+jf0&(Pd6Nzm&fh?N|G;vE992YUj%jO}rI2PeOQlz?X4U%UR()!#|w0@1mVp>4= z;?yt9hE|_BgBg5R^cTlL zCbIzf*Ba5X{~o^AmmpxLE7i6M>U7#Ps`D2m9qwqFm5@)a`bJdBFpyri9g7+YsCN|) zum329ommlF!ffecy$->s5}gmSV4YTh;*H~I*7;CuJC%y*nk#YEEf1}(+1NO?0LwRf zAhtdl*B=QHt6z|deOwqN^N|*)NUxUmfbUie*-&Ge9r zXjT%MBlTe`K20qKD=8gw8;a2v8US6M3F#(QP`9B!c3I58NBCZrf)lukfD#x)^ zX{dkQg88Ro$@)Y!!m{h&B3Fo#m)c}waSmL7eDfbn5f)p|JqI9V0L>c_TI*(Zi&+}$?T??8up5WDq zqQ)_U9M8?@7Z~8K}gNF)MpH#+Vy+o`` z&V^}WG;|*DDN?wRGDAjS*?VnpLv!GF#tf_S>Y(s2w7i*w`MWBREf_byHLIa=PZJXPMVQzZ zg@|S?_)m35cC$Hr+l#TUwG_`DC&SKJ6&*LNFfyVD?w4e7rON@$C+s1q&PLr=A56Xa zoqajl4G{)w5Vk)TJvaF{<&+PNViP<|4u{th!Fu_p4<1=9z@YdP>>5Q_$|=UJ7y8(> z+yG&#ig8BZgW5YP=zU>KE&i_foFWIc%gRW3F9T;r7Yo<7la$E|dUyav#G8pvTfG?qPFrC@Nj$+9| zNO=~YFA=KQ4&_RXG3XT z7TN`w_VFtZ#Z5_gUdxAmN-i!WC?a897C!jKLVDXuv_39G@UwVWZRcZ}HXkMR-new2 z9Bm`xfK!TEu+`YxR|1}EI(D~b;Zfj9oNUcS(bilXD9eX(M=5rm%!gW79wPK| zvBWzYJ+JHt_Tu^gi==E2l;1xAF0U~Xvz z0)qKqf2zhJwK%+C@^N5mH99U9;`SeF&=in|2iKm{#>EBL^)d)@zvLlRDGfzu7K0@o zD%c|_hSVz%mENt|g=O6%SIHRXCLBi`)EM^fgzbHL6PRxl@jQ z!{w0PpA2JjE<^&?plkUCK^7)rurM0j2nI8v6R=U>)3#An;9o7myO|t>TrNT9m2CV# z0mk&_A^vJPx(?=pua^bgEFMIoim+@=1vU!yKN5Fl;m;kp@K20DiC}#d+!T)c9>Fkw zU<21N(U{K8N2zu^+DwYEM!pEPp(S8?+;FrYA7vIrc)f2ub+Do^uEZZ!t(B+~U5yKR zJY;;2MCF$p%$och;ty+~C0m2FWfHWw=^b8vEcHkuzKK;EJP za-UbgtttrLxWO=azY49pbMQ=B;NiUyX#BVa><@9c=+DPO`BF69%7m)O3Y5QKu$j+; zvwtB1Ywd9BbvR@OGH_hO9nJS+A<<`m%lAv5VVnbrqk=J+l!ddRm9UBR!l-o|s9A8( PyEGODb&D{1dnW!5&PYuU literal 0 HcmV?d00001 diff --git a/gan_inv/inversion.py b/gan_inv/inversion.py new file mode 100644 index 0000000..b22a8c7 --- /dev/null +++ b/gan_inv/inversion.py @@ -0,0 +1,355 @@ +import math +import os +from viz import renderer +import torch +from torch import optim +from torch.nn import functional as F +from torchvision import transforms +from PIL import Image +from tqdm import tqdm +import dataclasses +import dnnlib +from .lpips import util +import imageio + + + +def noise_regularize(noises): + loss = 0 + + for noise in noises: + size = noise.shape[2] + + while True: + loss = ( + loss + + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) + + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2) + ) + + if size <= 8: + break + + noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2]) + noise = noise.mean([3, 5]) + size //= 2 + + return loss + + +def noise_normalize_(noises): + for noise in noises: + mean = noise.mean() + std = noise.std() + + noise.data.add_(-mean).div_(std) + + +def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): + lr_ramp = min(1, (1 - t) / rampdown) + lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) + lr_ramp = lr_ramp * min(1, t / rampup) + + return initial_lr * lr_ramp + + +def latent_noise(latent, strength): + noise = torch.randn_like(latent) * strength + + return latent + noise + + +def make_image(tensor): + return ( + tensor.detach() + .clamp_(min=-1, max=1) + .add(1) + .div_(2) + .mul(255) + .type(torch.uint8) + .permute(0, 2, 3, 1) + .to("cpu") + .numpy() + ) + + +@dataclasses.dataclass +class InverseConfig: + lr_warmup = 0.05 + lr_decay = 0.25 + lr = 0.1 + noise = 0.05 + noise_decay = 0.75 + step = 1000 + noise_regularize = 1e5 + mse = 0.1 + + + +def inverse_image( + g_ema, + image, + percept, + image_size=256, + w_plus = False, + config=InverseConfig(), + device='cuda:0' +): + args = config + + n_mean_latent = 10000 + + resize = min(image_size, 256) + + if torch.is_tensor(image)==False: + transform = transforms.Compose( + [ + transforms.Resize(resize,), + transforms.CenterCrop(resize), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ] + ) + + img = transform(image) + + else: + img = transforms.functional.resize(image,resize) + transform = transforms.Compose( + [ + transforms.CenterCrop(resize), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ] + ) + img = transform(img) + imgs = [] + imgs.append(img) + imgs = torch.stack(imgs, 0).to(device) + + with torch.no_grad(): + + #noise_sample = torch.randn(n_mean_latent, 512, device=device) + noise_sample = torch.randn(n_mean_latent, g_ema.z_dim, device=device) + #label = torch.zeros([n_mean_latent,g_ema.c_dim],device = device) + w_samples = g_ema.mapping(noise_sample,None) + w_samples = w_samples[:, :1, :] + w_avg = w_samples.mean(0) + w_std = ((w_samples - w_avg).pow(2).sum() / n_mean_latent) ** 0.5 + + + + + noises = {name: buf for (name, buf) in g_ema.synthesis.named_buffers() if 'noise_const' in name} + for noise in noises.values(): + noise = torch.randn_like(noise) + noise.requires_grad = True + + + + w_opt = w_avg.detach().clone() + if w_plus: + w_opt = w_opt.repeat(1,g_ema.mapping.num_ws, 1) + w_opt.requires_grad = True + #if args.w_plus: + #latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1) + + + + optimizer = optim.Adam([w_opt] + list(noises.values()), lr=args.lr) + + pbar = tqdm(range(args.step)) + latent_path = [] + + for i in pbar: + t = i / args.step + lr = get_lr(t, args.lr) + optimizer.param_groups[0]["lr"] = lr + noise_strength = w_std * args.noise * max(0, 1 - t / args.noise_decay) ** 2 + + w_noise = torch.randn_like(w_opt) * noise_strength + if w_plus: + ws = w_opt + w_noise + else: + ws = (w_opt + w_noise).repeat([1, g_ema.mapping.num_ws, 1]) + + img_gen = g_ema.synthesis(ws, noise_mode='const', force_fp32=True) + + #latent_n = latent_noise(latent_in, noise_strength.item()) + + #latent, noise = g_ema.prepare([latent_n], input_is_latent=True, noise=noises) + #img_gen, F = g_ema.generate(latent, noise) + + # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. + + if img_gen.shape[2] > 256: + img_gen = F.interpolate(img_gen, size=(256, 256), mode='area') + + p_loss = percept(img_gen,imgs) + + + # Noise regularization. + reg_loss = 0.0 + for v in noises.values(): + noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d() + while True: + reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2 + if noise.shape[2] <= 8: + break + noise = F.avg_pool2d(noise, kernel_size=2) + mse_loss = F.mse_loss(img_gen, imgs) + + loss = p_loss + args.noise_regularize * reg_loss + args.mse * mse_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Normalize noise. + with torch.no_grad(): + for buf in noises.values(): + buf -= buf.mean() + buf *= buf.square().mean().rsqrt() + + if (i + 1) % 100 == 0: + latent_path.append(w_opt.detach().clone()) + + pbar.set_description( + ( + f"perceptual: {p_loss.item():.4f}; noise regularize: {reg_loss:.4f};" + f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}" + ) + ) + + #latent, noise = g_ema.prepare([latent_path[-1]], input_is_latent=True, noise=noises) + #img_gen, F = g_ema.generate(latent, noise) + if w_plus: + ws = latent_path[-1] + else: + ws = latent_path[-1].repeat([1, g_ema.mapping.num_ws, 1]) + + img_gen = g_ema.synthesis(ws, noise_mode='const') + + + result = { + "latent": latent_path[-1], + "sample": img_gen, + "real": imgs, + } + + return result + +def toogle_grad(model, flag=True): + for p in model.parameters(): + p.requires_grad = flag + + +class PTI: + def __init__(self,G, percept, l2_lambda = 1,max_pti_step = 400, pti_lr = 3e-4 ): + self.g_ema = G + self.l2_lambda = l2_lambda + self.max_pti_step = max_pti_step + self.pti_lr = pti_lr + self.percept = percept + def cacl_loss(self,percept, generated_image,real_image): + + mse_loss = F.mse_loss(generated_image, real_image) + p_loss = percept(generated_image, real_image).sum() + loss = p_loss +self.l2_lambda * mse_loss + return loss + + def train(self,img,w_plus=False): + inversed_result = inverse_image(self.g_ema,img,self.percept,self.g_ema.img_resolution,w_plus) + w_pivot = inversed_result['latent'] + if w_plus: + ws = w_pivot + else: + ws = w_pivot.repeat([1, self.g_ema.mapping.num_ws, 1]) + toogle_grad(self.g_ema,True) + optimizer = torch.optim.Adam(self.g_ema.parameters(), lr=self.pti_lr) + print('start PTI') + pbar = tqdm(range(self.max_pti_step)) + for i in pbar: + t = i / self.max_pti_step + lr = get_lr(t, self.pti_lr) + optimizer.param_groups[0]["lr"] = lr + + generated_image = self.g_ema.synthesis(ws,noise_mode='const') + loss = self.cacl_loss(self.percept,generated_image,inversed_result['real']) + pbar.set_description( + ( + f"loss: {loss.item():.4f}" + ) + ) + optimizer.zero_grad() + loss.backward() + optimizer.step() + with torch.no_grad(): + generated_image = self.g_ema.synthesis(ws, noise_mode='const') + + return generated_image,ws + +if __name__ == "__main__": + state = { + "images": { + # image_orig: the original image, change with seed/model is changed + # image_raw: image with mask and points, change durning optimization + # image_show: image showed on screen + }, + "temporal_params": { + # stop + }, + 'mask': + None, # mask for visualization, 1 for editing and 0 for unchange + 'last_mask': None, # last edited mask + 'show_mask': True, # add button + "generator_params": dnnlib.EasyDict(), + "params": { + "seed": 0, + "motion_lambda": 20, + "r1_in_pixels": 3, + "r2_in_pixels": 12, + "magnitude_direction_in_pixels": 1.0, + "latent_space": "w+", + "trunc_psi": 0.7, + "trunc_cutoff": None, + "lr": 0.001, + }, + "device": 'cuda:0', + "draw_interval": 1, + "renderer": renderer.Renderer(disable_timing=True), + "points": {}, + "curr_point": None, + "curr_type_point": "start", + 'editing_state': 'add_points', + 'pretrained_weight': 'stylegan2_horses_256_pytorch' + } + cache_dir = '../checkpoints' + valid_checkpoints_dict = { + f.split('/')[-1].split('.')[0]: os.path.join(cache_dir, f) + for f in os.listdir(cache_dir) + if (f.endswith('pkl') and os.path.exists(os.path.join(cache_dir, f))) + } + state['renderer'].init_network(state['generator_params'], # res + valid_checkpoints_dict[state['pretrained_weight']], # pkl + state['params']['seed'], # w0_seed, + None, # w_load + state['params']['latent_space'] == 'w+', # w_plus + 'const', + state['params']['trunc_psi'], # trunc_psi, + state['params']['trunc_cutoff'], # trunc_cutoff, + None, # input_transform + state['params']['lr'] # lr + ) + image = Image.open('/home/tianhao/research/drag3d/horse/render/0.png') + G = state['renderer'].G + #result = inverse_image(G,image,G.img_resolution) + percept = util.PerceptualLoss( + model="net-lin", net="vgg", use_gpu=True + ) + pti = PTI(G,percept) + result = pti.train(image,True) + imageio.imsave('../horse/test.png', make_image(result[0])[0]) + + + diff --git a/gan_inv/lpips/__init__.py b/gan_inv/lpips/__init__.py new file mode 100644 index 0000000..25f4ddc --- /dev/null +++ b/gan_inv/lpips/__init__.py @@ -0,0 +1,5 @@ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + diff --git a/gan_inv/lpips/__pycache__/__init__.cpython-39.pyc b/gan_inv/lpips/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acb6b4ab5e54d4fcc56c523a26fb40bfa48829b2 GIT binary patch literal 269 zcmYj~v1$V`42JEyYY8DqU!%9gFAz!sh0-A*&`!4qCw6m!JKM;3fA^^F`oNZ4*ykfSaZckHY1GeL@ygDGZ)N7Q~YtsTs9?h^(N}O zZ{8NQ93S){`ff-Tun(RRk0u7s7|}cCMIdwuJ_dJ8JlWONI0cSrAPYeGIhENewyU7~ z%-Mccs`b593AJmrS7f8rbZpf#>C?+*2mAL|b2E~$(j5G{^UcM literal 0 HcmV?d00001 diff --git a/gan_inv/lpips/__pycache__/base_model.cpython-39.pyc b/gan_inv/lpips/__pycache__/base_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87775177d07ffb97ab9f4cd129ab6935429af87d GIT binary patch literal 2791 zcmbtWOK;mo5Z>imFTdkQo=#gA6K{Xe|+)W48ZXI7F)$Z~VY5}cjo4(IXh%#dEAQDr!u|NR&IY%%r^Ey_d3 z#kY88Z6v^eN31iOc}G)gy`$r-MaIzVn4I?X$Qs%mNAX5f8CE-0&Ym(b!Mb2zz0o=~ z;3uqU|BE_VyUBCs7f*!VX%rw|$pYz0;rX;MgW+iaKFO8Gl~;e@(|~n2F!Y9lcEQja z9Soceve)BX4Qy~Oc*leaRPnh5HK^lkLj%@uc3>SgaIU~6Y~fsmZP>xN2De}r=Q`}c zKF$p|fZI5)!JYa0hj0(>qvmyZfYCh6E8lxr@H~Ma@_ETy;_3bY@612~?1a7L?YBsB zBMYLgu<@xP?;f}M$uMZi(2M(C(h@-yc*5_ueiUBs>341S#mi@{o)^1eeAbH6FwI)0 zXw)5&)#LO$cU;^GrR!G6nudh&@z#}N$BXUvcu2>iwS~FxOE_1Y@|BV%H)3xX)R9}* zSu5<&)=NQQ`*c`~K|DF7u-HIm!Y9mS6AdXI;t3O&glIaqM;KKv9pyF>H-xOI3kpdz zh^!I00!=7ysuXwVMd3JD$8Ci);}#$96kZh~?}>lm$;`;A;O%ZAUU>qx$<`d_&BD@r z_k+!>_plnmq1OxC)RX;ew`A_1Ft;T2!DhFURE9%@60Hh6N*Rdkpk(@9VS0t+&F1NJ zCv32QmnuAr(~(@6U)vOwKN2E{rF#};BQMJCu6AOD z;`ir^2D$Hc!ziF;fqkErWM~ltzr#`5GVZ7ROK)^9hEQBesg02FOI$0E^tRYX9pk;95)Z<-cR2U# zuq%?`;ZfFf1eJ;cK-@ttca&!kirh>^7|VP;!dKk_JV9xl8}zhk&HCJZQbGA94v9P@ zLTPXf9HnaR!WSrzQMV#+Ht;UPQx^0|rab;9KT8&i-oGmy1$2 z==~C|>H{__08;-o#iA0J@AL6?f#fWYPy&CNo-{Y;YfrV~x_Q;5e}GX$eBE__!(KVt z5j9jM=qoMgvnc4Js{j*Q#O)BdMT81kd`#psB26OHLP~C49l!c9pgdFk*x`=u=o<#n z`DHj3Ubs?{&ALJ*H{;PTJr|#oR+aH&rd$8l=C$=VI$xfnJJVK<5Gi9=S-#EK``-L|5lrcZTG z^;B22PgT!&njUsphusC_B|-?0Rtqz*Tms_qv_d=}@x&YA35j|D1QIA9E${*;4;#M! zuj<=mqD3Gfx>cvn{qmpxeE;Q?Gc!|C@O$+Senos|O;P@y5|cjzi5Kxirm84H5vs3r z)ViutT=TWM#&5lD@HJ;h)|Ay&-756r) z5hH27=}6Dn^4<49M`~HT=|t|Up>Ta9jD8q)=p8>0^v=1REmt5?4cvGylsl2Vtv9`9T95FZCQ{GiKYI>53!`ShY ze6K5<*iF=+W+X=B`prb$s^_*MukWG`z1~&iS#qaZWi(3V6!wXH$)+g4TyEkjilVzb{jk0&{t{1J!fr#b4~Xl?OV9huMYRc71XX~CR!LJM%RhA6QhlJC3&tJCAqG| zEEBWi?6|^{QBpuidy!OzquD*_WA}6ou}w?N zmK$@Ek6}NedD<$qqE*$3T2iZc7Vwl*c?J0m+SlQaPMe8`56J6?!~jO22^^0@RT#7j z=K(^Gjh}lOyQ#JHHujE=OgkqEk9Ae~9N^Qy{}JNB+?o=tuPQ44zkpCE={~79H<>x# zS6ANq06#iOy}*l;5`}iuaC{e}@sMyq8bi2IqoyTh)AL=XQAxT4O50($Q_C|;Cb_8V zqlY}G;d@(2LAnomZWJeFB99wO&)s#I%W*pB=qKf{7k7IpnG!WVMnd^C>a5_2Sd8G% zRX@n9n#%bNh3$L_+7L5m#R7UV#W}HvcTt=d7w|3tyf5Ni7Eg*xcvr+z!oqt-d`w)%yDB~| zp2mAt6hQY+B*k^GDI@dL0j9%D4{thw4A2i7C23LM?h@k>Bgng|JdMm_q)H6hlV?%D z7@?>*2E&1tsp$Oz!2X&K?l-Wp59|6JXY2OHX6xFmi~sz0&t7=pAJG)QzkQ*vuP(1i znoDDrv(jJ=uOc3vpaFg*Q~=)<234UF*Ac$h+5m`R>w^+PV^Bnx8&nbIfi&RIrnZ2E z*#&}jq3>w&k7Bc392kS#Km+0w2Q#!wk5O*xo^midC`4V6AAokBRt9-t_7wTGxFk?7 z?dy?xkROzyZ^va(8srxg!Y=f!w=3~XyE@1>RZkhreoOtjO8ZR66-XX?24nwaHnxVc z>SGc^n(oZ&jv_>`EZdNiXi ztMO^X^}{nqNTD2c+6G#GJvI;LkBA;;2RNAiMqS9Hj~Mday;ok%nn#h>p?0soeDkB)Vfwt#2%zWo-78C1 z=%nGb+iG2TYifw}IFmd<3$JNygn_$c?Rma$`M^^eXwUUp+i|pVf8(W1d*jBdFTJ(g zjkl+I8PA``#gZ@>CWt9d6hQ{|BA1B{8aRQlXjxMoq1YFG!R|+>@>a^q=YudjUm_CUC4bOdypMm_NFZ z4)`VvVkiqxDy&^c?$dj?zmerMWEe%(us`8-+#vG8AUbI{JKA=-E(XO$w7gw6SQ-vC zY|j;zhZ|*qG{LCvw6Pai!iAuOyaXdVy*TV(Dd25z_^#&+YuiwOVi5IexX-_ zSbme+C6?a4C^$sz#}Nd*YdhLdOseF1b~ z_M4bD7<@=vItxqI);bb5Yltyx#Myc8dfQOAqySYp%1403-{_S_UvxP)jz9ZnI| zr+GDDn$YsYkUydXz7ovEf!Z-~v%7(z6~ zIrJXIOGvL;drkzpcu<5f{?$DXr{A%*dac@2A2^7%CnBC|oH0A_bM>tW8ZB9<6=mmg z)EeqpAv?-RA?v1JcqX#$KNGzre+1HdvsOroT>zCW0R&kz?gj2Jy5Du{W$Hu7u!|#? zKmzlbObXfBk{p{!iQee;l1j(v+J4wTl`z=opB{qN?laNKGZ6t+f8h>H!r`{~ti3;H zJrmuAR+EpQ5;o#Q$Aq8}zK99Rn-si+0E^M?Tu(IrdQFqhBO@t|K}~*^Y8Xjwov4Gj!V49aIZ%2%mg z=}pJ)xi3i>N|ZrTCZ(<;p*fQJU$5T0_0sE3u!D`f17uB1@Xrp#CW;E z=R!s?UR$6NZ1M=4y2|3;3)V)Eae~H%YD=g~2*{kWyb0L}nH6reA+|{vq#|$Q;SW6g z{KGF0PLWthhq2ghWYKZ|+S=0ESnYZ3jZG+rqk&1U%O)>I6OYC-KVhyomt|OycU^gh zgce|dk3|9MsSBm4y6AuloE zp?ajOD{n2sodi=(sE;&R8mMhJ|6od$VxyfyD&H=Up$+o_al7!K#LkA?Mv_mLfGs22 zHVw$c#Fpw%nQCUXl|VZ<#^8wAb|NQ^Wm4X8W#IZY=^>JiXOd%@$OjZR;01KhDV>xg zpEf>5AQ__U8+bC;T@^>Ne{M{SDGE>dVTcSPN97+QLCZc*_8P@Bj^Shqonu54gJZBH z+lKrKpWGb6eCDqqyNXDuf>hJFC~3ocb1WBDCoI|HxjKAr*beeyuoY)=S{M^~KQcuM=$KV{$`!G+G$3wI)9Fb?h zjEjO=73N{XAEwrJ`3ROZ$|}g8;eLm0Dgy3np~87hZQ=iDSApC5M&@Jr*B4rt3;kZK z%}N3!`9%u8gdowxs{F|`vX;zc2x&WwMo&78efbkqlta>%l8J2WdhQS z@KFZVG_3=8wkhXO!$avD`JZUN2F4g zt}Bln3i%aO{8KzBglP-J{1?U)InE?gn?xH#jKZ5llOu03Hc4%3wE9B`EclTKYjRDk z1@Hd>FzqnUV7~zWa^cVVv#qfoj#i zI3~o0n@o{{oa!S|kc?zPWUN7e{)7VQrXRwcPD&T9PP8nQ}HSL(Z2es!;;@Px_;Zfa!Jy`UgYt-ouVA4I+UY6Z zJXM%dJBUL^T>7I((Aownc#{q(R<)=ckf6cE5A?AL!P&qNDh}pFT-#|e2`zP?H}%XN zP9g{ww&{^2lc;|wy`;3@79qy~{4Nb=A2xZ@y7oLm02Nh}KSn_VK~iDW5^6yR4^C=T z8ZkTtAQJOYCBRKkW83DA>bfsQYI~I@+8w9Ne0Ebhbi=?Q5xf6@(XtZ`iLc?N3@p;W zG(O=cHL9t%VK0)VAkR~9fr95KAiFM=uU(3fCMds50SP`yGKQg5{yfEgi2_m-<*!mO z)=M}>%8C3Gf=}X!ECjFx0YG`R^4-$iQrRq+Idjgen!1@ci{`9Z;M($+QFDrQi4O{R zn03E{NDJGE$H0S|R$@)u3E<}Tv4-vK;(mhw1@|Y=E%T&nKGMd>eF!syu=-QP5(F~C zdeB5aq8{C~utqvW|A{_wU?IYL`1x@oOflRx9suYN%ORE{R@=edV=5ytqQ0Y$>^!G% z9B@YQg92I<2ZbXwE>R1NNZP8=q#NR*a|*KJ1Os++FSKrut1w7CnYZDkfDiK4b_hp| zb@lp<+qY`4E<*_7UIaJ9$dVL8uYPlA6knm(71*#A*$2=Ktt%9PqsoDg@(Q;+X82B& z!=Sy@i{QXrPaD&Vd|2cTPfg%!L+3Ng|Gf}X%j#mx$eZv!(nQc7`GM*4jFF*hjpmvb zl7CEmhBB{``3>V0YuyfeJ}hZmPFvSUt0lX+yU&I>{E!6l5$+A(Nt`mum)OvUeC^4= zVm%~tB*J>f^=%9FH@M5xDj#3jkel(EMa0f%4I-Je~o%2?-Sa?S)W#KhCUqNv~<^Z$+bFeK&uq1(UMpr zKj=m}`DjRXrdGD=W0wKM6h4We>m7D@kaHBe6exg?T}D2TDLS^|iQb7_vXyqZD|}uv z^l4yz4Z5I$llmjqNK>ZY)dsX~@Ct^%ym~qH0lUpO>#7TN^72}m4Nvg)D1U8fYHHNe z<<;e9FI%KCN7l2eUHD$QZrBAEjpsMf#gg@dX`RvJeGEJ8TIS<(@|tn5VET2#fPJ-? z{N6U)&FmjQ;yv>R5D_$=+~(@ zABW`FluH{j zecfp`Z0ZWo^KYn@nz3idgJ?jYUB&+r`xQ&-RZadbYD|fUIb@8q8xn~_5%D4Mu+ESS zlZC|tKFIno>}R4ON6v{6r`fTC6a$pYf{Dyl0=5G3C+_B1@YNAPCoyT_Vs|Lo#~_dEff1HNK<)*Ev&S%b z3P^EN3h^|OA(J9xQlMqXy0=aDr(utPnLO%~!|g%3SPZ&Ds%8)KUMEyR)n?tY;b;4`d7uM_n;iH;yO z3-h~S!*;szw^0u|^SjjWcPL>|*0n9?`z zMs@u*2OIodo!h`CW8e^41VQw1q&$@(Wjhc7IlVys1o&vJ#`768hZtrPI zl0m}qgDEmW4e-7Yl#@Nl}zU>KC?EH*g$DR1!H&RZEr>+Da%!vFJ>V*x4?z3lJo* zfbK4&hy_~Ll1EAtrB0^PCV9w^ojz2V&a@AmHtmc1*v_OceZ1pNGtn(|`q!P#@~-o5wS@0@$?S+7{k8@O)#)ld8{o-mC6pqtUl!p&ur=z=f|-|$7# zXo-pt#LK3vNd7k~CjVP4yJDwxnU*aLVYHlz!xFQZZRIMtv^HOvz!O%p&?;6YS!Oq< zTGN$j<}=L$t(nRUc*oB+54L72hnUZSKU_Hs{#k#*FFX*HBmQ&#q(232?p4E|_76NT z`~!i#K7Ze=9ED`YKM2XeRB{ZGS^p3uhf>LLNDli)AUTpsPCzo}&qFewN=`y@)ISEv zu~c#jlH>jfNKOQ&@TL;(Px_~De~Ry)@k{aaBZ3Su`_UX^a_YP;E1#00I?z~6!XsgHvl6`)BY416~X;=66N z5kYKq<3@9(D(Ow>MLaFvRg`E3q-U&S0exAQJ>i>=X=v8+4mH~`L7oWhYy?5aZ?vex zR^WwYshvjX#z7dVA~iA8H2!{h@+&j*}+iVRWE8(&vfqTsi@sU;3ERtkILljJ_F*b=fn0)3<#4vDq{33rsY#tO)2Iq2T(TOFFySXhMme zS}dEo5LLZq;MTlq+*X)++YfZnYc}ekTWzZ_P!Ss0(VEu@vs`0b7vN+*^!=e;7NtF z-qG&?Qn3*Ws7P+fZ5V40yleKLG}r1$P`hV?W_k{&)60Nng(~u zdZC}K!y)_`JzoqjaGkj!l|tPl*I8-ZQLeL5_g&OYaa|^Fanyia``=w-p%1zKXu zcOKJNW_r&1LT$zesg<9dH+q(zV zaZIQu?hdA%-*qLrsDj(w1`OG)E6-P{vzh2df$OVw2k=$|N;V=j0d9w^afh(9Luf7^ zhkbNAfl?=Nr?W0>zY)8x&UIk=ZhQ}xld|1#z$!2FG3vw*Xq!EScHj=8MmyAnTQJH- zSYPt)1xgpmL~-Xf`Brv9^$gxrv8rofq#e=?UZnG>4v06KL8!%&HtA+!nJt4mt61pi z9Mc7+=b3(%>E|@nM$CVn>4kDmTT#%gsRg`IeU`|1BA+JmIU*N`d>*8d#mjiQaHl!f ztv0nv6om%^ufwHT# zCWC7d46O{C`YIOmktFvI!#nv%Ziz=?ThM<`Zi|`>GmpU4zCAji@##BePrWzDjZ$@zclgB5HXYr+FN#&~@yd z6!tupP^|P#DApP()@HZ10@ay_dpO8ReB3qkgb|XWwHb9=tl_Z{eg###TICVGPNGj9 z;qz$1BRsj^2r*6sX6d3^$=K8~^~xHPiN#S{f;d-Eci)tu5(9>aCwcu+v0}lM!iX9vKx$#!P88r)zMk3&&R%Y4?A=>_}*tgls6ahDCaoMc@?Bsj@{k5O{OcM8f4#9f{ZAjD4TaKZbE4w^=D&aS-!9ef?L5Db{mTz8)&J$^bbWZK zoKa^nUY-)`S5K*+4kr^RcePuqVYjythga8hW(6k?Is_%KV@2;`{Tw>^9h5YjCQNJ| z8VwqDH%9WmgNt}7S5YGBjiTOlY(=2r7~J?1K~Wj0RJkDF34et5P6E!|ArLqNl=R@T zIXm|H)Wp=AXfg66PlK^di%PFZ@lr^a8IrMUu&qbWLOj5WFMyK_qXq{Z90_*O;;2h# zxcQmVE2Rh_Q!^+bqQOz8#IaN9?pmX|R`T0HRPss--&!bD+u@xO`-#Qs;9WETitflI zFvQ3M?}3=9j>qPfwQbZzDCRGY5ZbP3Ejtp2P96!#~1-4D9$@Hx6LCda_0*YVB zPHGW5Rwjy70|zgs@t*zCD7nQRKSZGFxF&+h{ooh){sxwzH`18Kb-$DQzgmj|h z%>;omty(36_z;e$;?O^Zu@toF3~M4%ycE03Nr>7&iTKQf*lb1|7n`#qq2mxzKm~2+ zivV*4|LPk=-U88%>-y~~JOuTo-X_9tV6waRG@%Acy7rdfMXN$Gb`91eqOEOg?UUnR zy7t54pv}j@3?B!luED6zFSYB9m>u+HwG?|BSoZTJA2nWB4WbuIaMw{OUQ_LEeXUfh zCLn=4>NHRRm6hO_yFqCc*@hqpOW!#6;=>KIvqrfdaK>=OO2?sqPju3v{S-B zSE_`W!9{Zz;^@zDGmax5!}1XW0WXXg9NU4JKF*;)7mmBWwJjM#ELpdJuRU#nuTN;r z{f_j^t)=>7?mt85B?(61^i0NQmhtH{l$D^;XwQnP(<*@c|bICN&16NF4E zmY5Wqb9)wgh|-Ue!T$ZFF}tSqplrxh>XObcZ-X zn*;%h!9=q`?oAxQxemEFD*+ZkX`50wC^%Od$j`x~69~=Xi8=~hmjM@^0IeVL0pz&U z?}AUAxw->h-Q~*lu3rzdNg0P-?>fZOchDxCNdaNu=HWdv89++#>O6geChb0;zG#LB z=rW2VPY23L)DD3HSB!ieg>uN(v0J92Tt+*KaKhMAR79J%c4XFr8LOv4GIAj zDGBQ!p2RLBw-C`FwT%4dw#Dh>R00`EGJ`hTqMz+$kqDr-IC#r1$Z%tO%fWxPC)RU* zmU5LyiwtB)QVe9+yC)u-TY0}g{lpWo+$!|4H3Z{GcI41!QGM9UZB1?)xXY!fhT`(> z=}w!kzj`G(w#Jx0=jV8t7AdD6w_Q5XeS{iZZ(fGOrgN&cYt1%1^b~SiIO`>tAjhpH zsn)vF2<|2aLfJg8bMy_MtdI`f+D!Q+tj9oOp%P5}= zQ~TPjTi0)lBQ=AkhDiMz+Hx zvOwPt5c0v=i>apr=E=7?YA$@i)N*DABJWe932_)ME}~gym3nWYkKi^i)R^ zIyKaLhGQ`4rsQkB{z&rmK;AKTtV6}Ip)-6FI&-Y6Kxg?jboN+Rg|6aPp{tH{HDA6h zHEJWwLTa$o?t8t!kB7Tqu!*ML8~8yO0ck<+Uf{RMtbEaH&=p;R2qcI35~gl^C~Yxc z@zqC`qh!-{)DI<9ax@HPrgn6H7zF!4KZrN(v=z}qCUjI(NEdy|Bq^3+mPm*4k$fyM zDOPsWME2RoY)3m#VtvOrGLNN1a;*}!OSn?P)e^3iaJ_`*d>OsX$7g)yki9Ma;^Kjt zNITZiLVPw+4_Trl>Wk8`%%uCjA{$>ll*d-~469FJrTSXQsxrgsZ?M8gEm>7(Sp5rD zx^D=pm&LqlPg;G#?w^Al+@WMwpJ{g;cBXHY?B-_LZ6)jz79Azh(c-a$x2+y%zMW|9 z3mEGUzJxDY*?bJw+2ujZy>377M+QFMNG73?FiFsU7*$AANz@=V8s}0i8tw+1lU-`M zZnxi!T{o5Ey|gy*lBU~C&8@(ThdhW8ZaqNpY;89xR7_p0KxE=2K z4Lwz(Al%|#!~8kDwsUr}4(@o$>_MuDPOwmx>$XELid^@x^zWO?+k;-Pj2Gx{dxK^4 z8+drR<+r)lxp`yLefQ3-<&M|Ky6i27yWQPrxzs)O&j$nkezdgvke|WWG=8)WA+b7s z)e}D*tqahW%SE)aLU;GiPF8%W_!MpWjKxJ?jy~pFDADgA;1p8Aj@f~n$Oj5qERpjT zHqH^Wc}qh}8Mkz_u#xk723p#Su!hsHrA zslWx5O{ac4W0fBcuARIQcY9*^$S+$x~Ygx7K2+L~?lV89~aS*n|FOqnP z#LFaJA@L;=v>=Tt|2nFay_d4>)Yx{jk#ew7Xj%nWEx?+S+pHJhT3XrbblgEU>r?^Y zS7;m|zDc7PTLb=q$9;0L=qnHsvzdBosk23Sf1&g(^T!+V#+}DH?J{C{>@wTU-5j{; zy7z}(nB5WH#J@{p*K%)d2Vv;CrI}sET~5;zdzVN%qSL@+^;d-EIh_@T7X}ZWa!|Z6 z>!5%qii1MML80yFM}{~k%o4Ur*e>Bp30K8|P>bs)4g~XPE}ow{5Vpw1mrEpW&9M3i zR`S%b@nm10!D{;0c+%?cu$n$Lp0xUppzn3;sxxEyHbCDCpzq6qzRgWepN|tmbBhGg zG^fqZ>ku1_vwWT!H%JiC@C6c!BwiqKj>LHqr|zFG)7@7|yheh&n17W7Idwx9+o{B& zFXf-zEzw=Uru-5;_YD%l3C9~n%6BT`M1;4A{Xzy$Hl=cCPLJz5u9Ai;Xu09ecC zz1wA@yvSuMbt+v4BHIy9!~!?E(P2hZk^y2&ARltbz&htwJE z?o$cnA+7jR2&I;-@H5c~xtK?N=CR3nPMT+=6I0}|*kdA(#r}9CmGW3C)e)|X&}2mC zVVj01F2YJ)qJIFKOk@+qXYb4WQl8MtA0x?iwf--mSD?QRJz`+L&~HGmLVp)}qoiMg zUW2|5y;;(qgIkhR%>$tTodH4t zS^!!oVu&B$Ekz(PqQiYe$IK`o@khErhiDN;+#0pZ(BEsO%DsLwRmrJSqvt)?9So2WhTg+~BO??wGD1-!BNX*&L!F8` zF#fPNE2j7e14VR3QA}a7`KeN8x@;pHpZxacvI#!}4g40pmI8`N3RO;@D77lmr*i5< zN!BEFq9kjQI#H4ZXDO1b*nF07!x%G!h%>U)Mean4+;T*P+=*Ii4U7L6bLV$S{Dj1N zBz{W5A@MU3KPPc|_(53^^#&fw0O%(Js|8p~ji-enf59Nps}K~1(2tU%sT4^T!AL#- zE(Lig8+W*H+*1P*a@OCYOc;bTje1&X55h2L$K63c5^pGC;j|+Do!IL}&~u^y5iF<1 zt#=>B+fWfRh=4iO-x2>Yh-@i~t>2=1Vs~62G!YbVrRJO2pAy#@S`kI&RX8HDe0j08 JczN;4e*sfYda3{b literal 0 HcmV?d00001 diff --git a/gan_inv/lpips/__pycache__/util.cpython-39.pyc b/gan_inv/lpips/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfaa8a4792c34b748353d170da7b8181b04908c8 GIT binary patch literal 5781 zcmb_g&2QYs6(4?cS1Z|(9ZRyElyMU$o5+%-q|Ha2+NoV7X&c3b)h`+@L2_S#GOd&AweR&ovnLW#4FGjE1( z-uu0ep+7m<|3vDaQUq4+k$D4_9#KXBEac<4VkWimNC@tDsd^m9z$0b2Ulp zvF;ieH)6AAxz?c1c5QBQD{l0fZc}0JGH&z6M~pW@dvjv5bzgHQ`2{|~TOTRz6kp&| z{1~WXe32jLPk}nlFY;-A0@PD{i9gLxftuzseD)*ep5Sx*G=5L=Wqy`FgOR8CtNdC1 z4N#}J177A%hsNd%e~zE)X#AV}d5p{+=y`#kN6#F8k-r4q)BH8w<}ah?%mJQPj?9`L z%+hbL_O~99-r7oa%wNybxL<@`)XP#)pv&gbohXme1oTuUqNMOT{Uj*Rv-(h}p3V!= z4~o9qspHa453QmWw6@B|L7J}~8MdT#7Z?6Xsh*1>u@{c;AOjn$oJZiGbsQSnjZ=|leh^j`eu~JCD+?oD#PgjKn0fPs{FW0X z5sgGUA>Xmkc`Z#sXv)a@StvvUo2_imBQDRY3~5p^iC%Gxn&Z@vYTU*(81DnJotVIL z)qK2gr6+zbY+IF{hjB*^ixy3sqDHPGSER)Nl7{C6v7hIj_mKVfjm3?$7cN4?WW!Gv zMVN=a2sRe43BUX9JFDLHwf7ghegXmSEXG-s<%|6yiWjnNF^$#e1vwE+olDFGA<1z`ws9hxx*Y?fs6Eq@pJ+;zYT#Ci>yL19kL! zVeS+gp%eA|E=Zcc9u{}QFmY~PoL{}bMGePp%A%OzRA$erv7n_Rv3>b>{w{cObgVZQ*clWi9%EIHDFvkF8f-4WX(73vn` zdTq6al+{VaU0?9FE)c<4rPJFgv({(T7^zLn(M;K#Ch9C2B60#MYwAi%X{kG>4%vUW zxUvF|0^|v^fPB-Vh}4gXZh78qgpPVd0!*%5O%4zngmLV72VOse_X#NjITC?9Cp6j0 z*_L#zFaG?Ozy1B6D(H9nr9Kt$ z8t5a!kS(MzB=bNRl>~rT&~3@!&>Fe*P^>dU(e%|vR7FU!>dmr$Uv6uHURfD&+FiOR z7D1zHiTIYOU(Z5+>#&@9mM4;DDCwJEJtFA{1+y3>Nue|-j1cfZhjjYT?!*X_yG`yF zz9%)~wtD@zh=L72Ny2y_ExRHNu!@V+jJru?@EkdrG&Yp!JD^v_ph@VZ1*H0Z0D@95 zRZ6?4j!NXi5~|vT^2AE1p)z70tAkeQXme2OEF$RP{!8eC-St(Od65AOnYtd4_F!-; z-;yCuyp89|LXKXJ{q?r)n!s+<>$(j>h8KiMF$r2WAwv{sD!MYS z$o_zm{_rB2P%C#y;3vnMV|#2LxW$dEzo&6^U%3JKYN=syvO(>DG{3D^T7WccbmA25 zWhDJxwp}#}*e1`?JZvj&v-Tq|UEe$;D)^*}EhpJ`u~uylRlEoKh@c~kkU!RrHKt%` zghJiaIy$zXb7dcY5L{ht>yK!ck7brycxZ?h(Oa2oQiK}YD3tE{_o7_9iy??^CfO6k zPO{}%^6!CfDE2DowJ~f{I2gg`*LWDg2!(Y45ZzaH^-|l@N*!=AN>!|uRA@g5PDlCy zCq;Z(s`;n?*BpbqY1b(9(zp-(pzxA#6qg{LAS--<8`m$cv<-0uL=Co~rJ|n1k)X{Y zfwUxK2a(dXiqxa<=~}(OOOtrJo~r?_5a=5P>wSgmCmD3EKS9gJ3_IF?iE*3j8R>Se<0r6SI1o~Q~#V7gHgJH0! zF>7p2QqTYvtYV6+Qd%7pRyBPCJh0NS16Hccs0%HJ;zR6NnoL|E>ig7?n^Z<0Cz3Y= zogu_qXmIGz{UVL)JVOt_c>CA^^V%pl64c&#HPwmYBBTODK1*k;okI`eD03~CX&J$( zGX1>RM)6GHrm~_uK%pK=Tu{Z5a`6ETG@{&i7MLlgiD-=pVI^l zJg$uwcx+lzLK;ePi38Ef5l= zYtl9bl@Qce;yf|te0cg;JsP;h`_cKz zgY>ryrX2jkqGJp6ShApUjdUt?hT}FgCDu5q4ALDz5@Afc#%{d@03bJ`MAEkmgd%-Q z70X=THIixcqb8|G=BPrUGq}NxGi-+KYPU}%e*_0oV;4N~Qvwy&=BN*E!RH0)vkIfM zHt}_biZ5>8r}sX;YlG7&tr9_Smd!Du@D_F0bBW#l6*z31_$}V}SOtfPj}tagTUeDt zlX49ep%H7$&>EoTP=MBJP|wM_vZ6|?IKf^K|q=dy|g;7hm^! zVIiX6R$CsYuBJWsAdZ_KUUlBN?sTI&I3yCk{8S^)jd<|mxITvB2s*Y39YgWm2B$-G z%VVWnl~%j>y0f(8WGTX`gI8{(+2goSH+Clw!|saM!U%rPwnj!Pgpg#KsdO zpjh&Il0eqh_}Ws(n6@IGM=yRTCmjA+X{&21Z5+7qf0j<#ojy_@PTGP}kI<+g(FMg5 zi39i8;4rXI13VWL8R~+y{t$(;Vp5s-5QQVVJY&m~sXPdZH;I=HHFc>*<+RMpvV@ZP zPyB@XNS!h|3JONjpa(QAbx-NAvM<%22EI$;=5z?emj$F!n^LUSQcph9)SH_8vuxeg Tn^VncTc0+jbz5sbH*Nk8jtl-1 literal 0 HcmV?d00001 diff --git a/gan_inv/lpips/base_model.py b/gan_inv/lpips/base_model.py new file mode 100644 index 0000000..8de1d16 --- /dev/null +++ b/gan_inv/lpips/base_model.py @@ -0,0 +1,58 @@ +import os +import numpy as np +import torch +from torch.autograd import Variable +from pdb import set_trace as st +from IPython import embed + +class BaseModel(): + def __init__(self): + pass; + + def name(self): + return 'BaseModel' + + def initialize(self, use_gpu=True, gpu_ids=[0]): + self.use_gpu = use_gpu + self.gpu_ids = gpu_ids + + def forward(self): + pass + + def get_image_paths(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, path, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(path, save_filename) + torch.save(network.state_dict(), save_path) + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + print('Loading network from %s'%save_path) + network.load_state_dict(torch.load(save_path)) + + def update_learning_rate(): + pass + + def get_image_paths(self): + return self.image_paths + + def save_done(self, flag=False): + np.save(os.path.join(self.save_dir, 'done_flag'),flag) + np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') diff --git a/gan_inv/lpips/dist_model.py b/gan_inv/lpips/dist_model.py new file mode 100644 index 0000000..23bf66a --- /dev/null +++ b/gan_inv/lpips/dist_model.py @@ -0,0 +1,314 @@ + +from __future__ import absolute_import + +import sys +import numpy as np +import torch +from torch import nn +import os +from collections import OrderedDict +from torch.autograd import Variable +import itertools +from .base_model import BaseModel +from scipy.ndimage import zoom +import fractions +import functools +import skimage.transform +from tqdm import tqdm +import urllib + +from IPython import embed + +from . import networks_basic as networks +from . import util + + +class DownloadProgressBar(tqdm): + def update_to(self, b=1, bsize=1, tsize=None): + if tsize is not None: + self.total = tsize + self.update(b * bsize - self.n) + + +def get_path(base_path): + BASE_DIR = os.path.join('checkpoints') + + save_path = os.path.join(BASE_DIR, base_path) + if not os.path.exists(save_path): + url = f"https://huggingface.co/aaronb/StyleGAN2/resolve/main/{base_path}" + print(f'{base_path} not found') + print('Try to download from huggingface: ', url) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + download_url(url, save_path) + print('Downloaded to ', save_path) + return save_path + + +def download_url(url, output_path): + with DownloadProgressBar(unit='B', unit_scale=True, + miniters=1, desc=url.split('/')[-1]) as t: + urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to) + + +class DistModel(BaseModel): + def name(self): + return self.model_name + + def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, + use_gpu=True, printNet=False, spatial=False, + is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): + ''' + INPUTS + model - ['net-lin'] for linearly calibrated network + ['net'] for off-the-shelf network + ['L2'] for L2 distance in Lab colorspace + ['SSIM'] for ssim in RGB colorspace + net - ['squeeze','alex','vgg'] + model_path - if None, will look in weights/[NET_NAME].pth + colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM + use_gpu - bool - whether or not to use a GPU + printNet - bool - whether or not to print network architecture out + spatial - bool - whether to output an array containing varying distances across spatial dimensions + spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). + spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. + spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). + is_train - bool - [True] for training mode + lr - float - initial learning rate + beta1 - float - initial momentum term for adam + version - 0.1 for latest, 0.0 was original (with a bug) + gpu_ids - int array - [0] by default, gpus to use + ''' + BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) + + self.model = model + self.net = net + self.is_train = is_train + self.spatial = spatial + self.gpu_ids = gpu_ids + self.model_name = '%s [%s]' % (model, net) + + if(self.model == 'net-lin'): # pretrained net + linear layer + self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, + use_dropout=True, spatial=spatial, version=version, lpips=True) + kw = {} + if not use_gpu: + kw['map_location'] = 'cpu' + if(model_path is None): + model_path = get_path('weights/v%s/%s.pth' % (version, net)) + + if(not is_train): + print('Loading model from: %s' % model_path) + self.net.load_state_dict(torch.load(model_path, **kw), strict=False) + + elif(self.model == 'net'): # pretrained network + self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) + elif(self.model in ['L2', 'l2']): + self.net = networks.L2(use_gpu=use_gpu, colorspace=colorspace) # not really a network, only for testing + self.model_name = 'L2' + elif(self.model in ['DSSIM', 'dssim', 'SSIM', 'ssim']): + self.net = networks.DSSIM(use_gpu=use_gpu, colorspace=colorspace) + self.model_name = 'SSIM' + else: + raise ValueError("Model [%s] not recognized." % self.model) + + self.parameters = list(self.net.parameters()) + + if self.is_train: # training mode + # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) + self.rankLoss = networks.BCERankingLoss() + self.parameters += list(self.rankLoss.net.parameters()) + self.lr = lr + self.old_lr = lr + self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) + else: # test mode + self.net.eval() + + if(use_gpu): + self.net.to(gpu_ids[0]) + self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) + if(self.is_train): + self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 + + if(printNet): + print('---------- Networks initialized -------------') + networks.print_network(self.net) + print('-----------------------------------------------') + + def forward(self, in0, in1, retPerLayer=False): + ''' Function computes the distance between image patches in0 and in1 + INPUTS + in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] + OUTPUT + computed distances between in0 and in1 + ''' + + return self.net.forward(in0, in1, retPerLayer=retPerLayer) + + # ***** TRAINING FUNCTIONS ***** + def optimize_parameters(self): + self.forward_train() + self.optimizer_net.zero_grad() + self.backward_train() + self.optimizer_net.step() + self.clamp_weights() + + def clamp_weights(self): + for module in self.net.modules(): + if(hasattr(module, 'weight') and module.kernel_size == (1, 1)): + module.weight.data = torch.clamp(module.weight.data, min=0) + + def set_input(self, data): + self.input_ref = data['ref'] + self.input_p0 = data['p0'] + self.input_p1 = data['p1'] + self.input_judge = data['judge'] + + if(self.use_gpu): + self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) + self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) + self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) + self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) + + self.var_ref = Variable(self.input_ref, requires_grad=True) + self.var_p0 = Variable(self.input_p0, requires_grad=True) + self.var_p1 = Variable(self.input_p1, requires_grad=True) + + def forward_train(self): # run forward pass + # print(self.net.module.scaling_layer.shift) + # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) + + self.d0 = self.forward(self.var_ref, self.var_p0) + self.d1 = self.forward(self.var_ref, self.var_p1) + self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge) + + self.var_judge = Variable(1. * self.input_judge).view(self.d0.size()) + + self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge * 2. - 1.) + + return self.loss_total + + def backward_train(self): + torch.mean(self.loss_total).backward() + + def compute_accuracy(self, d0, d1, judge): + ''' d0, d1 are Variables, judge is a Tensor ''' + d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten() + judge_per = judge.cpu().numpy().flatten() + return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per) + + def get_current_errors(self): + retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()), + ('acc_r', self.acc_r)]) + + for key in retDict.keys(): + retDict[key] = np.mean(retDict[key]) + + return retDict + + def get_current_visuals(self): + zoom_factor = 256 / self.var_ref.data.size()[2] + + ref_img = util.tensor2im(self.var_ref.data) + p0_img = util.tensor2im(self.var_p0.data) + p1_img = util.tensor2im(self.var_p1.data) + + ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0) + p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0) + p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0) + + return OrderedDict([('ref', ref_img_vis), + ('p0', p0_img_vis), + ('p1', p1_img_vis)]) + + def save(self, path, label): + if(self.use_gpu): + self.save_network(self.net.module, path, '', label) + else: + self.save_network(self.net, path, '', label) + self.save_network(self.rankLoss.net, path, 'rank', label) + + def update_learning_rate(self, nepoch_decay): + lrd = self.lr / nepoch_decay + lr = self.old_lr - lrd + + for param_group in self.optimizer_net.param_groups: + param_group['lr'] = lr + + print('update lr [%s] decay: %f -> %f' % (type, self.old_lr, lr)) + self.old_lr = lr + + +def score_2afc_dataset(data_loader, func, name=''): + ''' Function computes Two Alternative Forced Choice (2AFC) score using + distance function 'func' in dataset 'data_loader' + INPUTS + data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside + func - callable distance function - calling d=func(in0,in1) should take 2 + pytorch tensors with shape Nx3xXxY, and return numpy array of length N + OUTPUTS + [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators + [1] - dictionary with following elements + d0s,d1s - N arrays containing distances between reference patch to perturbed patches + gts - N array in [0,1], preferred patch selected by human evaluators + (closer to "0" for left patch p0, "1" for right patch p1, + "0.6" means 60pct people preferred right patch, 40pct preferred left) + scores - N array in [0,1], corresponding to what percentage function agreed with humans + CONSTS + N - number of test triplets in data_loader + ''' + + d0s = [] + d1s = [] + gts = [] + + for data in tqdm(data_loader.load_data(), desc=name): + d0s += func(data['ref'], data['p0']).data.cpu().numpy().flatten().tolist() + d1s += func(data['ref'], data['p1']).data.cpu().numpy().flatten().tolist() + gts += data['judge'].cpu().numpy().flatten().tolist() + + d0s = np.array(d0s) + d1s = np.array(d1s) + gts = np.array(gts) + scores = (d0s < d1s) * (1. - gts) + (d1s < d0s) * gts + (d1s == d0s) * .5 + + return(np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores)) + + +def score_jnd_dataset(data_loader, func, name=''): + ''' Function computes JND score using distance function 'func' in dataset 'data_loader' + INPUTS + data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside + func - callable distance function - calling d=func(in0,in1) should take 2 + pytorch tensors with shape Nx3xXxY, and return pytorch array of length N + OUTPUTS + [0] - JND score in [0,1], mAP score (area under precision-recall curve) + [1] - dictionary with following elements + ds - N array containing distances between two patches shown to human evaluator + sames - N array containing fraction of people who thought the two patches were identical + CONSTS + N - number of test triplets in data_loader + ''' + + ds = [] + gts = [] + + for data in tqdm(data_loader.load_data(), desc=name): + ds += func(data['p0'], data['p1']).data.cpu().numpy().tolist() + gts += data['same'].cpu().numpy().flatten().tolist() + + sames = np.array(gts) + ds = np.array(ds) + + sorted_inds = np.argsort(ds) + ds_sorted = ds[sorted_inds] + sames_sorted = sames[sorted_inds] + + TPs = np.cumsum(sames_sorted) + FPs = np.cumsum(1 - sames_sorted) + FNs = np.sum(sames_sorted) - TPs + + precs = TPs / (TPs + FPs) + recs = TPs / (TPs + FNs) + score = util.voc_ap(recs, precs) + + return(score, dict(ds=ds, sames=sames)) diff --git a/gan_inv/lpips/networks_basic.py b/gan_inv/lpips/networks_basic.py new file mode 100644 index 0000000..ea45e4c --- /dev/null +++ b/gan_inv/lpips/networks_basic.py @@ -0,0 +1,188 @@ + +from __future__ import absolute_import + +import sys +import torch +import torch.nn as nn +import torch.nn.init as init +from torch.autograd import Variable +import numpy as np +from pdb import set_trace as st +from skimage import color +from IPython import embed +from . import pretrained_networks as pn + +from . import util + + +def spatial_average(in_tens, keepdim=True): + return in_tens.mean([2,3],keepdim=keepdim) + +def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W + in_H = in_tens.shape[2] + scale_factor = 1.*out_H/in_H + + return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens) + +# Learned perceptual metric +class PNetLin(nn.Module): + def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True): + super(PNetLin, self).__init__() + + self.pnet_type = pnet_type + self.pnet_tune = pnet_tune + self.pnet_rand = pnet_rand + self.spatial = spatial + self.lpips = lpips + self.version = version + self.scaling_layer = ScalingLayer() + + if(self.pnet_type in ['vgg','vgg16']): + net_type = pn.vgg16 + self.chns = [64,128,256,512,512] + elif(self.pnet_type=='alex'): + net_type = pn.alexnet + self.chns = [64,192,384,256,256] + elif(self.pnet_type=='squeeze'): + net_type = pn.squeezenet + self.chns = [64,128,256,384,384,512,512] + self.L = len(self.chns) + + self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) + + if(lpips): + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] + if(self.pnet_type=='squeeze'): # 7 layers for squeezenet + self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) + self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) + self.lins+=[self.lin5,self.lin6] + + def forward(self, in0, in1, retPerLayer=False): + # v0.0 - original release had a bug, where input was not scaled + in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) + outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) + feats0, feats1, diffs = {}, {}, {} + + for kk in range(self.L): + feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk]-feats1[kk])**2 + + if(self.lpips): + if(self.spatial): + res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)] + else: + res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] + else: + if(self.spatial): + res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)] + else: + res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] + + val = res[0] + for l in range(1,self.L): + val += res[l] + + if(retPerLayer): + return (val, res) + else: + return val + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) + self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + ''' A single linear layer which does a 1x1 conv ''' + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + + layers = [nn.Dropout(),] if(use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] + self.model = nn.Sequential(*layers) + + +class Dist2LogitLayer(nn.Module): + ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' + def __init__(self, chn_mid=32, use_sigmoid=True): + super(Dist2LogitLayer, self).__init__() + + layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] + layers += [nn.LeakyReLU(0.2,True),] + layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] + layers += [nn.LeakyReLU(0.2,True),] + layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] + if(use_sigmoid): + layers += [nn.Sigmoid(),] + self.model = nn.Sequential(*layers) + + def forward(self,d0,d1,eps=0.1): + return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) + +class BCERankingLoss(nn.Module): + def __init__(self, chn_mid=32): + super(BCERankingLoss, self).__init__() + self.net = Dist2LogitLayer(chn_mid=chn_mid) + # self.parameters = list(self.net.parameters()) + self.loss = torch.nn.BCELoss() + + def forward(self, d0, d1, judge): + per = (judge+1.)/2. + self.logit = self.net.forward(d0,d1) + return self.loss(self.logit, per) + +# L2, DSSIM metrics +class FakeNet(nn.Module): + def __init__(self, use_gpu=True, colorspace='Lab'): + super(FakeNet, self).__init__() + self.use_gpu = use_gpu + self.colorspace=colorspace + +class L2(FakeNet): + + def forward(self, in0, in1, retPerLayer=None): + assert(in0.size()[0]==1) # currently only supports batchSize 1 + + if(self.colorspace=='RGB'): + (N,C,X,Y) = in0.size() + value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) + return value + elif(self.colorspace=='Lab'): + value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), + util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') + ret_var = Variable( torch.Tensor((value,) ) ) + if(self.use_gpu): + ret_var = ret_var.cuda() + return ret_var + +class DSSIM(FakeNet): + + def forward(self, in0, in1, retPerLayer=None): + assert(in0.size()[0]==1) # currently only supports batchSize 1 + + if(self.colorspace=='RGB'): + value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float') + elif(self.colorspace=='Lab'): + value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), + util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') + ret_var = Variable( torch.Tensor((value,) ) ) + if(self.use_gpu): + ret_var = ret_var.cuda() + return ret_var + +def print_network(net): + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print('Network',net) + print('Total number of parameters: %d' % num_params) diff --git a/gan_inv/lpips/pretrained_networks.py b/gan_inv/lpips/pretrained_networks.py new file mode 100644 index 0000000..077a244 --- /dev/null +++ b/gan_inv/lpips/pretrained_networks.py @@ -0,0 +1,181 @@ +from collections import namedtuple +import torch +from torchvision import models as tv +from IPython import embed + +class squeezenet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(squeezenet, self).__init__() + pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.slice6 = torch.nn.Sequential() + self.slice7 = torch.nn.Sequential() + self.N_slices = 7 + for x in range(2): + self.slice1.add_module(str(x), pretrained_features[x]) + for x in range(2,5): + self.slice2.add_module(str(x), pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), pretrained_features[x]) + for x in range(10, 11): + self.slice5.add_module(str(x), pretrained_features[x]) + for x in range(11, 12): + self.slice6.add_module(str(x), pretrained_features[x]) + for x in range(12, 13): + self.slice7.add_module(str(x), pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + h = self.slice6(h) + h_relu6 = h + h = self.slice7(h) + h_relu7 = h + vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) + out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) + + return out + + +class alexnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(alexnet, self).__init__() + alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(2): + self.slice1.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(10, 12): + self.slice5.add_module(str(x), alexnet_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) + out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) + + return out + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + + return out + + + +class resnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True, num=18): + super(resnet, self).__init__() + if(num==18): + self.net = tv.resnet18(pretrained=pretrained) + elif(num==34): + self.net = tv.resnet34(pretrained=pretrained) + elif(num==50): + self.net = tv.resnet50(pretrained=pretrained) + elif(num==101): + self.net = tv.resnet101(pretrained=pretrained) + elif(num==152): + self.net = tv.resnet152(pretrained=pretrained) + self.N_slices = 5 + + self.conv1 = self.net.conv1 + self.bn1 = self.net.bn1 + self.relu = self.net.relu + self.maxpool = self.net.maxpool + self.layer1 = self.net.layer1 + self.layer2 = self.net.layer2 + self.layer3 = self.net.layer3 + self.layer4 = self.net.layer4 + + def forward(self, X): + h = self.conv1(X) + h = self.bn1(h) + h = self.relu(h) + h_relu1 = h + h = self.maxpool(h) + h = self.layer1(h) + h_conv2 = h + h = self.layer2(h) + h_conv3 = h + h = self.layer3(h) + h_conv4 = h + h = self.layer4(h) + h_conv5 = h + + outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) + out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) + + return out diff --git a/gan_inv/lpips/util.py b/gan_inv/lpips/util.py new file mode 100644 index 0000000..4f8b582 --- /dev/null +++ b/gan_inv/lpips/util.py @@ -0,0 +1,160 @@ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from skimage.metrics import structural_similarity +import torch + + +from . import dist_model + +class PerceptualLoss(torch.nn.Module): + def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric) + # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss + super(PerceptualLoss, self).__init__() + print('Setting up Perceptual loss...') + self.use_gpu = use_gpu + self.spatial = spatial + self.gpu_ids = gpu_ids + self.model = dist_model.DistModel() + self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids) + print('...[%s] initialized'%self.model.name()) + print('...Done') + + def forward(self, pred, target, normalize=False): + """ + Pred and target are Variables. + If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] + If normalize is False, assumes the images are already between [-1,+1] + + Inputs pred and target are Nx3xHxW + Output pytorch Variable N long + """ + + if normalize: + target = 2 * target - 1 + pred = 2 * pred - 1 + + return self.model.forward(target, pred) + +def normalize_tensor(in_feat,eps=1e-10): + norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) + return in_feat/(norm_factor+eps) + +def l2(p0, p1, range=255.): + return .5*np.mean((p0 / range - p1 / range)**2) + +def psnr(p0, p1, peak=255.): + return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) + +def dssim(p0, p1, range=255.): + return (1 - structural_similarity(p0, p1, data_range=range, multichannel=True)) / 2. + +def rgb2lab(in_img,mean_cent=False): + from skimage import color + img_lab = color.rgb2lab(in_img) + if(mean_cent): + img_lab[:,:,0] = img_lab[:,:,0]-50 + return img_lab + +def tensor2np(tensor_obj): + # change dimension of a tensor object into a numpy array + return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) + +def np2tensor(np_obj): + # change dimenion of np array into tensor array + return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) + +def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): + # image tensor to lab tensor + from skimage import color + + img = tensor2im(image_tensor) + img_lab = color.rgb2lab(img) + if(mc_only): + img_lab[:,:,0] = img_lab[:,:,0]-50 + if(to_norm and not mc_only): + img_lab[:,:,0] = img_lab[:,:,0]-50 + img_lab = img_lab/100. + + return np2tensor(img_lab) + +def tensorlab2tensor(lab_tensor,return_inbnd=False): + from skimage import color + import warnings + warnings.filterwarnings("ignore") + + lab = tensor2np(lab_tensor)*100. + lab[:,:,0] = lab[:,:,0]+50 + + rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) + if(return_inbnd): + # convert back to lab, see if we match + lab_back = color.rgb2lab(rgb_back.astype('uint8')) + mask = 1.*np.isclose(lab_back,lab,atol=2.) + mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) + return (im2tensor(rgb_back),mask) + else: + return im2tensor(rgb_back) + +def rgb2lab(input): + from skimage import color + return color.rgb2lab(input / 255.) + +def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor + return image_numpy.astype(imtype) + +def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): + return torch.Tensor((image / factor - cent) + [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) + +def tensor2vec(vector_tensor): + return vector_tensor.data.cpu().numpy()[:, :, 0, 0] + +def voc_ap(rec, prec, use_07_metric=False): + """ ap = voc_ap(rec, prec, [use_07_metric]) + Compute VOC AP given precision and recall. + If use_07_metric is true, uses the + VOC 07 11 point method (default:False). + """ + if use_07_metric: + # 11 point metric + ap = 0. + for t in np.arange(0., 1.1, 0.1): + if np.sum(rec >= t) == 0: + p = 0 + else: + p = np.max(prec[rec >= t]) + ap = ap + p / 11. + else: + # correct AP calculation + # first append sentinel values at the end + mrec = np.concatenate(([0.], rec, [1.])) + mpre = np.concatenate(([0.], prec, [0.])) + + # compute the precision envelope + for i in range(mpre.size - 1, 0, -1): + mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) + + # to calculate area under PR curve, look for points + # where X axis (recall) changes value + i = np.where(mrec[1:] != mrec[:-1])[0] + + # and sum (\Delta recall) * prec + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) + return ap + +def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): +# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor + return image_numpy.astype(imtype) + +def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): +# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): + return torch.Tensor((image / factor - cent) + [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) From 2c428ff9bdc719cb021e40f6561ced494403f440 Mon Sep 17 00:00:00 2001 From: Tianhao Xie <52686796+tianhaoxie@users.noreply.github.com> Date: Fri, 30 Jun 2023 08:48:48 -0400 Subject: [PATCH 3/5] Add files via upload --- viz/renderer.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/viz/renderer.py b/viz/renderer.py index 26e4f11..6955f3d 100644 --- a/viz/renderer.py +++ b/viz/renderer.py @@ -225,7 +225,7 @@ def init_network(self, res, res.num_ws = G.num_ws res.has_noise = any('noise_const' in name for name, _buf in G.synthesis.named_buffers()) res.has_input_transform = (hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform')) - + self.lr = lr # Set input transform. if res.has_input_transform: m = np.eye(3) @@ -262,6 +262,17 @@ def init_network(self, res, self.feat_refs = None self.points0_pt = None + def set_latent(self,w,trunc_psi,trunc_cutoff): + #label = torch.zeros([1, self.G.c_dim], device=self._device) + #w = self.G.mapping(z, label, truncation_psi=trunc_psi, truncation_cutoff=trunc_cutoff) + self.w0 = w.detach().clone() + if self.w_plus: + self.w = w.detach() + else: + self.w = w[:, 0, :].detach() + self.w.requires_grad = True + self.w_optim = torch.optim.Adam([self.w], lr=self.lr) + def update_lr(self, lr): del self.w_optim From d6aa972708e3c6a3b7d8159f5e9dac5e771209a5 Mon Sep 17 00:00:00 2001 From: Tianhao Xie <52686796+tianhaoxie@users.noreply.github.com> Date: Fri, 30 Jun 2023 10:28:07 -0400 Subject: [PATCH 4/5] Add files via upload --- gan_inv/inversion.py | 56 +++++++++++++++++--------------------------- 1 file changed, 22 insertions(+), 34 deletions(-) diff --git a/gan_inv/inversion.py b/gan_inv/inversion.py index b22a8c7..c03a383 100644 --- a/gan_inv/inversion.py +++ b/gan_inv/inversion.py @@ -14,36 +14,6 @@ -def noise_regularize(noises): - loss = 0 - - for noise in noises: - size = noise.shape[2] - - while True: - loss = ( - loss - + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) - + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2) - ) - - if size <= 8: - break - - noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2]) - noise = noise.mean([3, 5]) - size //= 2 - - return loss - - -def noise_normalize_(noises): - for noise in noises: - mean = noise.mean() - std = noise.std() - - noise.data.add_(-mean).div_(std) - def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): lr_ramp = min(1, (1 - t) / rampdown) @@ -53,10 +23,7 @@ def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): return initial_lr * lr_ramp -def latent_noise(latent, strength): - noise = torch.randn_like(latent) * strength - return latent + noise def make_image(tensor): @@ -259,6 +226,27 @@ def cacl_loss(self,percept, generated_image,real_image): return loss def train(self,img,w_plus=False): + if torch.is_tensor(img) == False: + transform = transforms.Compose( + [ + transforms.Resize(self.g_ema.img_resolution, ), + transforms.CenterCrop(self.g_ema.img_resolution), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ] + ) + + real_img = transform(img).to('cuda').unsqueeze(0) + + else: + img = transforms.functional.resize(img, self.g_ema.img_resolution) + transform = transforms.Compose( + [ + transforms.CenterCrop(self.g_ema.img_resolution), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ] + ) + real_img = transform(img).to('cuda').unsqueeze(0) inversed_result = inverse_image(self.g_ema,img,self.percept,self.g_ema.img_resolution,w_plus) w_pivot = inversed_result['latent'] if w_plus: @@ -275,7 +263,7 @@ def train(self,img,w_plus=False): optimizer.param_groups[0]["lr"] = lr generated_image = self.g_ema.synthesis(ws,noise_mode='const') - loss = self.cacl_loss(self.percept,generated_image,inversed_result['real']) + loss = self.cacl_loss(self.percept,generated_image,real_img) pbar.set_description( ( f"loss: {loss.item():.4f}" From 5dc61562291ec15c0e410cefa156266d5310ba58 Mon Sep 17 00:00:00 2001 From: Tianhao Xie <52686796+tianhaoxie@users.noreply.github.com> Date: Mon, 10 Jul 2023 07:09:35 -0400 Subject: [PATCH 5/5] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 8b626aa..2c898a0 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ __pycache__/ *.py[cod] *$py.class +*.pyc # C extensions *.so