Skip to content

Commit

Permalink
Add more enum samples
Browse files Browse the repository at this point in the history
Change-Id: I743d5967cc1cc91576b8ddf5a60db1767d94508d
  • Loading branch information
MarkDaoust committed Sep 9, 2024
1 parent e0928fc commit 165aeb0
Showing 1 changed file with 71 additions and 0 deletions.
71 changes: 71 additions & 0 deletions samples/controlled_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,36 @@ class Choice(enum.Enum):
print(result) # "Keyboard"
# [END json_enum]

def test_enum_in_json(self):
# [START enum_in_json]
import enum
from typing_extensions import TypedDict

class Grade(enum.Enum):
A_PLUS = "a+"
A = "a"
B = "b"
C = "c"
D = "d"
F = "f"

class Recipe(TypedDict):
recipe_name: str
grade: Grade


model = genai.GenerativeModel("gemini-1.5-pro-latest")

result = model.generate_content(
"List about 10 cookie recipes, grade them based on popularity",
generation_config=genai.GenerationConfig(
response_mime_type="application/json",
response_schema=list[Recipe]
),
)
print(result) # [{"grade": "a+", "recipe_name": "Chocolate Chip Cookies"}, ...]
# [END enum_in_json]

def test_json_enum_raw(self):
# [START json_enum_raw]
model = genai.GenerativeModel("gemini-1.5-pro-latest")
Expand All @@ -92,5 +122,46 @@ def test_json_enum_raw(self):
# [END json_enum_raw]


def test_x_enum(self):
# [START x_enum]
import enum

class Choice(enum.Enum):
PERCUSSION = "Percussion"
STRING = "String"
WOODWIND = "Woodwind"
BRASS = "Brass"
KEYBOARD = "Keyboard"

model = genai.GenerativeModel("gemini-1.5-pro-latest")

organ = genai.upload_file(media / "organ.jpg")
result = model.generate_content(
["What kind of instrument is this:", organ],
generation_config=genai.GenerationConfig(
response_mime_type="text/x.enum", response_schema=Choice
),
)
print(result) # "Keyboard"
# [END x_enum]

def test_x_enum_raw(self):
# [START x_enum_raw]
model = genai.GenerativeModel("gemini-1.5-pro-latest")

organ = genai.upload_file(media / "organ.jpg")
result = model.generate_content(
["What kind of instrument is this:", organ],
generation_config=genai.GenerationConfig(
response_mime_type="text/x.enum",
response_schema={
"type": "STRING",
"enum": ["Percussion", "String", "Woodwind", "Brass", "Keyboard"],
},
),
)
print(result) # "Keyboard"
# [END x_enum_raw]

if __name__ == "__main__":
absltest.main()

0 comments on commit 165aeb0

Please sign in to comment.