Skip to content

Commit

Permalink
merge with master + tests + fixes and scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
XaviPeiro committed Sep 8, 2024
1 parent c2d4e74 commit abb2a63
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 4 deletions.
8 changes: 5 additions & 3 deletions ai_engine_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,11 @@ async def delete(self):

async def execute_function(self, function_ids: list[str], objective: str, context: str|None = None):
await self._submit_message(
payload=ApiUserMessageExecuteFunctions.parse_obj({
payload=ApiUserMessageExecuteFunctions.model_validate({
"functions": function_ids,
"objective": objective,
"context": context or ""
"context": context or "",
'session_id': self.session_id,
})
)

Expand All @@ -285,6 +286,7 @@ def __init__(self, api_key: str, options: Optional[dict] = None):
self._api_base_url = options.get('api_base_url') if options and 'api_base_url' in options else default_api_base_url
self._api_key = api_key


####
# Function groups
####
Expand Down Expand Up @@ -391,7 +393,7 @@ async def get_functions_by_function_group(self, function_group_id: str) -> list[
if "functions" in raw_response:
list(
map(
lambda function_name: FunctionGroupFunctions.parse_obj({"name": function_name}),
lambda function_name: FunctionGroupFunctions.model_validate({"name": function_name}),
raw_response["functions"]
)
)
Expand Down
88 changes: 88 additions & 0 deletions examples/execute_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import argparse
import asyncio
import os
from pprint import pprint

from faker.utils.decorators import lowercase

from ai_engine_sdk import AiEngine, FunctionGroup, ApiBaseMessage
from ai_engine_sdk.client import Session
from tests.conftest import function_groups


async def main(
target_environment: str,
agentverse_api_key: str,
function_uuid: str,
function_group_uuid: str
):
# Request from cli args.
options = {}
if target_environment:
options = {"api_base_url": target_environment}

ai_engine = AiEngine(api_key=agentverse_api_key, options=options)

session: Session = await ai_engine.create_session(function_group=function_group_uuid)
await session.execute_function(function_ids=[function_uuid], objective="", context="")

try:
empty_count = 0
session_ended = False

print("Waiting for execution:")
while empty_count < 100:
messages: list[ApiBaseMessage] = await session.get_messages()
if messages:
pprint(messages)
if any((msg.type.lower() == "stop" for msg in messages)):
print("DONE")
break
if len(messages) % 10 == 0:
print("Wait...")
if len(messages) == 0:
empty_count += 1
else:
empty_count = 0


except Exception as ex:
pprint(ex)
raise

if __name__ == '__main__':
from dotenv import load_dotenv
load_dotenv()
api_key = os.getenv("AV_API_KEY", "")

parser = argparse.ArgumentParser()
parser.add_argument(
"-e",
"--target_environment",
type=str,
required=False,
help="The target environment: staging, localhost, production... You need to explicitly add the domain. By default it will be production."
)
parser.add_argument(
"-fg",
"--function_group_uuid",
type=str,
required=True,
)
parser.add_argument(
"-f",
"--function_uuid",
type=str,
required=True,
)
args = parser.parse_args()

result = asyncio.run(
main(
agentverse_api_key=api_key,
target_environment=args.target_environment,
function_group_uuid=args.function_group_uuid,
function_uuid=args.function_uuid
)
)
pprint(result)
63 changes: 63 additions & 0 deletions examples/get_function_from_function_group_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import argparse
import asyncio
import os
from pprint import pprint

from ai_engine_sdk import FunctionGroup, AiEngine
from tests.integration.test_ai_engine_client import api_key


async def main(
function_group_name: str,
agentverse_api_key: str,
target_environment: str | None = None,
):
# Request from cli args.
options = {}
if target_environment:
options = {"api_base_url": target_environment}

ai_engine: AiEngine = AiEngine(api_key=agentverse_api_key, options=options)
function_groups: list[FunctionGroup] = await ai_engine.get_function_groups()

target_function_group = next((g for g in function_groups if g.name == function_group_name), None)
if target_function_group is None:
raise Exception(f'Could not find "{target_function_group}" function group.')

return await ai_engine.get_functions_by_function_group(function_group_id=target_function_group.uuid)



if __name__ == "__main__":
from dotenv import load_dotenv
load_dotenv()
api_key = os.getenv("AV_API_KEY", "")

# Parse CLI arguments
parser = argparse.ArgumentParser()

parser.add_argument(
"-e",
"--target_environment",
type=str,
required=False,
help="The target environment: staging, localhost, production... You need to explicitly add the domain. By default it will be production."
)
parser.add_argument(
"-fgn",
"--fg_name",
type=str,
required=True,
)
args = parser.parse_args()

target_environment = args.target_environment

res = asyncio.run(
main(
agentverse_api_key=api_key,
function_group_name=args.fg_name,
target_environment=args.target_environment
)
)
pprint(res)
File renamed without changes.
15 changes: 14 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,17 @@ async def function_groups(ai_engine_client) -> list[FunctionGroup]:
# session: Session = await ai_engine_client.create_session(
# function_group=function_groups, opts={"model": "next-gen"}
# )
# return session
# return session


@pytest.fixture(scope="session")
def valid_public_function_uuid() -> str:
# TODO: Do it programmatically (when test fails bc of it will be good moment)
# 'Cornerstone Software' from Public fg and staging
return "312712ae-eb70-42f7-bb5a-ad21ce6d73c3"


@pytest.fixture(scope="session")
def public_function_group() -> FunctionGroup:
# TODO: Do it programmatically (when test fails bc of it will be good moment)
return FunctionGroup(uuid="e504eabb-4bc7-458d-aa8c-7c3748f8952c", name="Public", isPrivate=False)
10 changes: 10 additions & 0 deletions tests/integration/test_ai_engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ async def test_create_session(self, ai_engine_client: AiEngine):
# await ai_engine_client.delete_function_group()


@pytest.mark.asyncio
async def test_execute_function(self, ai_engine_client: AiEngine, public_function_group: FunctionGroup, valid_public_function_uuid: str):
session: Session = await ai_engine_client.create_session(function_group=public_function_group.uuid)
result = await session.execute_function(
function_ids=[valid_public_function_uuid],
objective="Test software",
context=""
)


@pytest.mark.asyncio
async def test_create_function_group_and_list_them(self, ai_engine_client: AiEngine):
name = fake.company()
Expand Down

0 comments on commit abb2a63

Please sign in to comment.