From dd9c8ecc78575babf7527f74e830e05667dfcc79 Mon Sep 17 00:00:00 2001 From: Tianhao Xie <52686796+tianhaoxie@users.noreply.github.com> Date: Fri, 30 Jun 2023 08:47:24 -0400 Subject: [PATCH 1/5] Add files via upload --- visualizer_gradio_custom.py | 964 ++++++++++++++++++++++++++++++++++++ 1 file changed, 964 insertions(+) create mode 100644 visualizer_gradio_custom.py diff --git a/visualizer_gradio_custom.py b/visualizer_gradio_custom.py new file mode 100644 index 0000000..6e84057 --- /dev/null +++ b/visualizer_gradio_custom.py @@ -0,0 +1,964 @@ +import os +import os.path as osp +from argparse import ArgumentParser +from functools import partial + +import gradio as gr +import numpy as np +import torch +from PIL import Image +import imageio +import dnnlib +from gradio_utils import (ImageMask, draw_mask_on_image, draw_points_on_image, + get_latest_points_pair, get_valid_mask, + on_change_single_global_state) +from viz.renderer import Renderer, add_watermark_np +from gan_inv.inversion import PTI +from gan_inv.lpips import util +parser = ArgumentParser() +parser.add_argument('--share',default='False') +parser.add_argument('--cache-dir', type=str, default='./checkpoints') +args = parser.parse_args() + +cache_dir = args.cache_dir + +device = 'cuda' + + +def reverse_point_pairs(points): + new_points = [] + for p in points: + new_points.append([p[1], p[0]]) + return new_points + + +def clear_state(global_state, target=None): + """Clear target history state from global_state + If target is not defined, points and mask will be both removed. + 1. set global_state['points'] as empty dict + 2. set global_state['mask'] as full-one mask. + """ + if target is None: + target = ['point', 'mask'] + if not isinstance(target, list): + target = [target] + if 'point' in target: + global_state['points'] = dict() + print('Clear Points State!') + if 'mask' in target: + image_raw = global_state["images"]["image_raw"] + global_state['mask'] = np.ones((image_raw.size[1], image_raw.size[0]), + dtype=np.uint8) + print('Clear mask State!') + + return global_state + + +def init_images(global_state): + """This function is called only ones with Gradio App is started. + 0. pre-process global_state, unpack value from global_state of need + 1. Re-init renderer + 2. run `renderer._render_drag_impl` with `is_drag=False` to generate + new image + 3. Assign images to global state and re-generate mask + """ + + if isinstance(global_state, gr.State): + state = global_state.value + else: + state = global_state + + state['renderer'].init_network( + state['generator_params'], # res + valid_checkpoints_dict[state['pretrained_weight']], # pkl + state['params']['seed'], # w0_seed, + None, # w_load + state['params']['latent_space'] == 'w+', # w_plus + 'const', + state['params']['trunc_psi'], # trunc_psi, + state['params']['trunc_cutoff'], # trunc_cutoff, + None, # input_transform + state['params']['lr'] # lr, + ) + + state['renderer']._render_drag_impl(state['generator_params'], + is_drag=False, + to_pil=True) + + init_image = state['generator_params'].image + state['images']['image_orig'] = init_image + state['images']['image_raw'] = init_image + state['images']['image_show'] = Image.fromarray( + add_watermark_np(np.array(init_image))) + state['mask'] = np.ones((init_image.size[1], init_image.size[0]), + dtype=np.uint8) + return global_state + + +def update_image_draw(image, points, mask, show_mask, global_state=None): + + image_draw = draw_points_on_image(image, points) + if show_mask and mask is not None and not (mask == 0).all() and not ( + mask == 1).all(): + image_draw = draw_mask_on_image(image_draw, mask) + + image_draw = Image.fromarray(add_watermark_np(np.array(image_draw))) + if global_state is not None: + global_state['images']['image_show'] = image_draw + return image_draw + + +def preprocess_mask_info(global_state, image): + """Function to handle mask information. + 1. last_mask is None: Do not need to change mask, return mask + 2. last_mask is not None: + 2.1 global_state is remove_mask: + 2.2 global_state is add_mask: + """ + if isinstance(image, dict): + last_mask = get_valid_mask(image['mask']) + else: + last_mask = None + mask = global_state['mask'] + + # mask in global state is a placeholder with all 1. + if (mask == 1).all(): + mask = last_mask + + # last_mask = global_state['last_mask'] + editing_mode = global_state['editing_state'] + + if last_mask is None: + return global_state + + if editing_mode == 'remove_mask': + updated_mask = np.clip(mask - last_mask, 0, 1) + print(f'Last editing_state is {editing_mode}, do remove.') + elif editing_mode == 'add_mask': + updated_mask = np.clip(mask + last_mask, 0, 1) + print(f'Last editing_state is {editing_mode}, do add.') + else: + updated_mask = mask + print(f'Last editing_state is {editing_mode}, ' + 'do nothing to mask.') + + global_state['mask'] = updated_mask + # global_state['last_mask'] = None # clear buffer + return global_state + + +valid_checkpoints_dict = { + f.split('/')[-1].split('.')[0]: osp.join(cache_dir, f) + for f in os.listdir(cache_dir) + if (f.endswith('pkl') and osp.exists(osp.join(cache_dir, f))) +} +print(f'File under cache_dir ({cache_dir}):') +print(os.listdir(cache_dir)) +print('Valid checkpoint file:') +print(valid_checkpoints_dict) + +init_pkl = 'stylegan2_lions_512_pytorch' + + + +# Network & latents tab listeners +def on_change_pretrained_dropdown(pretrained_value, global_state): + """Function to handle model change. + 1. Set pretrained value to global_state + 2. Re-init images and clear all states + """ + global_state['pretrained_weight'] = pretrained_value + init_images(global_state) + clear_state(global_state) + + return global_state, global_state["images"]['image_show'] + + + +def on_click_reset_image(global_state): + """Reset image to the original one and clear all states + 1. Re-init images + 2. Clear all states + """ + + init_images(global_state) + clear_state(global_state) + + return global_state, global_state['images']['image_show'] + + + + # Update parameters +def on_change_update_image_seed(seed, global_state): + """Function to handle generation seed change. + 1. Set seed to global_state + 2. Re-init images and clear all states + """ + + global_state["params"]["seed"] = int(seed) + init_images(global_state) + clear_state(global_state) + + return global_state, global_state['images']['image_show'] + + + +def on_click_latent_space(latent_space, global_state): + """Function to reset latent space to optimize. + NOTE: this function we reset the image and all controls + 1. Set latent-space to global_state + 2. Re-init images and clear all state + """ + + global_state['params']['latent_space'] = latent_space + init_images(global_state) + clear_state(global_state) + + return global_state, global_state['images']['image_show'] + + + +def on_click_inverse_custom_image(custom_image,global_state): + print('inverse GAN') + + if isinstance(global_state, gr.State): + state = global_state.value + else: + state = global_state + + state['renderer'].init_network( + state['generator_params'], # res + valid_checkpoints_dict[state['pretrained_weight']], # pkl + state['params']['seed'], # w0_seed, + None, # w_load + state['params']['latent_space'] == 'w+', # w_plus + 'const', + state['params']['trunc_psi'], # trunc_psi, + state['params']['trunc_cutoff'], # trunc_cutoff, + None, # input_transform + state['params']['lr'] # lr, + ) + + percept = util.PerceptualLoss( + model="net-lin", net="vgg", use_gpu=True + ) + + image = Image.open(custom_image.name) + + pti = PTI(global_state['renderer'].G,percept) + inversed_img, w_pivot = pti.train(image,state['params']['latent_space'] == 'w+') + inversed_img = (inversed_img[0] * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0) + inversed_img = inversed_img.cpu().numpy() + inversed_img = Image.fromarray(inversed_img) + global_state['images']['image_show'] = Image.fromarray( + add_watermark_np(np.array(inversed_img))) + + global_state['images']['image_orig'] = inversed_img + global_state['images']['image_raw'] = inversed_img + + global_state['mask'] = np.ones((inversed_img.size[1], inversed_img.size[0]), + dtype=np.uint8) + global_state['generator_params'].image = inversed_img + global_state['generator_params'].w = w_pivot.detach().cpu().numpy() + global_state['renderer'].set_latent(w_pivot,global_state['params']['trunc_psi'],global_state['params']['trunc_cutoff']) + + del percept + del pti + print('inverse end') + + return global_state, global_state['images']['image_show'], gr.Button.update(interactive=True) + +def on_save_image(global_state,form_save_image_path): + imageio.imsave(form_save_image_path,global_state['images']['image_raw']) + +def on_reset_custom_image(global_state): + if isinstance(global_state, gr.State): + state = global_state.value + else: + state = global_state + clear_state(state) + state['renderer'].w = state['renderer'].w0.detach().clone() + state['renderer'].w.requires_grad = True + state['renderer'].w_optim = torch.optim.Adam([state['renderer'].w], lr=state['renderer'].lr) + state['renderer']._render_drag_impl(state['generator_params'], + is_drag=False, + to_pil=True) + + init_image = state['generator_params'].image + state['images']['image_orig'] = init_image + state['images']['image_raw'] = init_image + state['images']['image_show'] = Image.fromarray( + add_watermark_np(np.array(init_image))) + state['mask'] = np.ones((init_image.size[1], init_image.size[0]), + dtype=np.uint8) + return state, state['images']['image_show'] +def on_change_lr(lr, global_state): + if lr == 0: + print('lr is 0, do nothing.') + return global_state + else: + global_state["params"]["lr"] = lr + renderer = global_state['renderer'] + renderer.update_lr(lr) + print('New optimizer: ') + print(renderer.w_optim) + return global_state + + +def on_click_start(global_state, image): + p_in_pixels = [] + t_in_pixels = [] + valid_points = [] + + # handle of start drag in mask editing mode + global_state = preprocess_mask_info(global_state, image) + + # Prepare the points for the inference + if len(global_state["points"]) == 0: + # yield on_click_start_wo_points(global_state, image) + image_raw = global_state['images']['image_raw'] + update_image_draw( + image_raw, + global_state['points'], + global_state['mask'], + global_state['show_mask'], + global_state, + ) + + yield ( + global_state, + 0, + global_state['images']['image_show'], + # gr.File.update(visible=False), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + # latent space + gr.Radio.update(interactive=True), + gr.Button.update(interactive=True), + # NOTE: disable stop button + gr.Button.update(interactive=False), + + # update other comps + gr.Dropdown.update(interactive=True), + gr.Number.update(interactive=True), + gr.Number.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Checkbox.update(interactive=True), + # gr.Number.update(interactive=True), + gr.Number.update(interactive=True), + ) + else: + + # Transform the points into torch tensors + for key_point, point in global_state["points"].items(): + try: + p_start = point.get("start_temp", point["start"]) + p_end = point["target"] + + if p_start is None or p_end is None: + continue + + except KeyError: + continue + + p_in_pixels.append(p_start) + t_in_pixels.append(p_end) + valid_points.append(key_point) + + mask = torch.tensor(global_state['mask']).float() + drag_mask = 1 - mask + + renderer: Renderer = global_state["renderer"] + global_state['temporal_params']['stop'] = False + global_state['editing_state'] = 'running' + + # reverse points order + p_to_opt = reverse_point_pairs(p_in_pixels) + t_to_opt = reverse_point_pairs(t_in_pixels) + #print('Running with:') + #print(f' Source: {p_in_pixels}') + #print(f' Target: {t_in_pixels}') + step_idx = 0 + while True: + if global_state["temporal_params"]["stop"]: + break + + # do drage here! + renderer._render_drag_impl( + global_state['generator_params'], + p_to_opt, # point + t_to_opt, # target + drag_mask, # mask, + global_state['params']['motion_lambda'], # lambda_mask + reg=0, + feature_idx=5, # NOTE: do not support change for now + r1=global_state['params']['r1_in_pixels'], # r1 + r2=global_state['params']['r2_in_pixels'], # r2 + # random_seed = 0, + # noise_mode = 'const', + trunc_psi=global_state['params']['trunc_psi'], + # force_fp32 = False, + # layer_name = None, + # sel_channels = 3, + # base_channel = 0, + # img_scale_db = 0, + # img_normalize = False, + # untransform = False, + is_drag=True, + to_pil=True) + + if step_idx % global_state['draw_interval'] == 0: + #print('Current Source:') + for key_point, p_i, t_i in zip(valid_points, p_to_opt, + t_to_opt): + global_state["points"][key_point]["start_temp"] = [ + p_i[1], + p_i[0], + ] + global_state["points"][key_point]["target"] = [ + t_i[1], + t_i[0], + ] + start_temp = global_state["points"][key_point][ + "start_temp"] + #print(f' {start_temp}') + + image_result = global_state['generator_params']['image'] + image_draw = update_image_draw( + image_result, + global_state['points'], + global_state['mask'], + global_state['show_mask'], + global_state, + ) + global_state['images']['image_raw'] = image_result + + yield ( + global_state, + step_idx, + global_state['images']['image_show'], + # gr.File.update(visible=False), + gr.Button.update(interactive=False), + gr.Button.update(interactive=False), + gr.Button.update(interactive=False), + gr.Button.update(interactive=False), + gr.Button.update(interactive=False), + # latent space + gr.Radio.update(interactive=False), + gr.Button.update(interactive=False), + # enable stop button in loop + gr.Button.update(interactive=True), + + # update other comps + gr.Dropdown.update(interactive=False), + gr.Number.update(interactive=False), + gr.Number.update(interactive=False), + gr.Button.update(interactive=False), + gr.Button.update(interactive=False), + gr.Checkbox.update(interactive=False), + # gr.Number.update(interactive=False), + gr.Number.update(interactive=False), + ) + + # increate step + step_idx += 1 + + image_result = global_state['generator_params']['image'] + global_state['images']['image_raw'] = image_result + image_draw = update_image_draw(image_result, + global_state['points'], + global_state['mask'], + global_state['show_mask'], + global_state) + + # fp = NamedTemporaryFile(suffix=".png", delete=False) + # image_result.save(fp, "PNG") + + global_state['editing_state'] = 'add_points' + + yield ( + global_state, + 0, # reset step to 0 after stop. + global_state['images']['image_show'], + # gr.File.update(visible=True, value=fp.name), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + # latent space + gr.Radio.update(interactive=True), + gr.Button.update(interactive=True), + # NOTE: disable stop button with loop finish + gr.Button.update(interactive=False), + + # update other comps + gr.Dropdown.update(interactive=True), + gr.Number.update(interactive=True), + gr.Number.update(interactive=True), + gr.Checkbox.update(interactive=True), + gr.Number.update(interactive=True), + ) + + + +def on_click_stop(global_state): + """Function to handle stop button is clicked. + 1. send a stop signal by set global_state["temporal_params"]["stop"] as True + 2. Disable Stop button + """ + global_state["temporal_params"]["stop"] = True + + return global_state, gr.Button.update(interactive=False) + + + +def on_click_remove_point(global_state): + choice = global_state["curr_point"] + del global_state["points"][choice] + + choices = list(global_state["points"].keys()) + + if len(choices) > 0: + global_state["curr_point"] = choices[0] + + return ( + gr.Dropdown.update(choices=choices, value=choices[0]), + global_state, + ) + + # Mask +def on_click_reset_mask(global_state): + global_state['mask'] = np.ones( + ( + global_state["images"]["image_raw"].size[1], + global_state["images"]["image_raw"].size[0], + ), + dtype=np.uint8, + ) + image_draw = update_image_draw(global_state['images']['image_raw'], + global_state['points'], + global_state['mask'], + global_state['show_mask'], global_state) + return global_state, image_draw + + + + # Image +def on_click_enable_draw(global_state, image): + """Function to start add mask mode. + 1. Preprocess mask info from last state + 2. Change editing state to add_mask + 3. Set curr image with points and mask + """ + global_state = preprocess_mask_info(global_state, image) + global_state['editing_state'] = 'add_mask' + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw(image_raw, global_state['points'], + global_state['mask'], True, + global_state) + return (global_state, + gr.Image.update(value=image_draw, interactive=True)) + +def on_click_remove_draw(global_state, image): + """Function to start remove mask mode. + 1. Preprocess mask info from last state + 2. Change editing state to remove_mask + 3. Set curr image with points and mask + """ + global_state = preprocess_mask_info(global_state, image) + global_state['edinting_state'] = 'remove_mask' + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw(image_raw, global_state['points'], + global_state['mask'], True, + global_state) + return (global_state, + gr.Image.update(value=image_draw, interactive=True)) + + + +def on_click_add_point(global_state, image: dict): + """Function switch from add mask mode to add points mode. + 1. Updaste mask buffer if need + 2. Change global_state['editing_state'] to 'add_points' + 3. Set current image with mask + """ + + global_state = preprocess_mask_info(global_state, image) + global_state['editing_state'] = 'add_points' + mask = global_state['mask'] + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw(image_raw, global_state['points'], mask, + global_state['show_mask'], global_state) + + return (global_state, + gr.Image.update(value=image_draw, interactive=False)) + + + +def on_click_image(global_state, evt: gr.SelectData): + """This function only support click for point selection + """ + xy = evt.index + if global_state['editing_state'] != 'add_points': + print(f'In {global_state["editing_state"]} state. ' + 'Do not add points.') + + return global_state, global_state['images']['image_show'] + + points = global_state["points"] + + point_idx = get_latest_points_pair(points) + if point_idx is None: + points[0] = {'start': xy, 'target': None} + print(f'Click Image - Start - {xy}') + elif points[point_idx].get('target', None) is None: + points[point_idx]['target'] = xy + print(f'Click Image - Target - {xy}') + else: + points[point_idx + 1] = {'start': xy, 'target': None} + print(f'Click Image - Start - {xy}') + + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw( + image_raw, + global_state['points'], + global_state['mask'], + global_state['show_mask'], + global_state, + ) + + return global_state, image_draw + + + +def on_click_clear_points(global_state): + """Function to handle clear all control points + 1. clear global_state['points'] (clear_state) + 2. re-init network + 2. re-draw image + """ + clear_state(global_state, target='point') + + renderer: Renderer = global_state["renderer"] + renderer.feat_refs = None + + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw(image_raw, {}, global_state['mask'], + global_state['show_mask'], global_state) + return global_state, image_draw + + + +def on_click_show_mask(global_state, show_mask): + """Function to control whether show mask on image.""" + global_state['show_mask'] = show_mask + + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw( + image_raw, + global_state['points'], + global_state['mask'], + global_state['show_mask'], + global_state, + ) + return global_state, image_draw + + +if __name__ == "__main__": + with gr.Blocks() as app: + # renderer = Renderer() + global_state = gr.State({ + "images": { + # image_orig: the original image, change with seed/model is changed + # image_raw: image with mask and points, change durning optimization + # image_show: image showed on screen + }, + "temporal_params": { + # stop + }, + 'mask': + None, # mask for visualization, 1 for editing and 0 for unchange + 'last_mask': None, # last edited mask + 'show_mask': True, # add button + "generator_params": dnnlib.EasyDict(), + "params": { + "seed": 0, + "motion_lambda": 20, + "r1_in_pixels": 3, + "r2_in_pixels": 12, + "magnitude_direction_in_pixels": 1.0, + "latent_space": "w+", + "trunc_psi": 0.7, + "trunc_cutoff": None, + "lr": 0.001, + }, + "device": device, + "draw_interval": 1, + "renderer": Renderer(disable_timing=True), + "points": {}, + "curr_point": None, + "curr_type_point": "start", + 'editing_state': 'add_points', + 'pretrained_weight': init_pkl + }) + + # init image + global_state = init_images(global_state) + + with gr.Row(): + with gr.Row(): + # Left --> tools + with gr.Column(scale=3): + # Pickle + with gr.Row(): + with gr.Column(scale=1, min_width=10): + gr.Markdown(value='Pickle', show_label=False) + + with gr.Column(scale=4, min_width=10): + form_pretrained_dropdown = gr.Dropdown( + choices=list(valid_checkpoints_dict.keys()), + label="Pretrained Model", + value=init_pkl, + ) + + # Latent + with gr.Row(): + with gr.Column(scale=1, min_width=10): + gr.Markdown(value='Latent', show_label=False) + + with gr.Column(scale=4, min_width=10): + form_seed_number = gr.Number( + value=global_state.value['params']['seed'], + interactive=True, + label="Seed", + ) + form_lr_number = gr.Number( + value=global_state.value["params"]["lr"], + interactive=True, + label="Step Size") + + with gr.Row(): + with gr.Column(scale=2, min_width=10): + form_reset_image = gr.Button("Reset Image") + with gr.Column(scale=3, min_width=10): + form_latent_space = gr.Radio( + ['w', 'w+'], + value=global_state.value['params'] + ['latent_space'], + interactive=True, + label='Latent space to optimize', + show_label=False, + ) + with gr.Row(): + with gr.Column(scale=3, min_width=10): + form_custom_image = gr.UploadButton(label="inverse custom image", + file_types=['.png', '.jpg', '.jpeg']) + with gr.Column(scale=3, min_width=10): + form_reset_custom_image = gr.Button('reset custom image', interactive=False) + with gr.Row(): + with gr.Column(scale=3, min_width=10): + form_save_image_path = gr.Textbox(label="save image to",value='./test.png') + form_save_image = gr.Button('save',interactive=True) + + + # Drag + with gr.Row(): + with gr.Column(scale=1, min_width=10): + gr.Markdown(value='Drag', show_label=False) + with gr.Column(scale=4, min_width=10): + with gr.Row(): + with gr.Column(scale=1, min_width=10): + enable_add_points = gr.Button('Add Points') + with gr.Column(scale=1, min_width=10): + undo_points = gr.Button('Reset Points') + with gr.Row(): + with gr.Column(scale=1, min_width=10): + form_start_btn = gr.Button("Start") + with gr.Column(scale=1, min_width=10): + form_stop_btn = gr.Button("Stop") + + form_steps_number = gr.Number(value=0, + label="Steps", + interactive=False) + + # Mask + with gr.Row(): + with gr.Column(scale=1, min_width=10): + gr.Markdown(value='Mask', show_label=False) + with gr.Column(scale=4, min_width=10): + enable_add_mask = gr.Button('Edit Flexible Area') + with gr.Row(): + with gr.Column(scale=1, min_width=10): + form_reset_mask_btn = gr.Button("Reset mask") + with gr.Column(scale=1, min_width=10): + show_mask = gr.Checkbox( + label='Show Mask', + value=global_state.value['show_mask'], + show_label=False) + + with gr.Row(): + form_lambda_number = gr.Number( + value=global_state.value["params"] + ["motion_lambda"], + interactive=True, + label="Lambda", + ) + + form_draw_interval_number = gr.Number( + value=global_state.value["draw_interval"], + label="Draw Interval (steps)", + interactive=True, + visible=False) + + # Right --> Image + with gr.Column(scale=8): + form_image = ImageMask( + value=global_state.value['images']['image_show'], + brush_radius=20).style( + width=768, + height=768) # NOTE: hard image size code here. + gr.Markdown(""" + ## Quick Start + + 1. Select desired `Pretrained Model` and adjust `Seed` to generate an + initial image. + 2. Click on image to add control points. + 3. Click `Start` and enjoy it! + + ## Advance Usage + + 1. Change `Step Size` to adjust learning rate in drag optimization. + 2. Select `w` or `w+` to change latent space to optimize: + * Optimize on `w` space may cause greater influence to the image. + * Optimize on `w+` space may work slower than `w`, but usually achieve + better results. + * Note that changing the latent space will reset the image, points and + mask (this has the same effect as `Reset Image` button). + 3. Click `Edit Flexible Area` to create a mask and constrain the + unmasked region to remain unchanged. + """) + gr.HTML(""" + +