Skip to content

Commit

Permalink
watermark (#603)
Browse files Browse the repository at this point in the history
* watermark

Change-Id: I271993e07bfc034d89cd74013339046a21d0b472

* fix typing

Change-Id: Ib026f3b66fb44c6ad35a15030b03204306b35443

* format

Change-Id: I24304a80e6689b4fcec3155f7274e24db712b402

* fix test

Change-Id: I68cb74f8ee5f224942417ea4e5fc4232ec688977

* make check_watermark a stand alone function

Change-Id: I2d72620359dcc70fe8e720a14f78d83f75a42d90

* simplify typing.

Change-Id: I1b901cc40b4b029cb09699fc6eac77690622b6e8

* Update google/generativeai/vision_models/_vision_models.py

* Typo

Co-authored-by: Mark McDonald <macd@google.com>

---------

Co-authored-by: Mark McDonald <macd@google.com>
  • Loading branch information
MarkDaoust and markmcd authored Oct 29, 2024
1 parent e09b902 commit aae0caf
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 5 deletions.
1 change: 0 additions & 1 deletion google/generativeai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@

__version__ = version.__version__

del embedding
del files
del generative_models
del models
Expand Down
5 changes: 4 additions & 1 deletion google/generativeai/types/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import IPython.display

IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image)
ImageType = PIL.Image.Image | IPython.display.Image
else:
IMAGE_TYPES = ()
try:
Expand All @@ -52,6 +53,8 @@
except ImportError:
IPython = None

ImageType = Union["PIL.Image.Image", "IPython.display.Image"]


__all__ = [
"BlobDict",
Expand Down Expand Up @@ -123,7 +126,7 @@ def webp_blob(image: PIL.Image.Image) -> protos.Blob:
return file_blob(image) or webp_blob(image)


def image_to_blob(image) -> protos.Blob:
def image_to_blob(image: ImageType) -> protos.Blob:
if PIL is not None:
if isinstance(image, PIL.Image.Image):
return _pil_to_blob(image)
Expand Down
1 change: 1 addition & 0 deletions google/generativeai/vision_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Classes for working with vision models."""

from google.generativeai.vision_models._vision_models import (
check_watermark,
Image,
GeneratedImage,
ImageGenerationModel,
Expand Down
76 changes: 73 additions & 3 deletions google/generativeai/vision_models/_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@
import base64
import collections
import dataclasses
import hashlib
import io
import json
import os
import pathlib
import typing
from typing import Any, Dict, List, Literal, Optional, Union

from google.generativeai import client
from google.generativeai import protos
from google.generativeai.types import content_types

from google.protobuf import struct_pb2

Expand Down Expand Up @@ -110,6 +111,52 @@ def to_mapping_value(value) -> struct_pb2.Struct:
PersonGeneration = Literal["dont_allow", "allow_adult"]
PERSON_GENERATIONS = PersonGeneration.__args__ # type: ignore

ImageLikeType = Union["Image", pathlib.Path, content_types.ImageType]


def check_watermark(
img: ImageLikeType, model_id: str = "models/image-verification-001"
) -> "CheckWatermarkResult":
"""Checks if an image has a Google-AI watermark.
Args:
img: can be a `pathlib.Path` or a `PIL.Image.Image`, `IPython.display.Image`, or `google.generativeai.Image`.
model_id: Which version of the image-verification model to send the image to.
Returns:
"""
if isinstance(img, Image):
pass
elif isinstance(img, pathlib.Path):
img = Image.load_from_file(img)
elif IPython_display is not None and isinstance(img, IPython_display.Image):
img = Image(image_bytes=img.data)
elif PIL_Image is not None and isinstance(img, PIL_Image.Image):
blob = content_types._pil_to_blob(img)
img = Image(image_bytes=blob.data)
elif isinstance(img, protos.Blob):
img = Image(image_bytes=img.data)
else:
raise TypeError(
f"Not implemented: Could not convert a {type(img)} into `Image`\n {img=}"
)

prediction_client = client.get_default_prediction_client()
if not model_id.startswith("models/"):
model_id = f"models/{model_id}"

instance = {"image": {"bytesBase64Encoded": base64.b64encode(img._loaded_bytes).decode()}}
parameters = {"watermarkVerification": True}

# This is to get around https://github.com/googleapis/proto-plus-python/issues/488
pr = protos.PredictRequest.pb()
request = pr(model=model_id, instances=[to_value(instance)], parameters=to_value(parameters))

response = prediction_client.predict(request)

return CheckWatermarkResult(response.predictions)


class Image:
"""Image."""
Expand All @@ -131,7 +178,7 @@ def __init__(
self._image_bytes = image_bytes

@staticmethod
def load_from_file(location: str) -> "Image":
def load_from_file(location: os.PathLike) -> "Image":
"""Loads image from local file or Google Cloud Storage.
Args:
Expand Down Expand Up @@ -206,6 +253,29 @@ def _as_base64_string(self) -> str:
def _repr_png_(self):
return self._pil_image._repr_png_() # type:ignore

check_watermark = check_watermark


class CheckWatermarkResult:
def __init__(self, predictions):
self._predictions = predictions

@property
def decision(self):
return self._predictions[0]["decision"]

def __str__(self):
return f"CheckWatermarkResult([{{'decision': {self.decision!r}}}])"

def __bool__(self):
decision = self.decision
if decision == "ACCEPT":
return True
elif decision == "REJECT":
return False
else:
raise ValueError(f"Unrecognized result: {decision}")


class ImageGenerationModel:
"""Generates images from text prompt.
Expand Down Expand Up @@ -479,7 +549,7 @@ def generation_parameters(self):
return self._generation_parameters

@staticmethod
def load_from_file(location: str) -> "GeneratedImage":
def load_from_file(location: os.PathLike) -> "GeneratedImage":
"""Loads image from file.
Args:
Expand Down

0 comments on commit aae0caf

Please sign in to comment.