From 9849e6389e381891112dc3ab966aa060cfb8b976 Mon Sep 17 00:00:00 2001 From: Alexey Borsky Date: Sun, 14 May 2023 05:58:37 +0300 Subject: [PATCH] critical fixes --- readme.md | 2 ++ scripts/base_ui.py | 8 +++++++- scripts/core/txt2vid.py | 41 ++++++++++++++++++++++++++++++++++------- scripts/core/utils.py | 5 +++-- 4 files changed, 46 insertions(+), 10 deletions(-) diff --git a/readme.md b/readme.md index 8129c3c..d29c255 100644 --- a/readme.md +++ b/readme.md @@ -71,3 +71,5 @@ To install the extension go to 'Extensions' tab in [Automatic1111 web-ui](https: * Added ability to export current parameters in a human readable form as a json. * Interpolation mode in the flow-applying stage is set to ‘nearest’ to reduce overtime image blurring. * Added ControlNet to txt2vid mode as well as fixing #86 issue, thanks to [@mariaWitch](https://github.com/mariaWitch) +* Fixed a major issue when ConrtolNet used wrong input images. Because of this vid2vid results were way worse than they should be. +* Text to video mode now supports video as a guidance for ControlNet. It allows to create much stronger video stylizations. diff --git a/scripts/base_ui.py b/scripts/base_ui.py index b51b5ee..343a98b 100644 --- a/scripts/base_ui.py +++ b/scripts/base_ui.py @@ -71,7 +71,7 @@ def inputs_ui(): with gr.Tab('vid2vid') as tab_vid2vid: with gr.Row(): - gr.HTML('Put your video here:') + gr.HTML('Input video (each frame will be used as initial image for SD and as input image to CN): *REQUIRED') with gr.Row(): v2v_file = gr.File(label="Input video", interactive=True, file_count="single", file_types=["video"], elem_id="vid_to_vid_chosen_file") @@ -110,7 +110,13 @@ def inputs_ui(): v2v_custom_inputs = scripts.scripts_img2img.setup_ui() with gr.Tab('txt2vid') as tab_txt2vid: + with gr.Row(): + gr.HTML('Control video (each frame will be used as input image to CN): *NOT REQUIRED') + with gr.Row(): + t2v_file = gr.File(label="Input video", interactive=True, file_count="single", file_types=["video"], elem_id="tex_to_vid_chosen_file") + t2v_width, t2v_height, t2v_prompt, t2v_n_prompt, t2v_cfg_scale, t2v_seed, t2v_processing_strength, t2v_fix_frame_strength, t2v_sampler_index, t2v_steps = setup_common_values('txt2vid', t2v_args) + with gr.Row(): t2v_length = gr.Slider(label='Length (in frames)', minimum=10, maximum=2048, step=10, value=40, interactive=True) t2v_fps = gr.Slider(label='Video FPS', minimum=4, maximum=64, step=4, value=12, interactive=True) diff --git a/scripts/core/txt2vid.py b/scripts/core/txt2vid.py index 77de60e..0a3c3e6 100644 --- a/scripts/core/txt2vid.py +++ b/scripts/core/txt2vid.py @@ -44,18 +44,30 @@ def FloweR_load_model(w, h): # Move the model to the device FloweR_model = FloweR_model.to(DEVICE) - +def read_frame_from_video(input_video): + if input_video is None: return None + + # Reading video file + if input_video.isOpened(): + ret, cur_frame = input_video.read() + if cur_frame is not None: + cur_frame = cv2.cvtColor(cur_frame, cv2.COLOR_BGR2RGB) + else: + cur_frame = None + input_video.release() + input_video = None + + return cur_frame def start_process(*args): processing_start_time = time.time() args_dict = utils.args_to_dict(*args) args_dict = utils.get_mode_args('t2v', args_dict) - #utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame)) - processed_frames, _, _, _ = utils.txt2img(args_dict) - processed_frame = np.array(processed_frames[0]) - processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8) - init_frame = processed_frame.copy() + # Open the input video file + input_video = None + if args_dict['file'] is not None: + input_video = cv2.VideoCapture(args_dict['file'].name) # Create an output video file with the same fps, width, and height as the input video output_video_name = f'outputs/sd-cn-animation/txt2vid/{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.mp4' @@ -69,6 +81,16 @@ def save_result_to_image(image, ind): if args_dict['save_frames_check']: cv2.imwrite(os.path.join(output_video_folder, f'{ind:05d}.png'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) + if input_video is not None: + curr_video_frame = read_frame_from_video(input_video) + curr_video_frame = cv2.resize(curr_video_frame, (args_dict['width'], args_dict['height'])) + utils.set_CNs_input_image(args_dict, Image.fromarray(curr_video_frame)) + + processed_frames, _, _, _ = utils.txt2img(args_dict) + processed_frame = np.array(processed_frames[0]) + processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8) + init_frame = processed_frame.copy() + output_video = cv2.VideoWriter(output_video_name, cv2.VideoWriter_fourcc(*'mp4v'), args_dict['fps'], (args_dict['width'], args_dict['height'])) output_video.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR)) @@ -125,7 +147,11 @@ def save_result_to_image(image, ind): args_dict['mask_img'] = Image.fromarray(pred_occl) args_dict['seed'] = -1 - #utils.set_CNs_input_image(args_dict, Image.fromarray(curr_frame)) + if input_video is not None: + curr_video_frame = read_frame_from_video(input_video) + curr_video_frame = cv2.resize(curr_video_frame, (args_dict['width'], args_dict['height'])) + utils.set_CNs_input_image(args_dict, Image.fromarray(curr_video_frame)) + processed_frames, _, _, _ = utils.img2img(args_dict) processed_frame = np.array(processed_frames[0]) processed_frame = skimage.exposure.match_histograms(processed_frame, init_frame, channel_axis=None) @@ -150,6 +176,7 @@ def save_result_to_image(image, ind): stat = f"Frame: {ind + 2} / {args_dict['length']}; " + utils.get_time_left(ind+2, args_dict['length'], processing_start_time) yield stat, curr_frame, pred_occl, warped_frame, processed_frame, None, gr.Button.update(interactive=False), gr.Button.update(interactive=True) + if input_video is not None: input_video.release() output_video.release() FloweR_clear_memory() diff --git a/scripts/core/utils.py b/scripts/core/utils.py index b7a402f..bdcb105 100644 --- a/scripts/core/utils.py +++ b/scripts/core/utils.py @@ -10,7 +10,7 @@ def get_component_names(): 'v2v_sampler_index', 'v2v_steps', 'v2v_override_settings', 'v2v_occlusion_mask_blur', 'v2v_occlusion_mask_trailing', 'v2v_occlusion_mask_flow_multiplier', 'v2v_occlusion_mask_difo_multiplier', 'v2v_occlusion_mask_difs_multiplier', 'v2v_step_1_processing_mode', 'v2v_step_1_blend_alpha', 'v2v_step_1_seed', 'v2v_step_2_seed', - 't2v_width', 't2v_height', 't2v_prompt', 't2v_n_prompt', 't2v_cfg_scale', 't2v_seed', 't2v_processing_strength', 't2v_fix_frame_strength', + 't2v_file', 't2v_width', 't2v_height', 't2v_prompt', 't2v_n_prompt', 't2v_cfg_scale', 't2v_seed', 't2v_processing_strength', 't2v_fix_frame_strength', 't2v_sampler_index', 't2v_steps', 't2v_length', 't2v_fps', 'glo_save_frames_check' ] @@ -121,7 +121,8 @@ def get_mode_args(mode, args_dict): def set_CNs_input_image(args_dict, image): for script_input in args_dict['script_inputs']: if type(script_input).__name__ == 'UiControlNetUnit': - script_input.batch_images = [image] + script_input.batch_images = [np.array(image)] + script_input.image = np.array(image) import time import datetime