Skip to content

Commit

Permalink
Enum (#529)
Browse files Browse the repository at this point in the history
* try to support enum types

Change-Id: I5141f751c4d6c578ef957aa8250cb26309ea9bd3

* format

Change-Id: I9619654247f0f7230c8ba4c76035ad0ff9324fd4

* Be clear that test uses enum value.

Change-Id: I03e319f2795c7c15f527316a145d021620936c57

* Add samples

Change-Id: Ifc5e5b2039c9f0532d37386f6d7b136961943bac

* Fix type annotations.

Change-Id: I6b7b769cf0ba17fc7188518cdcec3085f59760b0
  • Loading branch information
MarkDaoust authored Aug 27, 2024
1 parent e805b24 commit e0928fc
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 8 deletions.
2 changes: 1 addition & 1 deletion google/generativeai/responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _generate_schema(
inspect.Parameter.POSITIONAL_ONLY,
)
}
parameters = pydantic.create_model(f.__name__, **fields_dict).schema()
parameters = pydantic.create_model(f.__name__, **fields_dict).model_json_schema()
# Postprocessing
# 4. Suppress unnecessary title generation:
# * https://github.com/pydantic/pydantic/issues/1051
Expand Down
7 changes: 5 additions & 2 deletions google/generativeai/types/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def _schema_for_function(


def _build_schema(fname, fields_dict):
parameters = pydantic.create_model(fname, **fields_dict).schema()
parameters = pydantic.create_model(fname, **fields_dict).model_json_schema()
defs = parameters.pop("$defs", {})
# flatten the defs
for name, value in defs.items():
Expand All @@ -424,7 +424,10 @@ def _build_schema(fname, fields_dict):


def unpack_defs(schema, defs):
properties = schema["properties"]
properties = schema.get("properties", None)
if properties is None:
return

for name, value in properties.items():
ref_key = value.get("$ref", None)
if ref_key is not None:
Expand Down
6 changes: 3 additions & 3 deletions google/generativeai/types/generation_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import collections
import contextlib
import sys
from collections.abc import Iterable, AsyncIterable, Mapping
import dataclasses
import itertools
Expand Down Expand Up @@ -165,7 +164,7 @@ class GenerationConfig:
top_p: float | None = None
top_k: int | None = None
response_mime_type: str | None = None
response_schema: protos.Schema | Mapping[str, Any] | None = None
response_schema: protos.Schema | Mapping[str, Any] | type | None = None


GenerationConfigType = Union[protos.GenerationConfig, GenerationConfigDict, GenerationConfig]
Expand All @@ -186,7 +185,8 @@ def _normalize_schema(generation_config):
if not str(response_schema).startswith("list["):
raise ValueError(
f"Invalid input: Could not understand the type of '{response_schema}'. "
"Expected one of the following types: `int`, `float`, `str`, `bool`, `typing_extensions.TypedDict`, `dataclass`, or `list[...]`."
"Expected one of the following types: `int`, `float`, `str`, `bool`, `enum`, "
"`typing_extensions.TypedDict`, `dataclass` or `list[...]`."
)
response_schema = content_types._schema_for_class(response_schema)

Expand Down
51 changes: 49 additions & 2 deletions samples/controlled_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import pathlib

import google.generativeai as genai

media = pathlib.Path(__file__).parents[1] / "third_party"


class UnitTests(absltest.TestCase):
def test_json_controlled_generation(self):
Expand All @@ -22,6 +25,7 @@ def test_json_controlled_generation(self):

class Recipe(typing.TypedDict):
recipe_name: str
ingredients: list[str]

model = genai.GenerativeModel("gemini-1.5-pro-latest")
result = model.generate_content(
Expand All @@ -36,14 +40,57 @@ class Recipe(typing.TypedDict):
def test_json_no_schema(self):
# [START json_no_schema]
model = genai.GenerativeModel("gemini-1.5-pro-latest")
prompt = """List a few popular cookie recipes using this JSON schema:
prompt = """List a few popular cookie recipes in JSON format.
Use this JSON schema:
Recipe = {'recipe_name': str}
Recipe = {'recipe_name': str, 'ingredients': list[str]}
Return: list[Recipe]"""
result = model.generate_content(prompt)
print(result)
# [END json_no_schema]

def test_json_enum(self):
# [START json_enum]
import enum

class Choice(enum.Enum):
PERCUSSION = "Percussion"
STRING = "String"
WOODWIND = "Woodwind"
BRASS = "Brass"
KEYBOARD = "Keyboard"

model = genai.GenerativeModel("gemini-1.5-pro-latest")

organ = genai.upload_file(media / "organ.jpg")
result = model.generate_content(
["What kind of instrument is this:", organ],
generation_config=genai.GenerationConfig(
response_mime_type="application/json", response_schema=Choice
),
)
print(result) # "Keyboard"
# [END json_enum]

def test_json_enum_raw(self):
# [START json_enum_raw]
model = genai.GenerativeModel("gemini-1.5-pro-latest")

organ = genai.upload_file(media / "organ.jpg")
result = model.generate_content(
["What kind of instrument is this:", organ],
generation_config=genai.GenerationConfig(
response_mime_type="application/json",
response_schema={
"type": "STRING",
"enum": ["Percussion", "String", "Woodwind", "Brass", "Keyboard"],
},
),
)
print(result) # "Keyboard"
# [END json_enum_raw]


if __name__ == "__main__":
absltest.main()
32 changes: 32 additions & 0 deletions tests/test_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import enum
import pathlib
import typing_extensions
from typing import Any, Union, Iterable
Expand Down Expand Up @@ -69,6 +70,18 @@ class ADataClassWithList:
a: list[int]


class Choices(enum.Enum):
A = "a"
B = "b"
C = "c"
D = "d"


@dataclasses.dataclass
class HasEnum:
choice: Choices


class UnitTests(parameterized.TestCase):
@parameterized.named_parameters(
["PIL", PIL.Image.open(TEST_PNG_PATH)],
Expand Down Expand Up @@ -551,6 +564,25 @@ def b():
},
),
],
["enum", Choices, protos.Schema(type=protos.Type.STRING, enum=["a", "b", "c", "d"])],
[
"enum_list",
list[Choices],
protos.Schema(
type="ARRAY",
items=protos.Schema(type=protos.Type.STRING, enum=["a", "b", "c", "d"]),
),
],
[
"has_enum",
HasEnum,
protos.Schema(
type=protos.Type.OBJECT,
properties={
"choice": protos.Schema(type=protos.Type.STRING, enum=["a", "b", "c", "d"])
},
),
],
)
def test_auto_schema(self, annotation, expected):
def fun(a: annotation):
Expand Down

0 comments on commit e0928fc

Please sign in to comment.