-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
merge with master + tests + fixes and scripts
- Loading branch information
Showing
6 changed files
with
180 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import argparse | ||
import asyncio | ||
import os | ||
from pprint import pprint | ||
|
||
from faker.utils.decorators import lowercase | ||
|
||
from ai_engine_sdk import AiEngine, FunctionGroup, ApiBaseMessage | ||
from ai_engine_sdk.client import Session | ||
from tests.conftest import function_groups | ||
|
||
|
||
async def main( | ||
target_environment: str, | ||
agentverse_api_key: str, | ||
function_uuid: str, | ||
function_group_uuid: str | ||
): | ||
# Request from cli args. | ||
options = {} | ||
if target_environment: | ||
options = {"api_base_url": target_environment} | ||
|
||
ai_engine = AiEngine(api_key=agentverse_api_key, options=options) | ||
|
||
session: Session = await ai_engine.create_session(function_group=function_group_uuid) | ||
await session.execute_function(function_ids=[function_uuid], objective="", context="") | ||
|
||
try: | ||
empty_count = 0 | ||
session_ended = False | ||
|
||
print("Waiting for execution:") | ||
while empty_count < 100: | ||
messages: list[ApiBaseMessage] = await session.get_messages() | ||
if messages: | ||
pprint(messages) | ||
if any((msg.type.lower() == "stop" for msg in messages)): | ||
print("DONE") | ||
break | ||
if len(messages) % 10 == 0: | ||
print("Wait...") | ||
if len(messages) == 0: | ||
empty_count += 1 | ||
else: | ||
empty_count = 0 | ||
|
||
|
||
except Exception as ex: | ||
pprint(ex) | ||
raise | ||
|
||
if __name__ == '__main__': | ||
from dotenv import load_dotenv | ||
load_dotenv() | ||
api_key = os.getenv("AV_API_KEY", "") | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"-e", | ||
"--target_environment", | ||
type=str, | ||
required=False, | ||
help="The target environment: staging, localhost, production... You need to explicitly add the domain. By default it will be production." | ||
) | ||
parser.add_argument( | ||
"-fg", | ||
"--function_group_uuid", | ||
type=str, | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"-f", | ||
"--function_uuid", | ||
type=str, | ||
required=True, | ||
) | ||
args = parser.parse_args() | ||
|
||
result = asyncio.run( | ||
main( | ||
agentverse_api_key=api_key, | ||
target_environment=args.target_environment, | ||
function_group_uuid=args.function_group_uuid, | ||
function_uuid=args.function_uuid | ||
) | ||
) | ||
pprint(result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import argparse | ||
import asyncio | ||
import os | ||
from pprint import pprint | ||
|
||
from ai_engine_sdk import FunctionGroup, AiEngine | ||
from tests.integration.test_ai_engine_client import api_key | ||
|
||
|
||
async def main( | ||
function_group_name: str, | ||
agentverse_api_key: str, | ||
target_environment: str | None = None, | ||
): | ||
# Request from cli args. | ||
options = {} | ||
if target_environment: | ||
options = {"api_base_url": target_environment} | ||
|
||
ai_engine: AiEngine = AiEngine(api_key=agentverse_api_key, options=options) | ||
function_groups: list[FunctionGroup] = await ai_engine.get_function_groups() | ||
|
||
target_function_group = next((g for g in function_groups if g.name == function_group_name), None) | ||
if target_function_group is None: | ||
raise Exception(f'Could not find "{target_function_group}" function group.') | ||
|
||
return await ai_engine.get_functions_by_function_group(function_group_id=target_function_group.uuid) | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
from dotenv import load_dotenv | ||
load_dotenv() | ||
api_key = os.getenv("AV_API_KEY", "") | ||
|
||
# Parse CLI arguments | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
"-e", | ||
"--target_environment", | ||
type=str, | ||
required=False, | ||
help="The target environment: staging, localhost, production... You need to explicitly add the domain. By default it will be production." | ||
) | ||
parser.add_argument( | ||
"-fgn", | ||
"--fg_name", | ||
type=str, | ||
required=True, | ||
) | ||
args = parser.parse_args() | ||
|
||
target_environment = args.target_environment | ||
|
||
res = asyncio.run( | ||
main( | ||
agentverse_api_key=api_key, | ||
function_group_name=args.fg_name, | ||
target_environment=args.target_environment | ||
) | ||
) | ||
pprint(res) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters