Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

watermark #603

Merged
merged 8 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion google/generativeai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@

__version__ = version.__version__

del embedding
del files
del generative_models
del models
Expand Down
5 changes: 4 additions & 1 deletion google/generativeai/types/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import IPython.display

IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image)
ImageType = PIL.Image.Image | IPython.display.Image
else:
IMAGE_TYPES = ()
try:
Expand All @@ -52,6 +53,8 @@
except ImportError:
IPython = None

ImageType = Union["PIL.Image.Image", "IPython.display.Image"]


__all__ = [
"BlobDict",
Expand Down Expand Up @@ -123,7 +126,7 @@ def webp_blob(image: PIL.Image.Image) -> protos.Blob:
return file_blob(image) or webp_blob(image)


def image_to_blob(image) -> protos.Blob:
def image_to_blob(image: ImageType) -> protos.Blob:
if PIL is not None:
if isinstance(image, PIL.Image.Image):
return _pil_to_blob(image)
Expand Down
1 change: 1 addition & 0 deletions google/generativeai/vision_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Classes for working with vision models."""

from google.generativeai.vision_models._vision_models import (
check_watermark,
Image,
GeneratedImage,
ImageGenerationModel,
Expand Down
76 changes: 73 additions & 3 deletions google/generativeai/vision_models/_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@
import base64
import collections
import dataclasses
import hashlib
import io
import json
import os
import pathlib
import typing
from typing import Any, Dict, List, Literal, Optional, Union

from google.generativeai import client
from google.generativeai import protos
from google.generativeai.types import content_types

from google.protobuf import struct_pb2

Expand Down Expand Up @@ -110,6 +111,52 @@ def to_mapping_value(value) -> struct_pb2.Struct:
PersonGeneration = Literal["dont_allow", "allow_adult"]
PERSON_GENERATIONS = PersonGeneration.__args__ # type: ignore

ImageLikeType = Union["Image", pathlib.Path, content_types.ImageType]


def check_watermark(
img: ImageLikeType, model_id: str = "models/image-verification-001"
) -> "CheckWatermarkResult":
"""Checks if an image has a Google-AI watermark.
Args:
img: can be a `pathlib.Path` or a `PIL.Image.Image`, `IPython.display.Image`, or `google.generativeai.Image`.
model_id: Which version of the image-verification model to send the image to.
Returns:
"""
if isinstance(img, Image):
pass
elif isinstance(img, pathlib.Path):
img = Image.load_from_file(img)
elif IPython_display is not None and isinstance(img, IPython_display.Image):
img = Image(image_bytes=img.data)
elif PIL_Image is not None and isinstance(img, PIL_Image.Image):
blob = content_types._pil_to_blob(img)
img = Image(image_bytes=blob.data)
elif isinstance(img, protos.Blob):
img = Image(image_bytes=img.data)
else:
raise TypeError(
f"Not implemented: Could not convert a {type(img)} into `Image`\n {img=}"
)

prediction_client = client.get_default_prediction_client()
if not model_id.startswith("models/"):
model_id = f"models/{model_id}"

instance = {"image": {"bytesBase64Encoded": base64.b64encode(img._loaded_bytes).decode()}}
parameters = {"watermarkVerification": True}

# This is to get around https://github.com/googleapis/proto-plus-python/issues/488
pr = protos.PredictRequest.pb()
request = pr(model=model_id, instances=[to_value(instance)], parameters=to_value(parameters))

response = prediction_client.predict(request)

return CheckWatermarkResult(response.predictions)


class Image:
"""Image."""
Expand All @@ -131,7 +178,7 @@ def __init__(
self._image_bytes = image_bytes

@staticmethod
def load_from_file(location: str) -> "Image":
def load_from_file(location: os.PathLike) -> "Image":
"""Loads image from local file or Google Cloud Storage.
Args:
Expand Down Expand Up @@ -206,6 +253,29 @@ def _as_base64_string(self) -> str:
def _repr_png_(self):
return self._pil_image._repr_png_() # type:ignore

check_watermark = check_watermark


class CheckWatermarkResult:
def __init__(self, predictions):
self._predictions = predictions

@property
def decision(self):
return self._predictions[0]["decision"]

def __str__(self):
return f"CheckWatermarkResult([{{'decision': {self.decision!r}}}])"

def __bool__(self):
decision = self.decision
if decision == "ACCEPT":
return True
elif decision == "REJECT":
return False
else:
raise ValueError(f"Unrecognized result: {decision}")


class ImageGenerationModel:
"""Generates images from text prompt.
Expand Down Expand Up @@ -479,7 +549,7 @@ def generation_parameters(self):
return self._generation_parameters

@staticmethod
def load_from_file(location: str) -> "GeneratedImage":
def load_from_file(location: os.PathLike) -> "GeneratedImage":
"""Loads image from file.
Args:
Expand Down
Loading