Skip to content

Commit

Permalink
Use safety checker in OV format
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianboguszewski committed Jan 3, 2025
1 parent 21ef346 commit 72f6133
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
21 changes: 16 additions & 5 deletions demos/paint_your_dreams_demo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import openvino_genai as genai
from PIL import Image
from huggingface_hub import snapshot_download
from transformers import Pipeline, pipeline
from optimum.intel.openvino import OVModelForImageClassification
from transformers import Pipeline, pipeline, AutoProcessor

SCRIPT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "utils")
sys.path.append(os.path.dirname(SCRIPT_DIR))
Expand All @@ -41,15 +42,24 @@ def get_available_devices() -> list[str]:
def download_models(model_name: str, safety_checker_model: str) -> None:
global safety_checker

is_openvino_model = model_name.split("/")[0] == "OpenVINO"

output_dir = MODEL_DIR / model_name
if not output_dir.exists():
snapshot_download(model_name, local_dir=output_dir)
if is_openvino_model:
snapshot_download(model_name, local_dir=output_dir)
else:
raise ValueError(f"Model {model_name} is not from OpenVINO Hub and not supported")

safety_checker_dir = MODEL_DIR / safety_checker_model
if not safety_checker_dir.exists():
snapshot_download(safety_checker_model, local_dir=safety_checker_dir)
model = OVModelForImageClassification.from_pretrained(safety_checker_model, export=True, compile=False)
model.save_pretrained(safety_checker_dir)
processor = AutoProcessor.from_pretrained(safety_checker_model)
processor.save_pretrained(safety_checker_dir)

safety_checker = pipeline("image-classification", model=str(safety_checker_dir), device="cpu")
safety_checker = pipeline("image-classification", model=OVModelForImageClassification.from_pretrained(safety_checker_dir),
image_processor=AutoProcessor.from_pretrained(safety_checker_dir))


async def load_pipeline(model_name: str, device: str):
Expand Down Expand Up @@ -203,7 +213,8 @@ def run_endless_lcm(model_name: str, safety_checker_model: str, local_network: b

parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="OpenVINO/LCM_Dreamshaper_v7-fp16-ov",
choices=["OpenVINO/LCM_Dreamshaper_v7-int8-ov", "OpenVINO/LCM_Dreamshaper_v7-fp16-ov"], help="Visual GenAI model to be used")
choices=["OpenVINO/LCM_Dreamshaper_v7-int8-ov", "OpenVINO/LCM_Dreamshaper_v7-fp16-ov"],
help="Visual GenAI model to be used")
parser.add_argument("--safety_checker_model", type=str, default="Falconsai/nsfw_image_detection",
choices=["Falconsai/nsfw_image_detection"], help="The model to verify if the generated image is NSFW")
parser.add_argument("--local_network", action="store_true", help="Whether demo should be available in local network")
Expand Down
7 changes: 6 additions & 1 deletion demos/paint_your_dreams_demo/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@

openvino==2024.6.0
openvino-genai==2024.6.0
optimum-intel==1.21.0
optimum==1.23.3
# onnx>1.16.1 doesn't work on windows
onnx==1.16.1
huggingface-hub==0.27.0
diffusers==0.32.1
transformers==4.47.1
transformers==4.46.3
torch==2.5.1
accelerate==1.2.1
pillow==11.1.0
opencv-python==4.10.0.84
numpy==2.1.3
Expand Down

0 comments on commit 72f6133

Please sign in to comment.