diff --git a/samples/controlled_generation.py b/samples/controlled_generation.py index 4942481f6..5c9c362a5 100644 --- a/samples/controlled_generation.py +++ b/samples/controlled_generation.py @@ -73,6 +73,36 @@ class Choice(enum.Enum): print(result) # "Keyboard" # [END json_enum] + def test_enum_in_json(self): + # [START enum_in_json] + import enum + from typing_extensions import TypedDict + + class Grade(enum.Enum): + A_PLUS = "a+" + A = "a" + B = "b" + C = "c" + D = "d" + F = "f" + + class Recipe(TypedDict): + recipe_name: str + grade: Grade + + + model = genai.GenerativeModel("gemini-1.5-pro-latest") + + result = model.generate_content( + "List about 10 cookie recipes, grade them based on popularity", + generation_config=genai.GenerationConfig( + response_mime_type="application/json", + response_schema=list[Recipe] + ), + ) + print(result) # [{"grade": "a+", "recipe_name": "Chocolate Chip Cookies"}, ...] + # [END enum_in_json] + def test_json_enum_raw(self): # [START json_enum_raw] model = genai.GenerativeModel("gemini-1.5-pro-latest") @@ -92,5 +122,46 @@ def test_json_enum_raw(self): # [END json_enum_raw] + def test_x_enum(self): + # [START x_enum] + import enum + + class Choice(enum.Enum): + PERCUSSION = "Percussion" + STRING = "String" + WOODWIND = "Woodwind" + BRASS = "Brass" + KEYBOARD = "Keyboard" + + model = genai.GenerativeModel("gemini-1.5-pro-latest") + + organ = genai.upload_file(media / "organ.jpg") + result = model.generate_content( + ["What kind of instrument is this:", organ], + generation_config=genai.GenerationConfig( + response_mime_type="text/x.enum", response_schema=Choice + ), + ) + print(result) # "Keyboard" + # [END x_enum] + + def test_x_enum_raw(self): + # [START x_enum_raw] + model = genai.GenerativeModel("gemini-1.5-pro-latest") + + organ = genai.upload_file(media / "organ.jpg") + result = model.generate_content( + ["What kind of instrument is this:", organ], + generation_config=genai.GenerationConfig( + response_mime_type="text/x.enum", + response_schema={ + "type": "STRING", + "enum": ["Percussion", "String", "Woodwind", "Brass", "Keyboard"], + }, + ), + ) + print(result) # "Keyboard" + # [END x_enum_raw] + if __name__ == "__main__": absltest.main()