diff --git a/intro/index.html b/intro/index.html index f9dc5ca..c1c7af1 100644 --- a/intro/index.html +++ b/intro/index.html @@ -541,6 +541,15 @@ + + +
- +
In this project, our objective is to develop a modular and flexible intelligent agent and society system, designed as a virtual assistant capable of performing diverse tasks, learning from data, environment, and interactions, and self-evolving over time. The system will leverage deep learning models, primarily transformers, while also exploring innovative models and learning methods.
-Our ultimate goal is to develop a General AI Agent System capable of forming a “genius society” of AI agents. These agents will:
+Our ultimate goal is to develop a General AI Agent System capable of forming a “genius society” of AI agents. These agents will:
Currently, Aeiva supports the following interaction modes:
Maid
, we can use our agent as the backend and call it through Maid desktop assistant. Maid
, we can use our agent as the backend and call it through Maid desktop assistant. ⭐️ Documentation 👉 aeiva documentation
Currently, we features with the following functionalities:
More functionalities and modules will be implemented gradually. Keep tuned! If you find any errors or bugs, feel free to report by opening an issue, thanks a lot!
@@ -890,7 +909,7 @@To install AEIVA, follow these steps:
Python 3.9
or newerPython 3.10
or newerpip
(Python package manager)pip
[recommended]You will see your terminal is like below:
- + + ++ +
+Run the following command in terminal:
aeiva-chat-gradio --config configs/agent_config.yaml --verbose
@@ -1005,7 +1029,12 @@ 🪄⭐Aeiva Chat in Gradio Mode
By visiting the gradio interface, you will see a gradio web-ui like below:
-
+
+
+
+
+
+
🪄⭐Aeiva Server
Run the following command in terminal:
@@ -1044,7 +1073,7 @@ 🪄⭐Maid Chat (Your
Download Maid.app
:
-- Download
Maid.app
from [provide download link or instructions].
+- Download
Maid.app
from here.
@@ -1112,10 +1141,26 @@ 🪄⭐Maid Chat (Your
-Screenshot of Maid-chat:
-
+Demo of Maid-chat:
+
+
+
+
+
+
+
+
+Citation
+To cite Aeiva in publications, please use the following BibTeX entries.
+@misc{bang2024aeiva,
+ title={Aeiva: An Evolving Intelligent Virtual Assistant},
+ author={Bang Liu},
+ year={2024},
+ url={https://github.com/chatsci/Aeiva}
+}
+
Contact
-
+
diff --git a/search/search_index.json b/search/search_index.json
index 541b0d0..a01a5bf 100644
--- a/search/search_index.json
+++ b/search/search_index.json
@@ -1 +1 @@
-{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"]},"docs":[{"location":"","title":"Welcome to Aeiva","text":"Home page...
"},{"location":"blogs/","title":"Blogs","text":"Here we summarize some experience we learned during developing Aeiva.
Thoughts on Several Key Concepts for Agentic Intelligence
"},{"location":"intro/","title":"Introduction","text":""},{"location":"intro/#aeiva-an-evolving-intelligent-virtual-assistant","title":"AEIVA: An Evolving Intelligent Virtual Assistant","text":""},{"location":"intro/#introduction","title":"Introduction","text":"In this project, our objective is to develop a modular and flexible intelligent agent and society system, designed as a virtual assistant capable of performing diverse tasks, learning from data, environment, and interactions, and self-evolving over time. The system will leverage deep learning models, primarily transformers, while also exploring innovative models and learning methods.
Our ultimate goal is to develop a General AI Agent System capable of forming a \u201cgenius society\u201d of AI agents. These agents will:
- Collaboratively address and solve societal challenges across domains.
- Function in diverse environments, from virtual simulations to real-world applications.
- Continuously evolve and improve through self-assessment and adaptation.
- Serve as versatile assistants in various roles, such as AI researchers, software engineers, game players, or digital society members.
Currently, Aeiva supports the following interaction modes:
- Chat in terminal: chat with an agent in the terminal interface
- Chat with Gradio Webui: we developed a gradio web UI interface that allows user to chat with the agent. We plan to support multimodality in the near future.
- Chat with desktop Waifu mode: by combining with our another project
Maid
, we can use our agent as the backend and call it through Maid desktop assistant.
"},{"location":"intro/#key-features","title":"Key Features","text":"Currently, we features with the following functionalities:
- Rich Toolkits: I have implemented a series of different API tools and I'm keep improving the API library.
- Open Operator: By implementing computer-use related tools, aeiva is able to understand and operate user's computer and complete daily tasks. We are keep enhancing the functionality in this part. Note: use this feature with caution!
- Memory Palace: I have designed and implemented a layered memory palace for storaging agent memories. It is flexible and can be customized to represent and query different types of memories.
More functionalities and modules will be implemented gradually. Keep tuned! If you find any errors or bugs, feel free to report by opening an issue, thanks a lot!
"},{"location":"intro/#installation","title":"Installation","text":"To install AEIVA, follow these steps:
"},{"location":"intro/#prerequisites","title":"Prerequisites","text":" Python 3.9
or newer pip
(Python package manager)
"},{"location":"intro/#option-1-install-via-pip-recommended","title":"Option 1: Install via pip
[recommended]","text":"You can easily install vai pip by:
pip install aeiva\n
"},{"location":"intro/#option-2-install-from-repository","title":"Option 2: Install from Repository","text":" -
Clone the AEIVA Repository
First, clone the AEIVA repository to your local machine using Git:
bash git clone https://github.com/chatsci/Aeiva.git cd Aeiva
-
Create a Virtual Environment (Recommended) It's a good practice to create a virtual environment for Python projects. This keeps dependencies required by different projects separate. Use the following command to create a virtual environment with conda
:
bash conda create --name <my-env>
Replace <my-env>
with the name of your environment.
To acivate your env:
bash conda activate <my-env>
For more advanced configurations or options, please check the online document of conda
.
-
Install Dependencies Install all dependencies listed in requirements.txt:
bash pip install -r requirements.txt
-
Install Aeiva Finally, install AEIVA using the setup.py script:
bash python setup.py install
-
Verify Installation To verify that AEIVA has been installed correctly, you can run the following command:
bash python -c \"import aeiva; print(aeiva.__version__)\"
"},{"location":"intro/#dependencies","title":"Dependencies","text":"Our memory module utilizes different types of databases.
-
Vector Database: Our memory module also utilizes vector database. Please install vector database such as milvus
(recommended), chroma
, qdrant
, or weaviate
.
-
Graph Database: Ensure Neo4j is installed and the NEO4J_HOME
environment variable is set.
-
Relational Database: We use sqlite
(recommended) or postgre sql
.
"},{"location":"intro/#commands","title":"Commands","text":"After installing Neo4j and setting the environment variable, follow these steps to run different aeiva chat commands.
"},{"location":"intro/#aeiva-chat-in-terminal-mode","title":"\ud83e\ude84\u2b50Aeiva Chat in Terminal Mode","text":"Run the following command in terminal:
aeiva-chat-terminal --config configs/agent_config.yaml --verbose\n
-
Options:
--config
or -c
: Path to the configuration file (default: configs/agent_config.yaml
). --verbose
or -v
: Enable verbose logging for detailed output.
-
Using the Interface:
- Interact with the chatbot directly in your terminal after running the command. * View Logs:
- Logs are stored at
~/.aeiva/logs/aeiva-chat-terminal.log
. - To monitor logs in real-time, use:
shell tail -f ~/.aeiva/logs/aeiva-chat-terminal.log
You will see your terminal is like below:
"},{"location":"intro/#aeiva-chat-in-gradio-mode","title":"\ud83e\ude84\u2b50Aeiva Chat in Gradio Mode","text":"Run the following command in terminal:
aeiva-chat-gradio --config configs/agent_config.yaml --verbose\n
-
Options:
--config
or -c
: Path to the configuration file (default: configs/agent_config.yaml
). --verbose
or -v
: Enable verbose logging for detailed output.
-
Access the Gradio Interface:
- Open your web browser and navigate to http://localhost:7860.
- Alternatively, use the public URL provided in the terminal output (e.g., https://1b1f89328e57b2f2e1.gradio.live) to access the interface remotely.
- View Logs:
- Logs are stored at
~/.aeiva/logs/aeiva-chat-gradio.log
. - To monitor logs in real-time, use:
shell tail -f ~/.aeiva/logs/aeiva-chat-gradio.log
By visiting the gradio interface, you will see a gradio web-ui like below:
"},{"location":"intro/#aeiva-server","title":"\ud83e\ude84\u2b50Aeiva Server","text":"Run the following command in terminal:
aeiva-server --config configs/agent_config.yaml --host 0.0.0.0 --port 8000 --verbose\n
- Options:
--config
or -c
: Path to the configuration file (default: configs/agent_config.yaml). --host
or -H
: Host address to run the server on (default: 0.0.0.0). --port
or -p
: Port number to run the server on (default: 8000). --verbose
or -v
: Enable verbose logging for detailed output.
- Access the Server:
- Open your web browser and navigate to
http://localhost:8000/docs
to access the interactive API documentation.
- View Logs:
- Logs are stored at
~/.aeiva/logs/aeiva-server.log
. - To monitor logs in real-time, use:
shell tail -f ~/.aeiva/logs/aeiva-server.log
"},{"location":"intro/#maid-chat-your-intelligent-assistant-on-desktop","title":"\ud83e\ude84\u2b50Maid Chat (Your Intelligent Assistant on Desktop!)","text":"Run the following command in terminal to get an animated virtual assisatnt on your deskto that you can talk in voice mode or by typing:
maid-chat --config configs/agent_config.yaml --host 0.0.0.0 --port 8000 --verbose\n
- Options:
--config
or -c
: Path to the configuration file (default: configs/agent_config.yaml
). --host
or -H
: Host address to run the server on (default: 0.0.0.0
). --port
or -p
: Port number to run the server on (default: 8000
). --verbose
or -v
: Enable verbose logging for detailed output.
- Download
Maid.app
: - Download
Maid.app
from [provide download link or instructions].
-
Set MAID_HOME
Environment Variable:
- Unix/Linux/macOS:
shell export MAID_HOME='/path/to/my/unity.app/Contents/MacOS/Maid - Your Intelligent Waifu !' source ~/.bashrc # or source ~/.zshrc
- Windows (Command Prompt):
shell set MAID_HOME=C:\\path\\to\\my\\unity\\app
- Windows (PowerShell):
shell $env:MAID_HOME = \"C:\\path\\to\\my\\unity\\app\"
Replace /path/to/my/unity/app
or C:\\path\\to\\my\\unity\\app
with the actual path to your Unity application.
-
Using the Interface:
- Interact with the server through the Maid.app Unity application after running the command.
- View Logs:
- Logs are stored at
~/.aeiva/logs/maid-chat.log
. - To monitor logs in real-time, use:
shell tail -f ~/.aeiva/logs/maid-chat.log
- Troubleshooting:
-
Permission Denied Error When Starting Unity Application: If you encounter an error like: Error: Failed to start Unity application: [Errno 13] Permission denied: '/path/to/my/unity/app'
Solution:
-
macOS Users:
- Open System Preferences.
- Navigate to Security & Privacy.
- Click on the Privacy tab.
- Select Accessibility from the sidebar.
- Click the lock icon to make changes and enter your password.
- Click the \"+\" button and add your terminal application (e.g., Terminal, iTerm).
- Ensure that your terminal application is checked, granting it the necessary permissions to run the Unity application.
-
Windows Users:
- Right-click on the Unity application executable.
- Select Properties.
- Go to the Compatibility tab.
- Check Run this program as an administrator.
- Click Apply, then OK.
- Try running the command again.
Ensure that the MAID_HOME
environment variable points to the correct path of your Unity application.
Screenshot of Maid-chat:
"},{"location":"intro/#contact","title":"Contact","text":""},{"location":"reference/","title":"Reference","text":"This part of the project documentation focuses on an information-oriented approach. Use it as a reference for the technical implementation of the Aeiva
project code.
"},{"location":"reference/#aeiva-api-references","title":"Aeiva API references","text":""},{"location":"reference/#src.aeiva.action","title":"action
","text":""},{"location":"reference/#src.aeiva.action.action","title":"action
","text":""},{"location":"reference/#src.aeiva.action.action.Action","title":"Action
","text":" Bases: Step
Represents an action that can be executed, extending from the Step class. An action is a tool with states and state management methods. It can execute functionality.
Source code in src/aeiva/action/action.py
class Action(Step):\n \"\"\"\n Represents an action that can be executed, extending from the Step class.\n An action is a tool with states and state management methods. It can execute functionality.\n \"\"\"\n\n def __init__(self, name: str, params: Dict[str, Any] = None,\n id: str = None, dependent_ids: Optional[List[str]] = None, \n type: Optional[str] = None, description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n super().__init__(name=name, params=params,\n id=id, dependent_ids=dependent_ids,\n type=type, description=description,\n metadata=metadata)\n self.type = \"Action\"\n self.tool = Tool(name)\n self.result = None\n\n def reset(self) -> None:\n \"\"\"\n Resets the step status, making it ready for re-execution.\n \"\"\"\n self.result = None\n self.status = Status.NOT_EXECUTED\n\n async def execute(self, params: Dict[str, Any]) -> Any:\n if self.tool is None:\n raise ValueError(f\"Action {self.id} has no tool assigned for execution.\")\n\n self.start()\n try:\n result = await self.tool.execute(params) # Assuming the tool's execute method is async\n self.end(success=True)\n self.result = result\n return result\n except Exception as e:\n self.end(success=False)\n raise RuntimeError(f\"Action {self.id} failed: {str(e)}\")\n
"},{"location":"reference/#src.aeiva.action.action.Action.reset","title":"reset()
","text":"Resets the step status, making it ready for re-execution.
Source code in src/aeiva/action/action.py
def reset(self) -> None:\n \"\"\"\n Resets the step status, making it ready for re-execution.\n \"\"\"\n self.result = None\n self.status = Status.NOT_EXECUTED\n
"},{"location":"reference/#src.aeiva.action.action_system","title":"action_system
","text":""},{"location":"reference/#src.aeiva.action.action_system.ActionSystem","title":"ActionSystem
","text":"A concrete Action System responsible for translating Plans into executable Skills and managing the execution of Skills.
Source code in src/aeiva/action/action_system.py
class ActionSystem:\n \"\"\"\n A concrete Action System responsible for translating Plans into executable Skills\n and managing the execution of Skills.\n \"\"\"\n\n def __init__(self, config: Dict):\n self.config = config\n self.state = {\n \"current_skill\": None,\n \"execution_status\": \"Not Started\",\n }\n self.tools = []\n self.skill = None\n\n def setup(self) -> None:\n if \"tools\" in self.config.keys():\n for tool_name in self.config[\"tools\"]:\n self.tools.append(Tool.load_tool_schema(tool_name))\n print(\"ActionSystem setup complete.\")\n\n def plan_to_skill(self, plan: Plan) -> Skill:\n actions = []\n\n for task in plan.steps:\n if isinstance(task, Task):\n action = Action(\n name=task.name,\n params=task.params,\n id=task.id,\n dependent_ids=task.dependent_ids,\n type=\"Action\",\n description=task.description,\n metadata=task.metadata\n )\n actions.append(action)\n elif isinstance(task, Plan):\n sub_skill = self.plan_to_skill(task) # Recursively handle sub-plans\n actions.append(sub_skill)\n else:\n raise TypeError(f\"Unexpected step type: {type(task)} in plan {plan.id}\")\n\n if not actions:\n raise ValueError(f\"The plan {plan.id} does not contain any valid actions or sub-plans.\")\n\n return Skill(\n name=plan.name,\n steps=actions,\n id=plan.id,\n dependent_ids=plan.dependent_ids,\n type=\"Skill\",\n description=plan.description,\n metadata=plan.metadata\n )\n\n async def execute(self, plan: Plan) -> None:\n self.state[\"execution_status\"] = \"Executing\"\n\n try:\n self.skill = self.plan_to_skill(plan) \n await self.skill.execute() \n self.state[\"execution_status\"] = \"Completed\" if self.skill.is_successful else \"Failed\"\n except Exception as e:\n self.state[\"execution_status\"] = \"Failed\"\n self.handle_error(e)\n raise # Ensure to re-throw the exception\n\n def handle_error(self, error: Exception) -> None:\n print(f\"ActionSystem encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.action.experience","title":"experience
","text":""},{"location":"reference/#src.aeiva.action.experience.Experience","title":"Experience
","text":" Bases: Procedure
Represents an experience, which is a structured composition of actions. Unlike a skill, an experience cannot be executed until it is validated and transformed into a skill.
Attributes:
Name Type Description owner
str
The person or agent who owns the experience.
reliable
bool
A flag indicating whether the experience is reliable enough to be transformed into a skill.
Source code in src/aeiva/action/experience.py
class Experience(Procedure):\n \"\"\"\n Represents an experience, which is a structured composition of actions.\n Unlike a skill, an experience cannot be executed until it is validated and transformed into a skill.\n\n Attributes:\n owner (str): The person or agent who owns the experience.\n reliable (bool): A flag indicating whether the experience is reliable enough to be transformed into a skill.\n \"\"\"\n\n def __init__(self, name: str, steps: List[Union['Experience', Action]],\n owner: Optional[str] = None, reliable: Optional[bool] = False,\n id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,\n type: Optional[str] = None, description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n \"\"\"\n Initializes a Skill by extending Procedure.\n \"\"\"\n super().__init__(name=name, steps=steps,\n id=id, dependent_ids=dependent_ids,\n type=type, description=description,\n metadata=metadata)\n self.type = \"Experience\"\n self.owner = owner # The owner of the experience\n self.reliable = reliable # Whether the experience can be transformed into a skill. \n # We can use metadata to store some scored and decide whether it is reliable.\n\n @property\n def is_reliable(self) -> bool:\n \"\"\"\n Checks if the experience is reliable enough to be transformed into a skill.\n \"\"\"\n return self.reliable\n\n def mark_reliable(self) -> None:\n \"\"\"\n Marks the experience as reliable, allowing it to be transformed into a skill.\n \"\"\"\n self.reliable = True\n\n def to_skill(self) -> Skill:\n \"\"\"\n Converts this experience into a skill, but only if the experience is marked as reliable.\n If the experience is not reliable, raises a ValueError.\n\n Returns:\n Skill: A new Skill object that is based on the actions from this experience.\n \"\"\"\n if not self.reliable:\n raise ValueError(f\"Experience {self.id} cannot be transformed into a skill because it is not marked as reliable.\")\n\n # Create and return a new Skill instance\n return Skill(\n name=self.name,\n steps=self.steps, # Use the same steps (actions) from the experience\n id=self.id,\n dependent_ids=self.dependent_ids,\n type=\"Skill\",\n description=f\"Skill derived from Experience: {self.description}\", \n metadata=self.metadata\n )\n\n def to_dict(self) -> Dict[str, Any]:\n \"\"\"\n Returns a dictionary representation of the object.\n \"\"\"\n experience_dict = super().to_dict()\n experience_dict.update({\n \"owner\": self.owner,\n \"reliable\": self.reliable,\n })\n return experience_dict\n
"},{"location":"reference/#src.aeiva.action.experience.Experience.is_reliable","title":"is_reliable: bool
property
","text":"Checks if the experience is reliable enough to be transformed into a skill.
"},{"location":"reference/#src.aeiva.action.experience.Experience.__init__","title":"__init__(name, steps, owner=None, reliable=False, id=None, dependent_ids=None, type=None, description=None, metadata=None)
","text":"Initializes a Skill by extending Procedure.
Source code in src/aeiva/action/experience.py
def __init__(self, name: str, steps: List[Union['Experience', Action]],\n owner: Optional[str] = None, reliable: Optional[bool] = False,\n id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,\n type: Optional[str] = None, description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n \"\"\"\n Initializes a Skill by extending Procedure.\n \"\"\"\n super().__init__(name=name, steps=steps,\n id=id, dependent_ids=dependent_ids,\n type=type, description=description,\n metadata=metadata)\n self.type = \"Experience\"\n self.owner = owner # The owner of the experience\n self.reliable = reliable # Whether the experience can be transformed into a skill. \n
"},{"location":"reference/#src.aeiva.action.experience.Experience.mark_reliable","title":"mark_reliable()
","text":"Marks the experience as reliable, allowing it to be transformed into a skill.
Source code in src/aeiva/action/experience.py
def mark_reliable(self) -> None:\n \"\"\"\n Marks the experience as reliable, allowing it to be transformed into a skill.\n \"\"\"\n self.reliable = True\n
"},{"location":"reference/#src.aeiva.action.experience.Experience.to_dict","title":"to_dict()
","text":"Returns a dictionary representation of the object.
Source code in src/aeiva/action/experience.py
def to_dict(self) -> Dict[str, Any]:\n \"\"\"\n Returns a dictionary representation of the object.\n \"\"\"\n experience_dict = super().to_dict()\n experience_dict.update({\n \"owner\": self.owner,\n \"reliable\": self.reliable,\n })\n return experience_dict\n
"},{"location":"reference/#src.aeiva.action.experience.Experience.to_skill","title":"to_skill()
","text":"Converts this experience into a skill, but only if the experience is marked as reliable. If the experience is not reliable, raises a ValueError.
Returns:
Name Type Description Skill
Skill
A new Skill object that is based on the actions from this experience.
Source code in src/aeiva/action/experience.py
def to_skill(self) -> Skill:\n \"\"\"\n Converts this experience into a skill, but only if the experience is marked as reliable.\n If the experience is not reliable, raises a ValueError.\n\n Returns:\n Skill: A new Skill object that is based on the actions from this experience.\n \"\"\"\n if not self.reliable:\n raise ValueError(f\"Experience {self.id} cannot be transformed into a skill because it is not marked as reliable.\")\n\n # Create and return a new Skill instance\n return Skill(\n name=self.name,\n steps=self.steps, # Use the same steps (actions) from the experience\n id=self.id,\n dependent_ids=self.dependent_ids,\n type=\"Skill\",\n description=f\"Skill derived from Experience: {self.description}\", \n metadata=self.metadata\n )\n
"},{"location":"reference/#src.aeiva.action.plan","title":"plan
","text":""},{"location":"reference/#src.aeiva.action.plan.Plan","title":"Plan
","text":" Bases: Procedure
Represents a plan, which is a structured roadmap for achieving a goal by executing tasks and sub-plans. Inherits common functionality from Procedure.
Source code in src/aeiva/action/plan.py
class Plan(Procedure):\n \"\"\"\n Represents a plan, which is a structured roadmap for achieving a goal by executing tasks and sub-plans.\n Inherits common functionality from Procedure.\n \"\"\"\n\n def __init__(self, name: str, steps: List[Union['Plan', Task]],\n id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,\n type: Optional[str] = None, description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n \"\"\"\n Initializes a Skill by extending Procedure.\n \"\"\"\n super().__init__(name=name, steps=steps,\n id=id, dependent_ids=dependent_ids,\n type=type, description=description,\n metadata=metadata)\n self.type = \"Plan\"\n
"},{"location":"reference/#src.aeiva.action.plan.Plan.__init__","title":"__init__(name, steps, id=None, dependent_ids=None, type=None, description=None, metadata=None)
","text":"Initializes a Skill by extending Procedure.
Source code in src/aeiva/action/plan.py
def __init__(self, name: str, steps: List[Union['Plan', Task]],\n id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,\n type: Optional[str] = None, description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n \"\"\"\n Initializes a Skill by extending Procedure.\n \"\"\"\n super().__init__(name=name, steps=steps,\n id=id, dependent_ids=dependent_ids,\n type=type, description=description,\n metadata=metadata)\n self.type = \"Plan\"\n
"},{"location":"reference/#src.aeiva.action.procedure","title":"procedure
","text":""},{"location":"reference/#src.aeiva.action.procedure.Procedure","title":"Procedure
","text":"Abstract base class for composite structures like Plan and Skill. Contains shared attributes and methods for organizing and managing steps (e.g., tasks, sub-procedures) in a directed acyclic graph (DAG).
Source code in src/aeiva/action/procedure.py
class Procedure:\n \"\"\"\n Abstract base class for composite structures like Plan and Skill.\n Contains shared attributes and methods for organizing and managing steps (e.g., tasks, sub-procedures) \n in a directed acyclic graph (DAG).\n \"\"\"\n\n def __init__(self, name: str, steps: List[Union['Procedure', Step]],\n id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,\n type: Optional[str] = None, description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None,\n *args, **kwargs):\n self.name = name\n self.steps = steps\n self.id = id\n self.dependent_ids = dependent_ids or []\n self.type = type\n self.description = description\n self.metadata = metadata or {}\n\n self.graph = nx.DiGraph()\n self.step_map = {step.id: step for step in steps}\n self.status = Status.NOT_EXECUTED\n\n # Add all steps as nodes in the graph\n for step in steps:\n self.graph.add_node(step)\n\n # Handle dependencies for steps\n for step in steps:\n for dep_id in step.dependent_ids:\n if dep_id in self.step_map:\n self.graph.add_edge(self.step_map[dep_id], step)\n else:\n raise ValueError(f\"Dependency {dep_id} not found for step {step.id}.\")\n\n def get_topological_sort(self):\n \"\"\"\n Returns the steps in topologically sorted order based on the dependency graph.\n Ensures that all prerequisite steps are executed before the dependent ones.\n \"\"\"\n try:\n return list(nx.topological_sort(self.graph))\n except nx.NetworkXUnfeasible:\n raise ValueError(\"The dependency graph contains cycles, which is not allowed in a procedure.\")\n\n def reset(self) -> None:\n \"\"\"\n Resets the status of the procedure and all its steps.\n \"\"\"\n self.status = Status.NOT_EXECUTED\n for step in self.steps:\n step.reset()\n\n def start(self) -> None:\n \"\"\"\n Marks the procedure as in progress. Raises an error if it's already in progress or finished.\n \"\"\"\n if self.status != Status.NOT_EXECUTED:\n raise ValueError(f\"{self.type} {self.id} has already been started or finished.\")\n self.status = Status.EXECUTING\n\n def end(self, success: bool) -> None:\n \"\"\"\n Marks the procedure as completed. Raises an error if it hasn't started yet.\n \"\"\"\n if self.status != Status.EXECUTING:\n raise ValueError(f\"Cannot finish {self.type} that hasn't started.\")\n self.status = Status.SUCCESS if success else Status.FAIL\n\n @property\n def is_successful(self) -> bool:\n return all(step.is_successful for step in self.steps)\n\n @property\n def is_failed(self) -> bool:\n return any(step.is_failed for step in self.steps)\n\n @property\n def is_in_progress(self) -> bool:\n return any(step.is_in_progress for step in self.steps)\n\n @property\n def is_not_started(self) -> bool:\n return all(step.is_not_started for step in self.steps)\n\n @property\n def is_finished(self) -> bool:\n return all(step.is_finished for step in self.steps)\n\n def visualize(self, save_path: Optional[str] = None):\n \"\"\"\n Visualizes the procedure's structure using networkx and matplotlib.\n \"\"\"\n pos = nx.spring_layout(self.graph) # Layout for the graph\n labels = {node: f\"{node.id} ({node.description}, {node.status})\" for node in self.graph.nodes()}\n\n # Draw the graph with labels\n nx.draw(self.graph, pos, with_labels=True, labels=labels, node_color='lightblue', node_size=2000, font_size=10, font_weight='bold', arrows=True)\n\n plt.title(f\"{self.type} {self.description} Visualization\")\n if save_path:\n plt.savefig(save_path)\n else:\n plt.show()\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n \"name\": self.name,\n \"steps\": [step.to_dict() for step in self.steps],\n \"id\": self.id,\n \"dependent_ids\": self.dependent_ids,\n \"type\": self.type,\n \"description\": self.description,\n \"metadata\": self.metadata,\n \"status\": self.status\n }\n
"},{"location":"reference/#src.aeiva.action.procedure.Procedure.end","title":"end(success)
","text":"Marks the procedure as completed. Raises an error if it hasn't started yet.
Source code in src/aeiva/action/procedure.py
def end(self, success: bool) -> None:\n \"\"\"\n Marks the procedure as completed. Raises an error if it hasn't started yet.\n \"\"\"\n if self.status != Status.EXECUTING:\n raise ValueError(f\"Cannot finish {self.type} that hasn't started.\")\n self.status = Status.SUCCESS if success else Status.FAIL\n
"},{"location":"reference/#src.aeiva.action.procedure.Procedure.get_topological_sort","title":"get_topological_sort()
","text":"Returns the steps in topologically sorted order based on the dependency graph. Ensures that all prerequisite steps are executed before the dependent ones.
Source code in src/aeiva/action/procedure.py
def get_topological_sort(self):\n \"\"\"\n Returns the steps in topologically sorted order based on the dependency graph.\n Ensures that all prerequisite steps are executed before the dependent ones.\n \"\"\"\n try:\n return list(nx.topological_sort(self.graph))\n except nx.NetworkXUnfeasible:\n raise ValueError(\"The dependency graph contains cycles, which is not allowed in a procedure.\")\n
"},{"location":"reference/#src.aeiva.action.procedure.Procedure.reset","title":"reset()
","text":"Resets the status of the procedure and all its steps.
Source code in src/aeiva/action/procedure.py
def reset(self) -> None:\n \"\"\"\n Resets the status of the procedure and all its steps.\n \"\"\"\n self.status = Status.NOT_EXECUTED\n for step in self.steps:\n step.reset()\n
"},{"location":"reference/#src.aeiva.action.procedure.Procedure.start","title":"start()
","text":"Marks the procedure as in progress. Raises an error if it's already in progress or finished.
Source code in src/aeiva/action/procedure.py
def start(self) -> None:\n \"\"\"\n Marks the procedure as in progress. Raises an error if it's already in progress or finished.\n \"\"\"\n if self.status != Status.NOT_EXECUTED:\n raise ValueError(f\"{self.type} {self.id} has already been started or finished.\")\n self.status = Status.EXECUTING\n
"},{"location":"reference/#src.aeiva.action.procedure.Procedure.visualize","title":"visualize(save_path=None)
","text":"Visualizes the procedure's structure using networkx and matplotlib.
Source code in src/aeiva/action/procedure.py
def visualize(self, save_path: Optional[str] = None):\n \"\"\"\n Visualizes the procedure's structure using networkx and matplotlib.\n \"\"\"\n pos = nx.spring_layout(self.graph) # Layout for the graph\n labels = {node: f\"{node.id} ({node.description}, {node.status})\" for node in self.graph.nodes()}\n\n # Draw the graph with labels\n nx.draw(self.graph, pos, with_labels=True, labels=labels, node_color='lightblue', node_size=2000, font_size=10, font_weight='bold', arrows=True)\n\n plt.title(f\"{self.type} {self.description} Visualization\")\n if save_path:\n plt.savefig(save_path)\n else:\n plt.show()\n
"},{"location":"reference/#src.aeiva.action.skill","title":"skill
","text":""},{"location":"reference/#src.aeiva.action.skill.Skill","title":"Skill
","text":" Bases: Procedure
Represents a skill, which is a structured roadmap for executing actions. Skills are composed of actions and can be executed. Inherits common functionality from Procedure.
Source code in src/aeiva/action/skill.py
class Skill(Procedure):\n \"\"\"\n Represents a skill, which is a structured roadmap for executing actions.\n Skills are composed of actions and can be executed.\n Inherits common functionality from Procedure.\n \"\"\"\n\n def __init__(self, name: str, steps: List[Union['Skill', Action]],\n id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,\n type: Optional[str] = None, description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n \"\"\"\n Initializes a Skill by extending Procedure.\n \"\"\"\n super().__init__(name=name, steps=steps,\n id=id, dependent_ids=dependent_ids,\n type=type, description=description,\n metadata=metadata)\n self.type = \"Skill\"\n\n def get_topological_sort(self):\n \"\"\"\n Returns the steps in topologically sorted order based on the dependency graph.\n Ensures that all prerequisite steps are executed before the dependent ones.\n \"\"\"\n return list(nx.topological_sort(self.graph))\n\n async def execute(self):\n \"\"\"\n Executes all actions in the skill based on the dependencies defined in the graph.\n This will run the actions asynchronously, respecting their dependencies.\n \"\"\"\n self.start()\n\n # Perform topological sort right before execution\n sorted_steps = self.get_topological_sort()\n\n for step in sorted_steps:\n if isinstance(step, Action):\n print(f\"Executing Action: {step.id} - {step.description}\")\n await step.execute(step.params) # Execute the action asynchronously\n elif isinstance(step, Skill):\n print(f\"Executing Sub-Skill: {step.id}\")\n await step.execute() # If it's a sub-skill, execute the sub-skill\n\n self.end(success=self.is_successful)\n
"},{"location":"reference/#src.aeiva.action.skill.Skill.__init__","title":"__init__(name, steps, id=None, dependent_ids=None, type=None, description=None, metadata=None)
","text":"Initializes a Skill by extending Procedure.
Source code in src/aeiva/action/skill.py
def __init__(self, name: str, steps: List[Union['Skill', Action]],\n id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,\n type: Optional[str] = None, description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n \"\"\"\n Initializes a Skill by extending Procedure.\n \"\"\"\n super().__init__(name=name, steps=steps,\n id=id, dependent_ids=dependent_ids,\n type=type, description=description,\n metadata=metadata)\n self.type = \"Skill\"\n
"},{"location":"reference/#src.aeiva.action.skill.Skill.execute","title":"execute()
async
","text":"Executes all actions in the skill based on the dependencies defined in the graph. This will run the actions asynchronously, respecting their dependencies.
Source code in src/aeiva/action/skill.py
async def execute(self):\n \"\"\"\n Executes all actions in the skill based on the dependencies defined in the graph.\n This will run the actions asynchronously, respecting their dependencies.\n \"\"\"\n self.start()\n\n # Perform topological sort right before execution\n sorted_steps = self.get_topological_sort()\n\n for step in sorted_steps:\n if isinstance(step, Action):\n print(f\"Executing Action: {step.id} - {step.description}\")\n await step.execute(step.params) # Execute the action asynchronously\n elif isinstance(step, Skill):\n print(f\"Executing Sub-Skill: {step.id}\")\n await step.execute() # If it's a sub-skill, execute the sub-skill\n\n self.end(success=self.is_successful)\n
"},{"location":"reference/#src.aeiva.action.skill.Skill.get_topological_sort","title":"get_topological_sort()
","text":"Returns the steps in topologically sorted order based on the dependency graph. Ensures that all prerequisite steps are executed before the dependent ones.
Source code in src/aeiva/action/skill.py
def get_topological_sort(self):\n \"\"\"\n Returns the steps in topologically sorted order based on the dependency graph.\n Ensures that all prerequisite steps are executed before the dependent ones.\n \"\"\"\n return list(nx.topological_sort(self.graph))\n
"},{"location":"reference/#src.aeiva.action.status","title":"status
","text":""},{"location":"reference/#src.aeiva.action.status.Status","title":"Status
","text":"A class to hold status constants.
Source code in src/aeiva/action/status.py
class Status:\n \"\"\"\n A class to hold status constants.\n \"\"\"\n NOT_EXECUTED = \"Not Executed\"\n EXECUTING = \"Executing\"\n SUCCESS = \"Success\"\n FAIL = \"Fail\"\n
"},{"location":"reference/#src.aeiva.action.step","title":"step
","text":""},{"location":"reference/#src.aeiva.action.step.Step","title":"Step
","text":"Abstract base class for atomic units like Task and Action. Contains shared attributes and methods for managing their execution and dependencies.
Source code in src/aeiva/action/step.py
class Step:\n \"\"\"\n Abstract base class for atomic units like Task and Action.\n Contains shared attributes and methods for managing their execution and dependencies.\n \"\"\"\n\n def __init__(self, name: str, params: Dict[str, Any] = None,\n id: Optional[str] = None, dependent_ids: Optional[List[str]] = None, \n type: Optional[str] = None, description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None,\n *args, **kwargs):\n self.name = name # The name of the step. It can be a task/action/tool/api/function name\n self.params = params # The parameters for this step. it can be a task/action/tool/api/function's params\n self.id = id # Unique identifier for the step\n self.dependent_ids = dependent_ids or [] # List of IDs of steps that must be completed before this one\n self.type = type # The type of this step, e.g., task or action\n self.description = description # A description for this step\n self.metadata = metadata or {} # Optional metadata (e.g., id, type, description, priority, etc.)\n self.status = Status.NOT_EXECUTED # Initial status\n\n def reset(self) -> None:\n \"\"\"\n Resets the step status, making it ready for re-execution.\n \"\"\"\n self.status = Status.NOT_EXECUTED\n\n def start(self) -> None:\n \"\"\"\n Marks the step as in progress. Raises an error if the step is already started or finished.\n \"\"\"\n if self.status != Status.NOT_EXECUTED:\n raise ValueError(f\"{self.type} {self.description} {self.id} has already been started or finished.\")\n self.status = Status.EXECUTING\n\n def end(self, success: bool) -> None:\n \"\"\"\n Marks the step as finished and indicates whether it was successful.\n Can only be called if the step is in progress.\n \"\"\"\n if self.status != Status.EXECUTING:\n raise ValueError(f\"Cannot finish a {self.type} that hasn't started.\")\n self.status = Status.SUCCESS if success else Status.FAIL\n\n @property\n def is_successful(self) -> bool:\n \"\"\"\n Returns True if the step was completed successfully.\n \"\"\"\n return self.status == Status.SUCCESS\n\n @property\n def is_failed(self) -> bool:\n \"\"\"\n Returns True if the step has finished but failed.\n \"\"\"\n return self.status == Status.FAIL\n\n @property\n def is_in_progress(self) -> bool:\n \"\"\"\n Returns True if the step is in progress (executing but not finished).\n \"\"\"\n return self.status == Status.EXECUTING\n\n @property\n def is_not_started(self) -> bool:\n \"\"\"\n Returns True if the step has not started yet.\n \"\"\"\n return self.status == Status.NOT_EXECUTED\n\n @property\n def is_finished(self) -> bool:\n \"\"\"\n Returns True if the step has finished execution, either successfully or failed.\n \"\"\"\n return self.status == Status.SUCCESS or self.status == Status.FAIL\n\n def to_dict(self) -> Dict[str, Any]:\n \"\"\"\n Converts the step into a dictionary representation.\n \"\"\"\n return {\n \"name\": self.name,\n \"params\": self.params,\n \"id\": self.id,\n \"dependent_ids\": self.dependent_ids,\n \"type\": self.type,\n \"description\": self.description,\n \"status\": self.status,\n \"metadata\": self.metadata\n }\n
"},{"location":"reference/#src.aeiva.action.step.Step.is_failed","title":"is_failed: bool
property
","text":"Returns True if the step has finished but failed.
"},{"location":"reference/#src.aeiva.action.step.Step.is_finished","title":"is_finished: bool
property
","text":"Returns True if the step has finished execution, either successfully or failed.
"},{"location":"reference/#src.aeiva.action.step.Step.is_in_progress","title":"is_in_progress: bool
property
","text":"Returns True if the step is in progress (executing but not finished).
"},{"location":"reference/#src.aeiva.action.step.Step.is_not_started","title":"is_not_started: bool
property
","text":"Returns True if the step has not started yet.
"},{"location":"reference/#src.aeiva.action.step.Step.is_successful","title":"is_successful: bool
property
","text":"Returns True if the step was completed successfully.
"},{"location":"reference/#src.aeiva.action.step.Step.end","title":"end(success)
","text":"Marks the step as finished and indicates whether it was successful. Can only be called if the step is in progress.
Source code in src/aeiva/action/step.py
def end(self, success: bool) -> None:\n \"\"\"\n Marks the step as finished and indicates whether it was successful.\n Can only be called if the step is in progress.\n \"\"\"\n if self.status != Status.EXECUTING:\n raise ValueError(f\"Cannot finish a {self.type} that hasn't started.\")\n self.status = Status.SUCCESS if success else Status.FAIL\n
"},{"location":"reference/#src.aeiva.action.step.Step.reset","title":"reset()
","text":"Resets the step status, making it ready for re-execution.
Source code in src/aeiva/action/step.py
def reset(self) -> None:\n \"\"\"\n Resets the step status, making it ready for re-execution.\n \"\"\"\n self.status = Status.NOT_EXECUTED\n
"},{"location":"reference/#src.aeiva.action.step.Step.start","title":"start()
","text":"Marks the step as in progress. Raises an error if the step is already started or finished.
Source code in src/aeiva/action/step.py
def start(self) -> None:\n \"\"\"\n Marks the step as in progress. Raises an error if the step is already started or finished.\n \"\"\"\n if self.status != Status.NOT_EXECUTED:\n raise ValueError(f\"{self.type} {self.description} {self.id} has already been started or finished.\")\n self.status = Status.EXECUTING\n
"},{"location":"reference/#src.aeiva.action.step.Step.to_dict","title":"to_dict()
","text":"Converts the step into a dictionary representation.
Source code in src/aeiva/action/step.py
def to_dict(self) -> Dict[str, Any]:\n \"\"\"\n Converts the step into a dictionary representation.\n \"\"\"\n return {\n \"name\": self.name,\n \"params\": self.params,\n \"id\": self.id,\n \"dependent_ids\": self.dependent_ids,\n \"type\": self.type,\n \"description\": self.description,\n \"status\": self.status,\n \"metadata\": self.metadata\n }\n
"},{"location":"reference/#src.aeiva.action.task","title":"task
","text":""},{"location":"reference/#src.aeiva.action.task.Task","title":"Task
","text":" Bases: Step
Represents the fundamental unit of work, extending from the Step class. Inherits shared attributes and methods from Step and adds task-specific functionality.
Source code in src/aeiva/action/task.py
class Task(Step):\n \"\"\"\n Represents the fundamental unit of work, extending from the Step class.\n Inherits shared attributes and methods from Step and adds task-specific functionality.\n \"\"\"\n\n def __init__(self, name: str, params: Dict[str, Any] = None,\n id: Optional[str] = None, dependent_ids: Optional[List[str]] = None, \n type: Optional[str] = None, description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n super().__init__(name=name, params=params,\n id=id, dependent_ids=dependent_ids,\n type=type, description=description,\n metadata=metadata)\n self.type = \"Task\"\n\n def show(self) -> None:\n print(\"---- Task Information ----\")\n pprint(self.to_dict(), sort_dicts=False)\n print(\"---- End of Task ----\")\n
"},{"location":"reference/#src.aeiva.agent","title":"agent
","text":""},{"location":"reference/#src.aeiva.agent.agent","title":"agent
","text":""},{"location":"reference/#src.aeiva.agent.agent.Agent","title":"Agent
","text":"Represents the agent that integrates perception, cognition, and action systems.
Source code in src/aeiva/agent/agent.py
class Agent:\n \"\"\"\n Represents the agent that integrates perception, cognition, and action systems.\n \"\"\"\n def __init__(self, config: Dict):\n self.config_dict = config\n self.config = None\n self.event_bus = EventBus()\n self.perception_system = None\n self.cognition_system = None\n self.action_system = None\n\n def setup(self) -> None:\n \"\"\"\n Set up all systems.\n \"\"\"\n perception_config = self.config_dict.get('perception_config', {})\n cognition_config = self.config_dict # NOTE: we didn't define a cognition config class yet.\n action_config = self.config_dict.get('action_config', {})\n\n self.perception_system = PerceptionSystem(perception_config, self.event_bus)\n self.cognition_system = CognitionSystem(cognition_config)\n self.action_system = ActionSystem(action_config)\n\n self.perception_system.setup()\n self.cognition_system.setup()\n self.action_system.setup()\n\n async def run(self) -> None:\n \"\"\"\n Run the agent by connecting perception, cognition, and action systems using the event bus.\n \"\"\"\n # Start the event bus within the running event loop\n self.event_bus.start()\n # Assign the current running loop to the EventBus\n self.event_bus.loop = asyncio.get_running_loop()\n # Set up event handlers\n self.setup_event_handlers()\n # Start the perception system\n await self.perception_system.start()\n\n # Keep the event loop running until interrupted\n try:\n while True:\n await asyncio.sleep(1)\n except KeyboardInterrupt:\n # Handle graceful shutdown\n self.perception_system.stop()\n await self.event_bus.wait_until_all_events_processed()\n self.event_bus.stop()\n except asyncio.CancelledError:\n pass\n except Exception as e:\n # logger.error(f\"Unexpected error in agent run loop: {e}\")\n print(f\"Unexpected error in agent run loop: {e}\", flush=True)\n await self.perception_system.stop()\n await self.event_bus.wait_until_all_events_processed()\n self.event_bus.stop()\n\n async def process_input(self, input_text: str) -> str:\n \"\"\"\n Process input text and return the agent's response.\n \"\"\"\n stream = self.config_dict.get(\"llm_gateway_config\").get(\"llm_stream\")\n use_async = self.config_dict.get(\"llm_gateway_config\").get(\"llm_use_async\")\n stimuli = Stimuli(signals=[Signal(data=input_text, modularity='text')])\n output = \"\"\n try:\n response_gen = self.cognition_system.think(stimuli, tools=self.action_system.tools, stream=stream, use_async=use_async)\n async for chunk in response_gen:\n if isinstance(chunk, str):\n # For streaming chunks\n output += chunk\n elif isinstance(chunk, Thought) or isinstance(chunk, Plan):\n # For non-streaming responses\n output += chunk.content\n except Exception as e:\n logger.error(f\"Error in response: {e}\")\n return output\n\n def setup_event_handlers(self) -> None:\n \"\"\"\n Set up event handlers for perception, cognition, and action events.\n \"\"\"\n\n @self.event_bus.on('perception.stimuli')\n async def handle_stimuli(event: Event):\n # print(\"handle_stimuli called\", flush=True)\n user_input = event.payload\n stimuli = Stimuli(signals=[Signal(data=user_input, modularity='text')])\n #print(f\"Received stimuli: {stimuli}\", flush=True)\n # Process stimuli through cognition system\n #stimuli = [{\"role\": \"user\", \"content\": stimuli}]\n\n stream = self.config_dict.get(\"llm_gateway_config\").get(\"llm_stream\")\n use_async = self.config_dict.get(\"llm_gateway_config\").get(\"llm_use_async\")\n sys.stdout.write(\"\\r\\033[K\") # Return to start of the line and clear it\\\n print(\"Response: \", end='', flush=True)\n\n try:\n response_gen = self.cognition_system.think(stimuli, tools=self.action_system.tools, stream=stream, use_async=use_async)\n async for chunk in response_gen:\n if isinstance(chunk, str):\n # For streaming chunks\n print(f\"{chunk}\", end='', flush=True)\n elif isinstance(chunk, Thought) or isinstance(chunk, Plan):\n # For non-streaming responses\n print(f\"{chunk.content}\", end='', flush=True)\n except Exception as e:\n logger.error(f\"Error in response: {e}\")\n\n print(\"\\nYou: \", end='', flush=True)\n\n # # Determine if output is a Plan or Thought\n # if isinstance(output, Plan): # TODO: change later\n # print(\"Output is a Plan\", flush=True)\n # await self.event_bus.emit('action.plan', payload=output)\n # elif isinstance(output, Thought):\n # print(\"Output is a Thought\", flush=True)\n # print(f\"Agent Response: {output.content}\", flush=True)\n # else:\n # print(\"Unknown output from cognition system.\", flush=True)\n\n @self.event_bus.on('action.plan')\n async def handle_plan(event: Event):\n print(\"handle_plan called\", flush=True)\n plan = event.payload\n await self.action_system.execute(plan)\n\n @self.event_bus.on('perception.gradio')\n async def handle_gradio_input(event: Event):\n \"\"\"\n Handle input from Gradio and emit response.gradio events.\n \"\"\"\n user_input = event.payload\n stimuli = Stimuli(signals=[Signal(data=user_input, modularity='text')])\n\n stream = self.config_dict.get(\"llm_gateway_config\").get(\"llm_stream\")\n use_async = self.config_dict.get(\"llm_gateway_config\").get(\"llm_use_async\")\n logger.info(f\"Handling Gradio input: {user_input} | Stream: {stream}\")\n try:\n response_gen = self.cognition_system.think(stimuli, tools=self.action_system.tools, stream=stream, use_async=use_async)\n\n async for chunk in response_gen:\n if isinstance(chunk, str):\n # For streaming chunks\n await self.event_bus.emit('response.gradio', payload=chunk)\n elif isinstance(chunk, Thought) or isinstance(chunk, Plan):\n # For non-streaming responses\n await self.event_bus.emit('response.gradio', payload=chunk.content if hasattr(chunk, 'content') else str(chunk))\n\n if stream:\n await self.event_bus.emit('response.gradio', payload=\"<END_OF_RESPONSE>\")\n except Exception as e:\n logger.error(f\"Error in streaming response: {e}\")\n await self.event_bus.emit('response.gradio', payload=\"An error occurred during response generation.\")\n if stream:\n await self.event_bus.emit('response.gradio', payload=\"<END_OF_RESPONSE>\")\n
"},{"location":"reference/#src.aeiva.agent.agent.Agent.process_input","title":"process_input(input_text)
async
","text":"Process input text and return the agent's response.
Source code in src/aeiva/agent/agent.py
async def process_input(self, input_text: str) -> str:\n \"\"\"\n Process input text and return the agent's response.\n \"\"\"\n stream = self.config_dict.get(\"llm_gateway_config\").get(\"llm_stream\")\n use_async = self.config_dict.get(\"llm_gateway_config\").get(\"llm_use_async\")\n stimuli = Stimuli(signals=[Signal(data=input_text, modularity='text')])\n output = \"\"\n try:\n response_gen = self.cognition_system.think(stimuli, tools=self.action_system.tools, stream=stream, use_async=use_async)\n async for chunk in response_gen:\n if isinstance(chunk, str):\n # For streaming chunks\n output += chunk\n elif isinstance(chunk, Thought) or isinstance(chunk, Plan):\n # For non-streaming responses\n output += chunk.content\n except Exception as e:\n logger.error(f\"Error in response: {e}\")\n return output\n
"},{"location":"reference/#src.aeiva.agent.agent.Agent.run","title":"run()
async
","text":"Run the agent by connecting perception, cognition, and action systems using the event bus.
Source code in src/aeiva/agent/agent.py
async def run(self) -> None:\n \"\"\"\n Run the agent by connecting perception, cognition, and action systems using the event bus.\n \"\"\"\n # Start the event bus within the running event loop\n self.event_bus.start()\n # Assign the current running loop to the EventBus\n self.event_bus.loop = asyncio.get_running_loop()\n # Set up event handlers\n self.setup_event_handlers()\n # Start the perception system\n await self.perception_system.start()\n\n # Keep the event loop running until interrupted\n try:\n while True:\n await asyncio.sleep(1)\n except KeyboardInterrupt:\n # Handle graceful shutdown\n self.perception_system.stop()\n await self.event_bus.wait_until_all_events_processed()\n self.event_bus.stop()\n except asyncio.CancelledError:\n pass\n except Exception as e:\n # logger.error(f\"Unexpected error in agent run loop: {e}\")\n print(f\"Unexpected error in agent run loop: {e}\", flush=True)\n await self.perception_system.stop()\n await self.event_bus.wait_until_all_events_processed()\n self.event_bus.stop()\n
"},{"location":"reference/#src.aeiva.agent.agent.Agent.setup","title":"setup()
","text":"Set up all systems.
Source code in src/aeiva/agent/agent.py
def setup(self) -> None:\n \"\"\"\n Set up all systems.\n \"\"\"\n perception_config = self.config_dict.get('perception_config', {})\n cognition_config = self.config_dict # NOTE: we didn't define a cognition config class yet.\n action_config = self.config_dict.get('action_config', {})\n\n self.perception_system = PerceptionSystem(perception_config, self.event_bus)\n self.cognition_system = CognitionSystem(cognition_config)\n self.action_system = ActionSystem(action_config)\n\n self.perception_system.setup()\n self.cognition_system.setup()\n self.action_system.setup()\n
"},{"location":"reference/#src.aeiva.agent.agent.Agent.setup_event_handlers","title":"setup_event_handlers()
","text":"Set up event handlers for perception, cognition, and action events.
Source code in src/aeiva/agent/agent.py
def setup_event_handlers(self) -> None:\n \"\"\"\n Set up event handlers for perception, cognition, and action events.\n \"\"\"\n\n @self.event_bus.on('perception.stimuli')\n async def handle_stimuli(event: Event):\n # print(\"handle_stimuli called\", flush=True)\n user_input = event.payload\n stimuli = Stimuli(signals=[Signal(data=user_input, modularity='text')])\n #print(f\"Received stimuli: {stimuli}\", flush=True)\n # Process stimuli through cognition system\n #stimuli = [{\"role\": \"user\", \"content\": stimuli}]\n\n stream = self.config_dict.get(\"llm_gateway_config\").get(\"llm_stream\")\n use_async = self.config_dict.get(\"llm_gateway_config\").get(\"llm_use_async\")\n sys.stdout.write(\"\\r\\033[K\") # Return to start of the line and clear it\\\n print(\"Response: \", end='', flush=True)\n\n try:\n response_gen = self.cognition_system.think(stimuli, tools=self.action_system.tools, stream=stream, use_async=use_async)\n async for chunk in response_gen:\n if isinstance(chunk, str):\n # For streaming chunks\n print(f\"{chunk}\", end='', flush=True)\n elif isinstance(chunk, Thought) or isinstance(chunk, Plan):\n # For non-streaming responses\n print(f\"{chunk.content}\", end='', flush=True)\n except Exception as e:\n logger.error(f\"Error in response: {e}\")\n\n print(\"\\nYou: \", end='', flush=True)\n\n # # Determine if output is a Plan or Thought\n # if isinstance(output, Plan): # TODO: change later\n # print(\"Output is a Plan\", flush=True)\n # await self.event_bus.emit('action.plan', payload=output)\n # elif isinstance(output, Thought):\n # print(\"Output is a Thought\", flush=True)\n # print(f\"Agent Response: {output.content}\", flush=True)\n # else:\n # print(\"Unknown output from cognition system.\", flush=True)\n\n @self.event_bus.on('action.plan')\n async def handle_plan(event: Event):\n print(\"handle_plan called\", flush=True)\n plan = event.payload\n await self.action_system.execute(plan)\n\n @self.event_bus.on('perception.gradio')\n async def handle_gradio_input(event: Event):\n \"\"\"\n Handle input from Gradio and emit response.gradio events.\n \"\"\"\n user_input = event.payload\n stimuli = Stimuli(signals=[Signal(data=user_input, modularity='text')])\n\n stream = self.config_dict.get(\"llm_gateway_config\").get(\"llm_stream\")\n use_async = self.config_dict.get(\"llm_gateway_config\").get(\"llm_use_async\")\n logger.info(f\"Handling Gradio input: {user_input} | Stream: {stream}\")\n try:\n response_gen = self.cognition_system.think(stimuli, tools=self.action_system.tools, stream=stream, use_async=use_async)\n\n async for chunk in response_gen:\n if isinstance(chunk, str):\n # For streaming chunks\n await self.event_bus.emit('response.gradio', payload=chunk)\n elif isinstance(chunk, Thought) or isinstance(chunk, Plan):\n # For non-streaming responses\n await self.event_bus.emit('response.gradio', payload=chunk.content if hasattr(chunk, 'content') else str(chunk))\n\n if stream:\n await self.event_bus.emit('response.gradio', payload=\"<END_OF_RESPONSE>\")\n except Exception as e:\n logger.error(f\"Error in streaming response: {e}\")\n await self.event_bus.emit('response.gradio', payload=\"An error occurred during response generation.\")\n if stream:\n await self.event_bus.emit('response.gradio', payload=\"<END_OF_RESPONSE>\")\n
"},{"location":"reference/#src.aeiva.agent.base_agent","title":"base_agent
","text":""},{"location":"reference/#src.aeiva.agent.base_agent.BaseAgent","title":"BaseAgent
","text":" Bases: ABC
Abstract base class for autonomous agents with perception, cognition, and action capabilities.
Source code in src/aeiva/agent/base_agent.py
class BaseAgent(ABC):\n \"\"\"\n Abstract base class for autonomous agents with perception, cognition, and action capabilities.\n \"\"\"\n\n def __init__(self, config: Any):\n \"\"\"\n Initialize the agent with configuration.\n\n Args:\n config (Any): Configuration settings for the agent.\n \"\"\"\n self.config = config\n self.state = self.initialize_state() # can be a dict that includes: id, profile, motivation, goal, task, plan, etc.\n self.stop_event = asyncio.Event()\n\n # Systems will be initialized in the setup method\n self.perception_system: PerceptionSystem = None\n self.cognition_system: CognitionSystem = None\n self.action_system: ActionSystem = None\n\n @abstractmethod\n def initialize_state(self) -> Any:\n \"\"\"\n Initialize the agent's state.\n\n Returns:\n Any: The initial state of the agent.\n \"\"\"\n pass\n\n @abstractmethod\n def setup(self) -> None:\n \"\"\"\n Set up the agent's components (perception, cognition, action, etc.).\n Perform any asynchronous initialization if necessary.\n \"\"\"\n pass\n\n @abstractmethod\n async def cycle(self) -> None:\n \"\"\"\n Execute one cycle of perception, cognition, and action.\n This method should be overridden to define the agent's behavior per cycle.\n \"\"\"\n pass\n\n async def run(self) -> None:\n \"\"\"\n Run the agent, continuously executing cycles until stopped.\n \"\"\"\n await self.setup()\n cycle_interval = self.config.get('cycle_interval', 1.0)\n while not self.stop_event.is_set():\n try:\n await self.cycle()\n except Exception as e:\n self.handle_error(e)\n await asyncio.sleep(cycle_interval)\n\n def stop(self) -> None:\n \"\"\"\n Signal the agent to stop running.\n \"\"\"\n self.stop_event.set()\n\n def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during cycle execution.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n # Implement your error handling logic here (e.g., logging)\n print(f\"Error during agent cycle: {error}\")\n
"},{"location":"reference/#src.aeiva.agent.base_agent.BaseAgent.__init__","title":"__init__(config)
","text":"Initialize the agent with configuration.
Parameters:
Name Type Description Default config
Any
Configuration settings for the agent.
required Source code in src/aeiva/agent/base_agent.py
def __init__(self, config: Any):\n \"\"\"\n Initialize the agent with configuration.\n\n Args:\n config (Any): Configuration settings for the agent.\n \"\"\"\n self.config = config\n self.state = self.initialize_state() # can be a dict that includes: id, profile, motivation, goal, task, plan, etc.\n self.stop_event = asyncio.Event()\n\n # Systems will be initialized in the setup method\n self.perception_system: PerceptionSystem = None\n self.cognition_system: CognitionSystem = None\n self.action_system: ActionSystem = None\n
"},{"location":"reference/#src.aeiva.agent.base_agent.BaseAgent.cycle","title":"cycle()
abstractmethod
async
","text":"Execute one cycle of perception, cognition, and action. This method should be overridden to define the agent's behavior per cycle.
Source code in src/aeiva/agent/base_agent.py
@abstractmethod\nasync def cycle(self) -> None:\n \"\"\"\n Execute one cycle of perception, cognition, and action.\n This method should be overridden to define the agent's behavior per cycle.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.agent.base_agent.BaseAgent.handle_error","title":"handle_error(error)
","text":"Handle errors that occur during cycle execution.
Parameters:
Name Type Description Default error
Exception
The exception that was raised.
required Source code in src/aeiva/agent/base_agent.py
def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during cycle execution.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n # Implement your error handling logic here (e.g., logging)\n print(f\"Error during agent cycle: {error}\")\n
"},{"location":"reference/#src.aeiva.agent.base_agent.BaseAgent.initialize_state","title":"initialize_state()
abstractmethod
","text":"Initialize the agent's state.
Returns:
Name Type Description Any
Any
The initial state of the agent.
Source code in src/aeiva/agent/base_agent.py
@abstractmethod\ndef initialize_state(self) -> Any:\n \"\"\"\n Initialize the agent's state.\n\n Returns:\n Any: The initial state of the agent.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.agent.base_agent.BaseAgent.run","title":"run()
async
","text":"Run the agent, continuously executing cycles until stopped.
Source code in src/aeiva/agent/base_agent.py
async def run(self) -> None:\n \"\"\"\n Run the agent, continuously executing cycles until stopped.\n \"\"\"\n await self.setup()\n cycle_interval = self.config.get('cycle_interval', 1.0)\n while not self.stop_event.is_set():\n try:\n await self.cycle()\n except Exception as e:\n self.handle_error(e)\n await asyncio.sleep(cycle_interval)\n
"},{"location":"reference/#src.aeiva.agent.base_agent.BaseAgent.setup","title":"setup()
abstractmethod
","text":"Set up the agent's components (perception, cognition, action, etc.). Perform any asynchronous initialization if necessary.
Source code in src/aeiva/agent/base_agent.py
@abstractmethod\ndef setup(self) -> None:\n \"\"\"\n Set up the agent's components (perception, cognition, action, etc.).\n Perform any asynchronous initialization if necessary.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.agent.base_agent.BaseAgent.stop","title":"stop()
","text":"Signal the agent to stop running.
Source code in src/aeiva/agent/base_agent.py
def stop(self) -> None:\n \"\"\"\n Signal the agent to stop running.\n \"\"\"\n self.stop_event.set()\n
"},{"location":"reference/#src.aeiva.cognition","title":"cognition
","text":""},{"location":"reference/#src.aeiva.cognition.brain","title":"brain
","text":""},{"location":"reference/#src.aeiva.cognition.brain.brain","title":"brain
","text":""},{"location":"reference/#src.aeiva.cognition.brain.brain.Brain","title":"Brain
","text":" Bases: ABC
Abstract base class representing the cognitive processing unit.
The Brain is responsible for processing input stimuli to generate cognitive states that the CognitionSystem will translate into actions.
Attributes:
Name Type Description config
Any
Configuration settings for the Brain.
state
Any
The internal state of the Brain.
Source code in src/aeiva/cognition/brain/brain.py
class Brain(ABC):\n \"\"\"\n Abstract base class representing the cognitive processing unit.\n\n The Brain is responsible for processing input stimuli to generate cognitive states\n that the CognitionSystem will translate into actions.\n\n Attributes:\n config (Any): Configuration settings for the Brain.\n state (Any): The internal state of the Brain.\n \"\"\"\n\n def __init__(self, config: Any):\n \"\"\"\n Initialize the Brain with the provided configuration.\n\n Args:\n config (Any): Configuration settings for the Brain.\n \"\"\"\n self.config = config\n self.state = self.init_state()\n\n @abstractmethod\n def init_state(self) -> Any:\n \"\"\"\n Initialize the internal state of the Brain.\n\n This method should set up the initial state required for the Brain's operations.\n\n Returns:\n Any: The initial state of the Brain.\n \"\"\"\n pass\n\n @abstractmethod\n def setup(self) -> None:\n \"\"\"\n Asynchronously set up the Brain's components.\n\n This method should initialize any necessary components or resources\n based on the provided configuration.\n\n Raises:\n ConfigurationError: If the configuration is invalid or incomplete.\n \"\"\"\n pass\n\n @abstractmethod\n async def think(self, stimuli: Any, *args, **kwargs) -> Any:\n \"\"\"\n Asynchronously process input stimuli to update the cognitive state.\n\n Args:\n stimuli (Any): The input stimuli to process.\n\n Returns:\n Any: The updated cognitive state.\n\n Raises:\n ProcessingError: If processing the stimuli fails.\n \"\"\"\n pass\n\n def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during cognitive processing.\n\n This method can be overridden to implement custom error handling logic.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n # Default error handling: log the error\n print(f\"Brain encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.brain.brain.Brain.__init__","title":"__init__(config)
","text":"Initialize the Brain with the provided configuration.
Parameters:
Name Type Description Default config
Any
Configuration settings for the Brain.
required Source code in src/aeiva/cognition/brain/brain.py
def __init__(self, config: Any):\n \"\"\"\n Initialize the Brain with the provided configuration.\n\n Args:\n config (Any): Configuration settings for the Brain.\n \"\"\"\n self.config = config\n self.state = self.init_state()\n
"},{"location":"reference/#src.aeiva.cognition.brain.brain.Brain.handle_error","title":"handle_error(error)
","text":"Handle errors that occur during cognitive processing.
This method can be overridden to implement custom error handling logic.
Parameters:
Name Type Description Default error
Exception
The exception that was raised.
required Source code in src/aeiva/cognition/brain/brain.py
def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during cognitive processing.\n\n This method can be overridden to implement custom error handling logic.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n # Default error handling: log the error\n print(f\"Brain encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.brain.brain.Brain.init_state","title":"init_state()
abstractmethod
","text":"Initialize the internal state of the Brain.
This method should set up the initial state required for the Brain's operations.
Returns:
Name Type Description Any
Any
The initial state of the Brain.
Source code in src/aeiva/cognition/brain/brain.py
@abstractmethod\ndef init_state(self) -> Any:\n \"\"\"\n Initialize the internal state of the Brain.\n\n This method should set up the initial state required for the Brain's operations.\n\n Returns:\n Any: The initial state of the Brain.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.brain.brain.Brain.setup","title":"setup()
abstractmethod
","text":"Asynchronously set up the Brain's components.
This method should initialize any necessary components or resources based on the provided configuration.
Raises:
Type Description ConfigurationError
If the configuration is invalid or incomplete.
Source code in src/aeiva/cognition/brain/brain.py
@abstractmethod\ndef setup(self) -> None:\n \"\"\"\n Asynchronously set up the Brain's components.\n\n This method should initialize any necessary components or resources\n based on the provided configuration.\n\n Raises:\n ConfigurationError: If the configuration is invalid or incomplete.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.brain.brain.Brain.think","title":"think(stimuli, *args, **kwargs)
abstractmethod
async
","text":"Asynchronously process input stimuli to update the cognitive state.
Parameters:
Name Type Description Default stimuli
Any
The input stimuli to process.
required Returns:
Name Type Description Any
Any
The updated cognitive state.
Raises:
Type Description ProcessingError
If processing the stimuli fails.
Source code in src/aeiva/cognition/brain/brain.py
@abstractmethod\nasync def think(self, stimuli: Any, *args, **kwargs) -> Any:\n \"\"\"\n Asynchronously process input stimuli to update the cognitive state.\n\n Args:\n stimuli (Any): The input stimuli to process.\n\n Returns:\n Any: The updated cognitive state.\n\n Raises:\n ProcessingError: If processing the stimuli fails.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.brain.llm_brain","title":"llm_brain
","text":""},{"location":"reference/#src.aeiva.cognition.brain.llm_brain.LLMBrain","title":"LLMBrain
","text":" Bases: Brain
Concrete implementation of the Brain, using an LLM to process stimuli and generate cognitive states.
This brain uses the LLMClient to communicate with a language model to process input stimuli and produce outputs.
Source code in src/aeiva/cognition/brain/llm_brain.py
class LLMBrain(Brain):\n \"\"\"\n Concrete implementation of the Brain, using an LLM to process stimuli\n and generate cognitive states.\n\n This brain uses the LLMClient to communicate with a language model to\n process input stimuli and produce outputs.\n \"\"\"\n\n def __init__(self, config: Dict):\n \"\"\"\n Initialize the LLMBrain with the provided LLM configuration.\n\n Args:\n config (LLMGatewayConfig): Configuration settings for the LLMBrain.\n \"\"\"\n super().__init__(config)\n self.config_dict = config\n self.config = None\n self.llm_client = None\n\n def init_state(self) -> Any:\n \"\"\"\n Initialize the internal state of the Brain.\n\n The state can track the ongoing conversation or task context.\n\n Returns:\n dict: Initial empty state.\n \"\"\"\n return {\"conversation\": [], \"cognitive_state\": None}\n\n def setup(self) -> None:\n \"\"\"\n Set up the Brain's components.\n\n For the LLMBrain, this might involve validating the LLM configuration\n and ensuring that all necessary resources are in place.\n \"\"\"\n llm_conf_dict = self.config_dict.get('llm_gateway_config', {})\n self.config = LLMGatewayConfig(\n llm_api_key=llm_conf_dict.get('llm_api_key'),\n llm_model_name=llm_conf_dict.get('llm_model_name', 'gpt-4o'),\n llm_temperature=llm_conf_dict.get('llm_temperature', 0.7),\n llm_max_output_tokens=llm_conf_dict.get('llm_max_output_tokens', 10000),\n llm_use_async=llm_conf_dict.get('llm_use_async', False),\n llm_stream=llm_conf_dict.get('llm_stream', False)\n )\n self.llm_client = LLMClient(self.config)\n\n system_prompt = llm_conf_dict.get('llm_system_prompt', None)\n if system_prompt is not None: # TODO: only add system prompt for llms that support it.\n self.state[\"conversation\"] += [{ \"role\": \"system\", \"content\": system_prompt }]\n\n print(\"LLMBrain setup complete.\")\n\n async def think(\n self,\n stimuli: Any,\n tools: List[Dict[str, Any]] = None,\n stream: bool = False,\n use_async: bool = False\n ) -> AsyncGenerator[str, None]:\n \"\"\"\n Asynchronously process input stimuli to update the cognitive state.\n\n Args:\n stimuli (Any): The input stimuli to process.\n stream (bool): Whether to use streaming mode. Default is False.\n\n Returns:\n str: The full response in both streaming and non-streaming modes.\n \"\"\"\n try:\n # Assume stimuli is a list of messages (conversation context)\n if not isinstance(stimuli, list):\n raise ValueError(\"Stimuli must be a list of messages.\")\n\n self.state[\"conversation\"] += stimuli #!! NOTE: to let LLM remember the history. \n\n if not use_async: # NOTE: stream mode only works when use_async!!!\n response = self.llm_client(self.state[\"conversation\"], tools=tools, stream=stream) #!! NOTE: llm client will update conversation\n # self.state[\"conversation\"] += [{\"role\": \"assistant\", \"content\": response}]\n self.state[\"cognitive_state\"] = response\n yield response\n elif stream:\n # Stream mode: collect all parts of the streamed response\n response = \"\"\n # messages = self.state[\"conversation\"].copy()\n async for delta in self.llm_client(self.state[\"conversation\"], tools=tools, stream=stream): #!! NOTE: llm client will update conversation\n response += delta # Collect the streamed content\n yield delta\n # self.state[\"conversation\"] += [{\"role\": \"assistant\", \"content\": response}]\n self.state[\"cognitive_state\"] = response\n #return response\n else:\n # messages = self.state[\"conversation\"].copy()\n response = await self.llm_client(self.state[\"conversation\"], tools=tools, stream=stream) #!! NOTE: llm client will update conversation\n # self.state[\"conversation\"] += [{\"role\": \"assistant\", \"content\": response}]\n self.state[\"cognitive_state\"] = response\n yield response\n #return response\n\n except Exception as e:\n self.handle_error(e)\n raise\n\n def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during cognitive processing.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n super().handle_error(error)\n # Custom error handling logic for LLM-related issues\n print(f\"LLMBrain encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.brain.llm_brain.LLMBrain.__init__","title":"__init__(config)
","text":"Initialize the LLMBrain with the provided LLM configuration.
Parameters:
Name Type Description Default config
LLMGatewayConfig
Configuration settings for the LLMBrain.
required Source code in src/aeiva/cognition/brain/llm_brain.py
def __init__(self, config: Dict):\n \"\"\"\n Initialize the LLMBrain with the provided LLM configuration.\n\n Args:\n config (LLMGatewayConfig): Configuration settings for the LLMBrain.\n \"\"\"\n super().__init__(config)\n self.config_dict = config\n self.config = None\n self.llm_client = None\n
"},{"location":"reference/#src.aeiva.cognition.brain.llm_brain.LLMBrain.handle_error","title":"handle_error(error)
","text":"Handle errors that occur during cognitive processing.
Parameters:
Name Type Description Default error
Exception
The exception that was raised.
required Source code in src/aeiva/cognition/brain/llm_brain.py
def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during cognitive processing.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n super().handle_error(error)\n # Custom error handling logic for LLM-related issues\n print(f\"LLMBrain encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.brain.llm_brain.LLMBrain.init_state","title":"init_state()
","text":"Initialize the internal state of the Brain.
The state can track the ongoing conversation or task context.
Returns:
Name Type Description dict
Any
Initial empty state.
Source code in src/aeiva/cognition/brain/llm_brain.py
def init_state(self) -> Any:\n \"\"\"\n Initialize the internal state of the Brain.\n\n The state can track the ongoing conversation or task context.\n\n Returns:\n dict: Initial empty state.\n \"\"\"\n return {\"conversation\": [], \"cognitive_state\": None}\n
"},{"location":"reference/#src.aeiva.cognition.brain.llm_brain.LLMBrain.setup","title":"setup()
","text":"Set up the Brain's components.
For the LLMBrain, this might involve validating the LLM configuration and ensuring that all necessary resources are in place.
Source code in src/aeiva/cognition/brain/llm_brain.py
def setup(self) -> None:\n \"\"\"\n Set up the Brain's components.\n\n For the LLMBrain, this might involve validating the LLM configuration\n and ensuring that all necessary resources are in place.\n \"\"\"\n llm_conf_dict = self.config_dict.get('llm_gateway_config', {})\n self.config = LLMGatewayConfig(\n llm_api_key=llm_conf_dict.get('llm_api_key'),\n llm_model_name=llm_conf_dict.get('llm_model_name', 'gpt-4o'),\n llm_temperature=llm_conf_dict.get('llm_temperature', 0.7),\n llm_max_output_tokens=llm_conf_dict.get('llm_max_output_tokens', 10000),\n llm_use_async=llm_conf_dict.get('llm_use_async', False),\n llm_stream=llm_conf_dict.get('llm_stream', False)\n )\n self.llm_client = LLMClient(self.config)\n\n system_prompt = llm_conf_dict.get('llm_system_prompt', None)\n if system_prompt is not None: # TODO: only add system prompt for llms that support it.\n self.state[\"conversation\"] += [{ \"role\": \"system\", \"content\": system_prompt }]\n\n print(\"LLMBrain setup complete.\")\n
"},{"location":"reference/#src.aeiva.cognition.brain.llm_brain.LLMBrain.think","title":"think(stimuli, tools=None, stream=False, use_async=False)
async
","text":"Asynchronously process input stimuli to update the cognitive state.
Parameters:
Name Type Description Default stimuli
Any
The input stimuli to process.
required stream
bool
Whether to use streaming mode. Default is False.
False
Returns:
Name Type Description str
AsyncGenerator[str, None]
The full response in both streaming and non-streaming modes.
Source code in src/aeiva/cognition/brain/llm_brain.py
async def think(\n self,\n stimuli: Any,\n tools: List[Dict[str, Any]] = None,\n stream: bool = False,\n use_async: bool = False\n ) -> AsyncGenerator[str, None]:\n \"\"\"\n Asynchronously process input stimuli to update the cognitive state.\n\n Args:\n stimuli (Any): The input stimuli to process.\n stream (bool): Whether to use streaming mode. Default is False.\n\n Returns:\n str: The full response in both streaming and non-streaming modes.\n \"\"\"\n try:\n # Assume stimuli is a list of messages (conversation context)\n if not isinstance(stimuli, list):\n raise ValueError(\"Stimuli must be a list of messages.\")\n\n self.state[\"conversation\"] += stimuli #!! NOTE: to let LLM remember the history. \n\n if not use_async: # NOTE: stream mode only works when use_async!!!\n response = self.llm_client(self.state[\"conversation\"], tools=tools, stream=stream) #!! NOTE: llm client will update conversation\n # self.state[\"conversation\"] += [{\"role\": \"assistant\", \"content\": response}]\n self.state[\"cognitive_state\"] = response\n yield response\n elif stream:\n # Stream mode: collect all parts of the streamed response\n response = \"\"\n # messages = self.state[\"conversation\"].copy()\n async for delta in self.llm_client(self.state[\"conversation\"], tools=tools, stream=stream): #!! NOTE: llm client will update conversation\n response += delta # Collect the streamed content\n yield delta\n # self.state[\"conversation\"] += [{\"role\": \"assistant\", \"content\": response}]\n self.state[\"cognitive_state\"] = response\n #return response\n else:\n # messages = self.state[\"conversation\"].copy()\n response = await self.llm_client(self.state[\"conversation\"], tools=tools, stream=stream) #!! NOTE: llm client will update conversation\n # self.state[\"conversation\"] += [{\"role\": \"assistant\", \"content\": response}]\n self.state[\"cognitive_state\"] = response\n yield response\n #return response\n\n except Exception as e:\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.cognition_system","title":"cognition_system
","text":""},{"location":"reference/#src.aeiva.cognition.cognition_system.CognitionSystem","title":"CognitionSystem
","text":"Processes Stimuli into Observations, uses the Brain to generate Thoughts, and orchestrates output into Plans.
Source code in src/aeiva/cognition/cognition_system.py
class CognitionSystem:\n \"\"\"\n Processes Stimuli into Observations, uses the Brain to generate Thoughts, and orchestrates output into Plans.\n \"\"\"\n def __init__(self, config: Dict):\n self.config_dict = config\n self.config = None\n self.input_interpreter = None\n self.brain = None\n self.output_orchestrator = None\n self.memory = None\n self.emotion = None\n self.world_model = None\n self.state = self.init_state()\n\n def init_state(self) -> Dict[str, Any]:\n return {\n \"cognitive_state\": None,\n \"last_input\": None,\n \"last_output\": None\n }\n\n def setup(self) -> None:\n \"\"\"\n Set up the cognition system's components.\n \"\"\"\n self.brain = LLMBrain(config=self.config_dict)\n self.memory = MemoryPalace(config=self.config_dict)\n self.emotion = SimpleEmotion() # TODO: replace\n self.world_model = SimpleWorldModel() # TODO: replace\n self.input_interpreter = SimpleInputInterpreter() # TODO: replace\n self.output_orchestrator = SimpleOutputOrchestrator() # TODO: replace\n\n self.brain.setup()\n self.memory.setup()\n self.world_model.setup()\n self.emotion.setup()\n self.input_interpreter.setup()\n self.output_orchestrator.setup()\n\n def handle_error(self, error: Exception) -> None:\n print(f\"CognitionSystem encountered an error: {error}\")\n\n async def think(\n self,\n stimuli: Stimuli,\n tools: List[Dict[str, Any]] = None,\n stream: bool=False,\n use_async: bool=False\n ) -> AsyncGenerator[str, None]:\n \"\"\"\n Processes stimuli and produces a thought or plan.\n\n Args:\n stimuli (Stimuli): The input stimuli.\n stream (bool): Whether to use streaming mode.\n tools (List[Dict[str, Any]]): Optional tools for function calls.\n\n Yields:\n str: Chunks of the assistant's response.\n \"\"\"\n self.state[\"last_input\"] = stimuli\n\n # Step 1: Use InputInterpreter to process stimuli into observation\n if self.input_interpreter.gate(stimuli):\n observation = await self.input_interpreter.interpret(stimuli)\n else:\n # Directly pass stimuli as observation (assuming it's acceptable)\n observation = Observation(data=stimuli.to_dict())\n\n # Step 2: Brain processes the observation into a thought or plan\n brain_input = [{\"role\": \"user\", \"content\": observation.data}]\n # Initiate brain processing\n response_gen = self.brain.think(brain_input, tools=tools, stream=stream, use_async=use_async)\n\n async for chunk in response_gen:\n if isinstance(chunk, str):\n # Streaming chunk or full response in non-streaming mode\n yield chunk\n elif isinstance(chunk, Thought):\n thought = chunk\n self.state[\"cognitive_state\"] = thought\n\n # Step 3: Use OutputOrchestrator if applicable\n if self.output_orchestrator.gate(thought):\n plan = await self.output_orchestrator.orchestrate(thought)\n self.state[\"last_output\"] = plan\n yield plan.content if hasattr(plan, 'content') else str(plan)\n else:\n self.state[\"last_output\"] = thought\n yield thought.content\n elif isinstance(chunk, Plan):\n plan = chunk\n self.state[\"last_output\"] = plan\n yield plan.content if hasattr(plan, 'content') else str(plan)\n else:\n # Handle unexpected chunk types\n #logger.warning(f\"Unexpected chunk type: {type(chunk)}\")\n yield str(chunk)\n
"},{"location":"reference/#src.aeiva.cognition.cognition_system.CognitionSystem.setup","title":"setup()
","text":"Set up the cognition system's components.
Source code in src/aeiva/cognition/cognition_system.py
def setup(self) -> None:\n \"\"\"\n Set up the cognition system's components.\n \"\"\"\n self.brain = LLMBrain(config=self.config_dict)\n self.memory = MemoryPalace(config=self.config_dict)\n self.emotion = SimpleEmotion() # TODO: replace\n self.world_model = SimpleWorldModel() # TODO: replace\n self.input_interpreter = SimpleInputInterpreter() # TODO: replace\n self.output_orchestrator = SimpleOutputOrchestrator() # TODO: replace\n\n self.brain.setup()\n self.memory.setup()\n self.world_model.setup()\n self.emotion.setup()\n self.input_interpreter.setup()\n self.output_orchestrator.setup()\n
"},{"location":"reference/#src.aeiva.cognition.cognition_system.CognitionSystem.think","title":"think(stimuli, tools=None, stream=False, use_async=False)
async
","text":"Processes stimuli and produces a thought or plan.
Parameters:
Name Type Description Default stimuli
Stimuli
The input stimuli.
required stream
bool
Whether to use streaming mode.
False
tools
List[Dict[str, Any]]
Optional tools for function calls.
None
Yields:
Name Type Description str
AsyncGenerator[str, None]
Chunks of the assistant's response.
Source code in src/aeiva/cognition/cognition_system.py
async def think(\n self,\n stimuli: Stimuli,\n tools: List[Dict[str, Any]] = None,\n stream: bool=False,\n use_async: bool=False\n ) -> AsyncGenerator[str, None]:\n \"\"\"\n Processes stimuli and produces a thought or plan.\n\n Args:\n stimuli (Stimuli): The input stimuli.\n stream (bool): Whether to use streaming mode.\n tools (List[Dict[str, Any]]): Optional tools for function calls.\n\n Yields:\n str: Chunks of the assistant's response.\n \"\"\"\n self.state[\"last_input\"] = stimuli\n\n # Step 1: Use InputInterpreter to process stimuli into observation\n if self.input_interpreter.gate(stimuli):\n observation = await self.input_interpreter.interpret(stimuli)\n else:\n # Directly pass stimuli as observation (assuming it's acceptable)\n observation = Observation(data=stimuli.to_dict())\n\n # Step 2: Brain processes the observation into a thought or plan\n brain_input = [{\"role\": \"user\", \"content\": observation.data}]\n # Initiate brain processing\n response_gen = self.brain.think(brain_input, tools=tools, stream=stream, use_async=use_async)\n\n async for chunk in response_gen:\n if isinstance(chunk, str):\n # Streaming chunk or full response in non-streaming mode\n yield chunk\n elif isinstance(chunk, Thought):\n thought = chunk\n self.state[\"cognitive_state\"] = thought\n\n # Step 3: Use OutputOrchestrator if applicable\n if self.output_orchestrator.gate(thought):\n plan = await self.output_orchestrator.orchestrate(thought)\n self.state[\"last_output\"] = plan\n yield plan.content if hasattr(plan, 'content') else str(plan)\n else:\n self.state[\"last_output\"] = thought\n yield thought.content\n elif isinstance(chunk, Plan):\n plan = chunk\n self.state[\"last_output\"] = plan\n yield plan.content if hasattr(plan, 'content') else str(plan)\n else:\n # Handle unexpected chunk types\n #logger.warning(f\"Unexpected chunk type: {type(chunk)}\")\n yield str(chunk)\n
"},{"location":"reference/#src.aeiva.cognition.emotion","title":"emotion
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.emotion","title":"emotion
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.emotion.ConfigurationError","title":"ConfigurationError
","text":" Bases: Exception
Exception raised for errors in the configuration.
Source code in src/aeiva/cognition/emotion/emotion.py
class ConfigurationError(Exception):\n \"\"\"Exception raised for errors in the configuration.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion","title":"Emotion
","text":" Bases: ABC
, Generic[T]
Abstract base class representing the Emotion system of an agent with generic state type.
The Emotion system manages the agent's emotional states, allowing it to respond to various stimuli in an emotionally coherent manner.
Attributes:
Name Type Description config
Dict[str, Any]
Configuration settings for the Emotion system.
state
T
The internal emotional state of the agent, defined by subclasses.
Source code in src/aeiva/cognition/emotion/emotion.py
class Emotion(ABC, Generic[T]):\n \"\"\"\n Abstract base class representing the Emotion system of an agent with generic state type.\n\n The Emotion system manages the agent's emotional states, allowing it to respond\n to various stimuli in an emotionally coherent manner.\n\n Attributes:\n config (Dict[str, Any]): Configuration settings for the Emotion system.\n state (T): The internal emotional state of the agent, defined by subclasses.\n \"\"\"\n\n def __init__(self, config: Dict[str, Any]):\n \"\"\"\n Initialize the Emotion system with the provided configuration.\n\n Args:\n config (Dict[str, Any]): Configuration settings for the Emotion system.\n \"\"\"\n self.config = config\n self.state = self.init_state()\n\n @abstractmethod\n def init_state(self) -> T:\n \"\"\"\n Initialize the internal emotional state of the Emotion system.\n\n This method should set up the initial emotional state required for the\n Emotion system's operations.\n\n Returns:\n T: The initial emotional state of the agent.\n \"\"\"\n pass\n\n @abstractmethod\n async def setup(self) -> None:\n \"\"\"\n Asynchronously set up the Emotion system's components.\n\n This method should initialize any necessary components or resources\n based on the provided configuration.\n\n Raises:\n ConfigurationError: If the configuration is invalid or incomplete.\n \"\"\"\n pass\n\n @abstractmethod\n async def update(self, input_data: Dict[str, Any]) -> None:\n \"\"\"\n Asynchronously update the emotional state based on input data.\n\n Args:\n input_data (Dict[str, Any]): The data or stimuli that influence the emotional state.\n\n Raises:\n UpdateError: If updating the emotional state fails.\n \"\"\"\n pass\n\n @abstractmethod\n def regulate(self, strategy: str) -> None:\n \"\"\"\n Regulate the emotional state using a specified strategy.\n\n Args:\n strategy (str): The regulation strategy to apply (e.g., 'suppression', 'amplification').\n\n Raises:\n RegulationError: If the regulation strategy is invalid or fails.\n \"\"\"\n pass\n\n @abstractmethod\n def express(self) -> str:\n \"\"\"\n Generate a representation of the current emotional state.\n\n Returns:\n str: A string describing the current emotion (e.g., \"I feel happy!\").\n \"\"\"\n pass\n\n @abstractmethod\n def serialize(self) -> str:\n \"\"\"\n Serialize the current emotional state into a string format.\n\n Returns:\n str: Serialized emotional state.\n \"\"\"\n pass\n\n @abstractmethod\n def deserialize(self, data: str) -> None:\n \"\"\"\n Deserialize the emotional state from a string format.\n\n Args:\n data (str): Serialized emotional state.\n \"\"\"\n pass\n\n def get_current_state(self) -> T:\n \"\"\"\n Retrieve the current emotional state of the agent.\n\n Returns:\n T: The current emotional state.\n \"\"\"\n return self.state\n\n def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during emotional processing.\n\n This method can be overridden to implement custom error handling logic.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion.__init__","title":"__init__(config)
","text":"Initialize the Emotion system with the provided configuration.
Parameters:
Name Type Description Default config
Dict[str, Any]
Configuration settings for the Emotion system.
required Source code in src/aeiva/cognition/emotion/emotion.py
def __init__(self, config: Dict[str, Any]):\n \"\"\"\n Initialize the Emotion system with the provided configuration.\n\n Args:\n config (Dict[str, Any]): Configuration settings for the Emotion system.\n \"\"\"\n self.config = config\n self.state = self.init_state()\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion.deserialize","title":"deserialize(data)
abstractmethod
","text":"Deserialize the emotional state from a string format.
Parameters:
Name Type Description Default data
str
Serialized emotional state.
required Source code in src/aeiva/cognition/emotion/emotion.py
@abstractmethod\ndef deserialize(self, data: str) -> None:\n \"\"\"\n Deserialize the emotional state from a string format.\n\n Args:\n data (str): Serialized emotional state.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion.express","title":"express()
abstractmethod
","text":"Generate a representation of the current emotional state.
Returns:
Name Type Description str
str
A string describing the current emotion (e.g., \"I feel happy!\").
Source code in src/aeiva/cognition/emotion/emotion.py
@abstractmethod\ndef express(self) -> str:\n \"\"\"\n Generate a representation of the current emotional state.\n\n Returns:\n str: A string describing the current emotion (e.g., \"I feel happy!\").\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion.get_current_state","title":"get_current_state()
","text":"Retrieve the current emotional state of the agent.
Returns:
Name Type Description T
T
The current emotional state.
Source code in src/aeiva/cognition/emotion/emotion.py
def get_current_state(self) -> T:\n \"\"\"\n Retrieve the current emotional state of the agent.\n\n Returns:\n T: The current emotional state.\n \"\"\"\n return self.state\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion.handle_error","title":"handle_error(error)
","text":"Handle errors that occur during emotional processing.
This method can be overridden to implement custom error handling logic.
Parameters:
Name Type Description Default error
Exception
The exception that was raised.
required Source code in src/aeiva/cognition/emotion/emotion.py
def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during emotional processing.\n\n This method can be overridden to implement custom error handling logic.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion.init_state","title":"init_state()
abstractmethod
","text":"Initialize the internal emotional state of the Emotion system.
This method should set up the initial emotional state required for the Emotion system's operations.
Returns:
Name Type Description T
T
The initial emotional state of the agent.
Source code in src/aeiva/cognition/emotion/emotion.py
@abstractmethod\ndef init_state(self) -> T:\n \"\"\"\n Initialize the internal emotional state of the Emotion system.\n\n This method should set up the initial emotional state required for the\n Emotion system's operations.\n\n Returns:\n T: The initial emotional state of the agent.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion.regulate","title":"regulate(strategy)
abstractmethod
","text":"Regulate the emotional state using a specified strategy.
Parameters:
Name Type Description Default strategy
str
The regulation strategy to apply (e.g., 'suppression', 'amplification').
required Raises:
Type Description RegulationError
If the regulation strategy is invalid or fails.
Source code in src/aeiva/cognition/emotion/emotion.py
@abstractmethod\ndef regulate(self, strategy: str) -> None:\n \"\"\"\n Regulate the emotional state using a specified strategy.\n\n Args:\n strategy (str): The regulation strategy to apply (e.g., 'suppression', 'amplification').\n\n Raises:\n RegulationError: If the regulation strategy is invalid or fails.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion.serialize","title":"serialize()
abstractmethod
","text":"Serialize the current emotional state into a string format.
Returns:
Name Type Description str
str
Serialized emotional state.
Source code in src/aeiva/cognition/emotion/emotion.py
@abstractmethod\ndef serialize(self) -> str:\n \"\"\"\n Serialize the current emotional state into a string format.\n\n Returns:\n str: Serialized emotional state.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion.setup","title":"setup()
abstractmethod
async
","text":"Asynchronously set up the Emotion system's components.
This method should initialize any necessary components or resources based on the provided configuration.
Raises:
Type Description ConfigurationError
If the configuration is invalid or incomplete.
Source code in src/aeiva/cognition/emotion/emotion.py
@abstractmethod\nasync def setup(self) -> None:\n \"\"\"\n Asynchronously set up the Emotion system's components.\n\n This method should initialize any necessary components or resources\n based on the provided configuration.\n\n Raises:\n ConfigurationError: If the configuration is invalid or incomplete.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion.update","title":"update(input_data)
abstractmethod
async
","text":"Asynchronously update the emotional state based on input data.
Parameters:
Name Type Description Default input_data
Dict[str, Any]
The data or stimuli that influence the emotional state.
required Raises:
Type Description UpdateError
If updating the emotional state fails.
Source code in src/aeiva/cognition/emotion/emotion.py
@abstractmethod\nasync def update(self, input_data: Dict[str, Any]) -> None:\n \"\"\"\n Asynchronously update the emotional state based on input data.\n\n Args:\n input_data (Dict[str, Any]): The data or stimuli that influence the emotional state.\n\n Raises:\n UpdateError: If updating the emotional state fails.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.RegulationError","title":"RegulationError
","text":" Bases: Exception
Exception raised for errors during emotion regulation.
Source code in src/aeiva/cognition/emotion/emotion.py
class RegulationError(Exception):\n \"\"\"Exception raised for errors during emotion regulation.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.UpdateError","title":"UpdateError
","text":" Bases: Exception
Exception raised for errors during emotion state updates.
Source code in src/aeiva/cognition/emotion/emotion.py
class UpdateError(Exception):\n \"\"\"Exception raised for errors during emotion state updates.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion_categorical","title":"emotion_categorical
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.emotion_categorical.CategoricalEmotionState","title":"CategoricalEmotionState
","text":"Represents the emotional state in a Categorical Model.
Source code in src/aeiva/cognition/emotion/emotion_categorical.py
class CategoricalEmotionState:\n \"\"\"\n Represents the emotional state in a Categorical Model.\n \"\"\"\n def __init__(self, emotion_label: str = \"neutral\"):\n self.emotion_label = emotion_label\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n 'emotion_label': self.emotion_label\n }\n\n @staticmethod\n def from_dict(data: Dict[str, Any]):\n return CategoricalEmotionState(\n emotion_label=data.get('emotion_label', 'neutral')\n )\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion_category","title":"emotion_category
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.emotion_category.CategoryEmotionState","title":"CategoryEmotionState
dataclass
","text":"Represents the emotional state in a Category-Based Model with extensive categories.
Attributes:
Name Type Description emotion_label
str
The current emotion category.
intensity
float
The intensity of the current emotion (range: 0.0 to 1.0).
Source code in src/aeiva/cognition/emotion/emotion_category.py
@dataclass\nclass CategoryEmotionState:\n \"\"\"\n Represents the emotional state in a Category-Based Model with extensive categories.\n\n Attributes:\n emotion_label (str): The current emotion category.\n intensity (float): The intensity of the current emotion (range: 0.0 to 1.0).\n \"\"\"\n emotion_label: str = \"neutral\"\n intensity: float = 0.0 # Optional: Represents the strength of the emotion\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n 'emotion_label': self.emotion_label,\n 'intensity': self.intensity\n }\n\n @staticmethod\n def from_dict(data: Dict[str, Any]):\n return CategoryEmotionState(\n emotion_label=data.get('emotion_label', 'neutral'),\n intensity=data.get('intensity', 0.0)\n )\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion_circumplex","title":"emotion_circumplex
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.emotion_circumplex.CircumplexEmotionState","title":"CircumplexEmotionState
","text":"Represents the emotional state in the Circumplex Model.
Source code in src/aeiva/cognition/emotion/emotion_circumplex.py
class CircumplexEmotionState:\n \"\"\"\n Represents the emotional state in the Circumplex Model.\n \"\"\"\n def __init__(self, valence: float = 0.0, arousal: float = 0.0):\n self.valence = valence # Range: [-1.0, 1.0]\n self.arousal = arousal # Range: [-1.0, 1.0]\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n 'valence': self.valence,\n 'arousal': self.arousal\n }\n\n @staticmethod\n def from_dict(data: Dict[str, Any]):\n return CircumplexEmotionState(\n valence=data.get('valence', 0.0),\n arousal=data.get('arousal', 0.0)\n )\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion_componential","title":"emotion_componential
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.emotion_componential.ComponentialEmotionState","title":"ComponentialEmotionState
dataclass
","text":"Represents the emotional state based on the Componential Model.
Attributes:
Name Type Description emotion_label
str
Current emotion category.
intensity
float
Intensity of the emotion (0.0 to 1.0).
Source code in src/aeiva/cognition/emotion/emotion_componential.py
@dataclass\nclass ComponentialEmotionState:\n \"\"\"\n Represents the emotional state based on the Componential Model.\n\n Attributes:\n emotion_label (str): Current emotion category.\n intensity (float): Intensity of the emotion (0.0 to 1.0).\n \"\"\"\n emotion_label: str = \"neutral\"\n intensity: float = 0.0\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n 'emotion_label': self.emotion_label,\n 'intensity': self.intensity\n }\n\n @staticmethod\n def from_dict(data: Dict[str, Any]):\n return ComponentialEmotionState(\n emotion_label=data.get('emotion_label', 'neutral'),\n intensity=data.get('intensity', 0.0)\n )\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion_hybrid","title":"emotion_hybrid
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.emotion_hybrid.HybridEmotionState","title":"HybridEmotionState
","text":"Represents the emotional state in the Hybrid Categorical-Dimensional Model.
Source code in src/aeiva/cognition/emotion/emotion_hybrid.py
class HybridEmotionState:\n \"\"\"\n Represents the emotional state in the Hybrid Categorical-Dimensional Model.\n \"\"\"\n def __init__(self, emotion_label: str = \"neutral\", valence: float = 0.0, arousal: float = 0.0):\n self.emotion_label = emotion_label # Categorical label\n self.valence = valence # Dimensional valence\n self.arousal = arousal # Dimensional arousal\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n 'emotion_label': self.emotion_label,\n 'valence': self.valence,\n 'arousal': self.arousal\n }\n\n @staticmethod\n def from_dict(data: Dict[str, Any]):\n return HybridEmotionState(\n emotion_label=data.get('emotion_label', 'neutral'),\n valence=data.get('valence', 0.0),\n arousal=data.get('arousal', 0.0)\n )\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion_occ","title":"emotion_occ
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.emotion_occ.OCCEmotionState","title":"OCCEmotionState
","text":"Represents the emotional state in the OCC Appraisal-Based Model.
Source code in src/aeiva/cognition/emotion/emotion_occ.py
class OCCEmotionState:\n \"\"\"\n Represents the emotional state in the OCC Appraisal-Based Model.\n \"\"\"\n def __init__(self, emotion_categories: Dict[str, float] = None):\n \"\"\"\n Initialize the OCC emotion state with emotion categories and their intensities.\n \"\"\"\n # Initialize with zero intensities if not provided\n self.emotion_categories = emotion_categories if emotion_categories else {\n 'joy': 0.0,\n 'sadness': 0.0,\n 'anger': 0.0,\n 'fear': 0.0,\n 'surprise': 0.0,\n 'disgust': 0.0\n }\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n 'emotion_categories': self.emotion_categories\n }\n\n @staticmethod\n def from_dict(data: Dict[str, Any]):\n return OCCEmotionState(\n emotion_categories=data.get('emotion_categories', {})\n )\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion_occ.OCCEmotionState.__init__","title":"__init__(emotion_categories=None)
","text":"Initialize the OCC emotion state with emotion categories and their intensities.
Source code in src/aeiva/cognition/emotion/emotion_occ.py
def __init__(self, emotion_categories: Dict[str, float] = None):\n \"\"\"\n Initialize the OCC emotion state with emotion categories and their intensities.\n \"\"\"\n # Initialize with zero intensities if not provided\n self.emotion_categories = emotion_categories if emotion_categories else {\n 'joy': 0.0,\n 'sadness': 0.0,\n 'anger': 0.0,\n 'fear': 0.0,\n 'surprise': 0.0,\n 'disgust': 0.0\n }\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion_pad","title":"emotion_pad
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.emotion_pad.PADEmotionState","title":"PADEmotionState
","text":"Represents the emotional state in the PAD Model.
Source code in src/aeiva/cognition/emotion/emotion_pad.py
class PADEmotionState:\n \"\"\"\n Represents the emotional state in the PAD Model.\n \"\"\"\n def __init__(self, pleasure: float = 0.0, arousal: float = 0.0, dominance: float = 0.0):\n self.pleasure = pleasure # Range: [-1.0, 1.0]\n self.arousal = arousal # Range: [-1.0, 1.0]\n self.dominance = dominance # Range: [-1.0, 1.0]\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n 'pleasure': self.pleasure,\n 'arousal': self.arousal,\n 'dominance': self.dominance\n }\n\n @staticmethod\n def from_dict(data: Dict[str, Any]):\n return PADEmotionState(\n pleasure=data.get('pleasure', 0.0),\n arousal=data.get('arousal', 0.0),\n dominance=data.get('dominance', 0.0)\n )\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion_plutchik","title":"emotion_plutchik
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.emotion_plutchik.PlutchikEmotionState","title":"PlutchikEmotionState
dataclass
","text":"Represents the emotional state in Plutchik's Wheel of Emotions.
Attributes:
Name Type Description joy
float
Intensity of Joy.
trust
float
Intensity of Trust.
fear
float
Intensity of Fear.
surprise
float
Intensity of Surprise.
sadness
float
Intensity of Sadness.
disgust
float
Intensity of Disgust.
anger
float
Intensity of Anger.
anticipation
float
Intensity of Anticipation.
Source code in src/aeiva/cognition/emotion/emotion_plutchik.py
@dataclass\nclass PlutchikEmotionState:\n \"\"\"\n Represents the emotional state in Plutchik's Wheel of Emotions.\n\n Attributes:\n joy (float): Intensity of Joy.\n trust (float): Intensity of Trust.\n fear (float): Intensity of Fear.\n surprise (float): Intensity of Surprise.\n sadness (float): Intensity of Sadness.\n disgust (float): Intensity of Disgust.\n anger (float): Intensity of Anger.\n anticipation (float): Intensity of Anticipation.\n \"\"\"\n joy: float = 0.0\n trust: float = 0.0\n fear: float = 0.0\n surprise: float = 0.0\n sadness: float = 0.0\n disgust: float = 0.0\n anger: float = 0.0\n anticipation: float = 0.0\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n 'joy': self.joy,\n 'trust': self.trust,\n 'fear': self.fear,\n 'surprise': self.surprise,\n 'sadness': self.sadness,\n 'disgust': self.disgust,\n 'anger': self.anger,\n 'anticipation': self.anticipation\n }\n\n @staticmethod\n def from_dict(data: Dict[str, Any]):\n return PlutchikEmotionState(\n joy=data.get('joy', 0.0),\n trust=data.get('trust', 0.0),\n fear=data.get('fear', 0.0),\n surprise=data.get('surprise', 0.0),\n sadness=data.get('sadness', 0.0),\n disgust=data.get('disgust', 0.0),\n anger=data.get('anger', 0.0),\n anticipation=data.get('anticipation', 0.0)\n )\n
"},{"location":"reference/#src.aeiva.cognition.emotion.exceptions","title":"exceptions
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.exceptions.ConfigurationError","title":"ConfigurationError
","text":" Bases: Exception
Exception raised for errors in the configuration.
Source code in src/aeiva/cognition/emotion/exceptions.py
class ConfigurationError(Exception):\n \"\"\"Exception raised for errors in the configuration.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.exceptions.RegulationError","title":"RegulationError
","text":" Bases: Exception
Exception raised for errors during emotion regulation.
Source code in src/aeiva/cognition/emotion/exceptions.py
class RegulationError(Exception):\n \"\"\"Exception raised for errors during emotion regulation.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.exceptions.UpdateError","title":"UpdateError
","text":" Bases: Exception
Exception raised for errors during emotion state updates.
Source code in src/aeiva/cognition/emotion/exceptions.py
class UpdateError(Exception):\n \"\"\"Exception raised for errors during emotion state updates.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory","title":"memory
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory","title":"memory
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory","title":"Memory
","text":" Bases: ABC
Abstract base class for memory operations in the intelligent agent.
This class defines methods corresponding to different layers of memory processing, such as creating, filtering, grouping, deriving, structuring, skillizing, embedding, and parameterizing memory units.
Source code in src/aeiva/cognition/memory/memory.py
class Memory(ABC):\n \"\"\"\n Abstract base class for memory operations in the intelligent agent.\n\n This class defines methods corresponding to different layers of memory processing,\n such as creating, filtering, grouping, deriving, structuring, skillizing, embedding,\n and parameterizing memory units.\n \"\"\"\n\n def __init__(self, config: Any):\n \"\"\"\n Initialize the Memory system with the provided configuration.\n\n Args:\n config (Any): Configuration settings for the Memory system.\n \"\"\"\n self.config = config\n\n @abstractmethod\n def setup(self) -> None:\n \"\"\"\n Asynchronously set up the Memory system's components.\n\n This method should initialize any necessary components or resources based on the provided configuration.\n\n Raises:\n ConfigurationError: If the configuration is invalid or incomplete.\n \"\"\"\n pass\n\n @abstractmethod\n def create(self, content: Any, **kwargs) -> MemoryUnit:\n \"\"\"\n Creates a new memory unit with the given content and metadata.\n\n Args:\n content (Any): The core content of the memory unit.\n **kwargs: Additional metadata for the memory unit.\n\n Returns:\n MemoryUnit: The created memory unit.\n \"\"\"\n pass\n\n @abstractmethod\n def get(self, unit_id: str) -> MemoryUnit:\n \"\"\"\n Retrieves a memory unit by its unique identifier.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n\n Returns:\n MemoryUnit: The retrieved memory unit.\n \"\"\"\n pass\n\n @abstractmethod\n def update(self, unit_id: str, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a memory unit with the given updates.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n updates (Dict[str, Any]): A dictionary of fields to update.\n \"\"\"\n pass\n\n @abstractmethod\n def delete(self, unit_id: str) -> None:\n \"\"\"\n Deletes a memory unit by its unique identifier.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n pass\n\n @abstractmethod\n def get_all(self) -> List[MemoryUnit]:\n \"\"\"\n Retrieves all memory units.\n\n Returns:\n List[MemoryUnit]: A list of all memory units.\n \"\"\"\n pass\n\n @abstractmethod\n def delete_all(self) -> None:\n \"\"\"\n Deletes all memory units.\n \"\"\"\n pass\n\n @abstractmethod\n def load(self) -> None:\n \"\"\"\n Loads the memory from file. The path is specified in config.\n \"\"\"\n pass\n\n @abstractmethod\n def save(self) -> None:\n \"\"\"\n Save the memory to database or file. The path is specified in config.\n \"\"\"\n pass\n\n @abstractmethod\n def filter(self, criteria: Dict[str, Any]) -> List[MemoryUnit]:\n \"\"\"\n Filters memory units based on the given criteria.\n\n Args:\n criteria (Dict[str, Any]): A dictionary of filter conditions.\n\n Returns:\n List[MemoryUnit]: A list of memory units matching the criteria.\n \"\"\"\n pass\n\n @abstractmethod\n def organize(self, unit_ids: List[str], organize_type: str, metadata: Optional[Dict[str, Any]] = None) -> str:\n \"\"\"\n Groups memory units into a meaningful group.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to group.\n organize_type (str): The type of group (e.g., 'dialogue_session', 'procedure').\n metadata (Optional[Dict[str, Any]]): Additional metadata for the group.\n\n Returns:\n str: A unique identifier for the created group.\n \"\"\"\n pass\n\n # @abstractmethod\n # def derive(self, unit_ids: List[str], derivation_type: str, **kwargs) -> MemoryUnit:\n # \"\"\"\n # Derives a new memory unit from existing ones.\n\n # Args:\n # unit_ids (List[str]): A list of memory unit IDs to derive from.\n # derivation_type (str): The type of derivation (e.g., 'summary', 'transformation').\n # **kwargs: Additional parameters for the derivation process.\n\n # Returns:\n # MemoryUnit: The derived memory unit.\n # \"\"\"\n # pass\n\n @abstractmethod\n def structurize(self, unit_ids: List[str], structure_type: str, **kwargs) -> None:\n \"\"\"\n Structures memory units into a knowledge graph or other structures.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to structurize.\n structure_type (str): The type of structure (e.g., 'knowledge_graph').\n **kwargs: Additional parameters for the structuring process.\n \"\"\"\n pass\n\n @abstractmethod\n def skillize(self, unit_ids: List[str], skill_name: str, **kwargs) -> str:\n \"\"\"\n Converts memory units into a reusable skill.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to skillize.\n skill_name (str): The name of the skill to create.\n **kwargs: Additional parameters for skill creation.\n\n Returns:\n str: The unique identifier of the created skill.\n \"\"\"\n pass\n\n @abstractmethod\n def embed(self, unit_id: str) -> None:\n \"\"\"\n Generates an embedding for a memory unit.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n pass\n\n @abstractmethod\n def parameterize(self, **kwargs) -> None:\n \"\"\"\n Trains a parametric model using the memory data.\n\n Args:\n **kwargs: Additional parameters for the training process.\n \"\"\"\n pass\n\n @abstractmethod\n def retrieve(self, query: Any, retrieve_type: str, **kwargs) -> List[MemoryUnit]:\n \"\"\"\n Asynchronously retrieve data from memory based on a query.\n\n Args:\n query (Any): The query or criteria to retrieve specific memory data.\n retrieve_type (str): The type of retrieval (e.g., 'retrieve_related', 'retrieve_similar').\n **kwargs: Additional parameters for the structuring process.\n\n Returns:\n Any: The retrieved memory data.\n\n Raises:\n RetrievalError: If the retrieval process fails.\n \"\"\"\n pass\n\n def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during memory operations.\n\n This method can be overridden to implement custom error handling logic.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n # Default error handling: log the error\n print(f\"Memory system encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.__init__","title":"__init__(config)
","text":"Initialize the Memory system with the provided configuration.
Parameters:
Name Type Description Default config
Any
Configuration settings for the Memory system.
required Source code in src/aeiva/cognition/memory/memory.py
def __init__(self, config: Any):\n \"\"\"\n Initialize the Memory system with the provided configuration.\n\n Args:\n config (Any): Configuration settings for the Memory system.\n \"\"\"\n self.config = config\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.create","title":"create(content, **kwargs)
abstractmethod
","text":"Creates a new memory unit with the given content and metadata.
Parameters:
Name Type Description Default content
Any
The core content of the memory unit.
required **kwargs
Additional metadata for the memory unit.
{}
Returns:
Name Type Description MemoryUnit
MemoryUnit
The created memory unit.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef create(self, content: Any, **kwargs) -> MemoryUnit:\n \"\"\"\n Creates a new memory unit with the given content and metadata.\n\n Args:\n content (Any): The core content of the memory unit.\n **kwargs: Additional metadata for the memory unit.\n\n Returns:\n MemoryUnit: The created memory unit.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.delete","title":"delete(unit_id)
abstractmethod
","text":"Deletes a memory unit by its unique identifier.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef delete(self, unit_id: str) -> None:\n \"\"\"\n Deletes a memory unit by its unique identifier.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.delete_all","title":"delete_all()
abstractmethod
","text":"Deletes all memory units.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef delete_all(self) -> None:\n \"\"\"\n Deletes all memory units.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.embed","title":"embed(unit_id)
abstractmethod
","text":"Generates an embedding for a memory unit.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef embed(self, unit_id: str) -> None:\n \"\"\"\n Generates an embedding for a memory unit.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.filter","title":"filter(criteria)
abstractmethod
","text":"Filters memory units based on the given criteria.
Parameters:
Name Type Description Default criteria
Dict[str, Any]
A dictionary of filter conditions.
required Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of memory units matching the criteria.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef filter(self, criteria: Dict[str, Any]) -> List[MemoryUnit]:\n \"\"\"\n Filters memory units based on the given criteria.\n\n Args:\n criteria (Dict[str, Any]): A dictionary of filter conditions.\n\n Returns:\n List[MemoryUnit]: A list of memory units matching the criteria.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.get","title":"get(unit_id)
abstractmethod
","text":"Retrieves a memory unit by its unique identifier.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required Returns:
Name Type Description MemoryUnit
MemoryUnit
The retrieved memory unit.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef get(self, unit_id: str) -> MemoryUnit:\n \"\"\"\n Retrieves a memory unit by its unique identifier.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n\n Returns:\n MemoryUnit: The retrieved memory unit.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.get_all","title":"get_all()
abstractmethod
","text":"Retrieves all memory units.
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of all memory units.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef get_all(self) -> List[MemoryUnit]:\n \"\"\"\n Retrieves all memory units.\n\n Returns:\n List[MemoryUnit]: A list of all memory units.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.handle_error","title":"handle_error(error)
","text":"Handle errors that occur during memory operations.
This method can be overridden to implement custom error handling logic.
Parameters:
Name Type Description Default error
Exception
The exception that was raised.
required Source code in src/aeiva/cognition/memory/memory.py
def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during memory operations.\n\n This method can be overridden to implement custom error handling logic.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n # Default error handling: log the error\n print(f\"Memory system encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.load","title":"load()
abstractmethod
","text":"Loads the memory from file. The path is specified in config.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef load(self) -> None:\n \"\"\"\n Loads the memory from file. The path is specified in config.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.organize","title":"organize(unit_ids, organize_type, metadata=None)
abstractmethod
","text":"Groups memory units into a meaningful group.
Parameters:
Name Type Description Default unit_ids
List[str]
A list of memory unit IDs to group.
required organize_type
str
The type of group (e.g., 'dialogue_session', 'procedure').
required metadata
Optional[Dict[str, Any]]
Additional metadata for the group.
None
Returns:
Name Type Description str
str
A unique identifier for the created group.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef organize(self, unit_ids: List[str], organize_type: str, metadata: Optional[Dict[str, Any]] = None) -> str:\n \"\"\"\n Groups memory units into a meaningful group.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to group.\n organize_type (str): The type of group (e.g., 'dialogue_session', 'procedure').\n metadata (Optional[Dict[str, Any]]): Additional metadata for the group.\n\n Returns:\n str: A unique identifier for the created group.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.parameterize","title":"parameterize(**kwargs)
abstractmethod
","text":"Trains a parametric model using the memory data.
Parameters:
Name Type Description Default **kwargs
Additional parameters for the training process.
{}
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef parameterize(self, **kwargs) -> None:\n \"\"\"\n Trains a parametric model using the memory data.\n\n Args:\n **kwargs: Additional parameters for the training process.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.retrieve","title":"retrieve(query, retrieve_type, **kwargs)
abstractmethod
","text":"Asynchronously retrieve data from memory based on a query.
Parameters:
Name Type Description Default query
Any
The query or criteria to retrieve specific memory data.
required retrieve_type
str
The type of retrieval (e.g., 'retrieve_related', 'retrieve_similar').
required **kwargs
Additional parameters for the structuring process.
{}
Returns:
Name Type Description Any
List[MemoryUnit]
The retrieved memory data.
Raises:
Type Description RetrievalError
If the retrieval process fails.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef retrieve(self, query: Any, retrieve_type: str, **kwargs) -> List[MemoryUnit]:\n \"\"\"\n Asynchronously retrieve data from memory based on a query.\n\n Args:\n query (Any): The query or criteria to retrieve specific memory data.\n retrieve_type (str): The type of retrieval (e.g., 'retrieve_related', 'retrieve_similar').\n **kwargs: Additional parameters for the structuring process.\n\n Returns:\n Any: The retrieved memory data.\n\n Raises:\n RetrievalError: If the retrieval process fails.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.save","title":"save()
abstractmethod
","text":"Save the memory to database or file. The path is specified in config.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef save(self) -> None:\n \"\"\"\n Save the memory to database or file. The path is specified in config.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.setup","title":"setup()
abstractmethod
","text":"Asynchronously set up the Memory system's components.
This method should initialize any necessary components or resources based on the provided configuration.
Raises:
Type Description ConfigurationError
If the configuration is invalid or incomplete.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef setup(self) -> None:\n \"\"\"\n Asynchronously set up the Memory system's components.\n\n This method should initialize any necessary components or resources based on the provided configuration.\n\n Raises:\n ConfigurationError: If the configuration is invalid or incomplete.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.skillize","title":"skillize(unit_ids, skill_name, **kwargs)
abstractmethod
","text":"Converts memory units into a reusable skill.
Parameters:
Name Type Description Default unit_ids
List[str]
A list of memory unit IDs to skillize.
required skill_name
str
The name of the skill to create.
required **kwargs
Additional parameters for skill creation.
{}
Returns:
Name Type Description str
str
The unique identifier of the created skill.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef skillize(self, unit_ids: List[str], skill_name: str, **kwargs) -> str:\n \"\"\"\n Converts memory units into a reusable skill.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to skillize.\n skill_name (str): The name of the skill to create.\n **kwargs: Additional parameters for skill creation.\n\n Returns:\n str: The unique identifier of the created skill.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.structurize","title":"structurize(unit_ids, structure_type, **kwargs)
abstractmethod
","text":"Structures memory units into a knowledge graph or other structures.
Parameters:
Name Type Description Default unit_ids
List[str]
A list of memory unit IDs to structurize.
required structure_type
str
The type of structure (e.g., 'knowledge_graph').
required **kwargs
Additional parameters for the structuring process.
{}
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef structurize(self, unit_ids: List[str], structure_type: str, **kwargs) -> None:\n \"\"\"\n Structures memory units into a knowledge graph or other structures.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to structurize.\n structure_type (str): The type of structure (e.g., 'knowledge_graph').\n **kwargs: Additional parameters for the structuring process.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.update","title":"update(unit_id, updates)
abstractmethod
","text":"Updates a memory unit with the given updates.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required updates
Dict[str, Any]
A dictionary of fields to update.
required Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef update(self, unit_id: str, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a memory unit with the given updates.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n updates (Dict[str, Any]): A dictionary of fields to update.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_cleaner","title":"memory_cleaner
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_cleaner.MemoryCleaner","title":"MemoryCleaner
","text":"A class to clean memory units based on various filtering algorithms.
Supported filter types - 'time': Removes memory units older than a specified threshold.
- 'modality': Keeps only memory units matching specified modalities.
- 'type': Keeps only memory units matching specified types.
Source code in src/aeiva/cognition/memory/memory_cleaner.py
class MemoryCleaner:\n \"\"\"\n A class to clean memory units based on various filtering algorithms.\n\n Supported filter types:\n - 'time': Removes memory units older than a specified threshold.\n - 'modality': Keeps only memory units matching specified modalities.\n - 'type': Keeps only memory units matching specified types.\n \"\"\"\n\n def __init__(self):\n \"\"\"\n Initializes the MemoryCleaner.\n\n Currently, no initialization parameters are required.\n \"\"\"\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemoryCleaner without default parameters.\")\n\n def filter(\n self,\n memory_units: List[MemoryUnit],\n filter_type: str,\n **kwargs\n ) -> List[MemoryUnit]:\n \"\"\"\n Filters the provided memory units based on the specified filter type.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be filtered.\n filter_type (str): The type of filtering algorithm to use ('time', 'modality', 'type').\n **kwargs: Additional parameters required for specific filters.\n For 'time' filter:\n - threshold_days (int): Number of days beyond which memory units are removed.\n For 'modality' filter:\n - modalities (List[str]): List of modalities to retain (e.g., ['text', 'image']).\n For 'type' filter:\n - types (List[str]): List of types to retain (e.g., ['dialogue', 'summary']).\n\n Returns:\n List[MemoryUnit]: The list of memory units after filtering.\n\n Raises:\n MemoryCleanerError: If an unknown filter_type is provided or if required parameters are missing.\n \"\"\"\n self.logger.debug(f\"Filtering memory units using filter_type='{filter_type}' with kwargs={kwargs}\")\n try:\n if filter_type == 'time':\n threshold_days = kwargs.get('threshold_days')\n if threshold_days is None:\n self.logger.error(\"Missing 'threshold_days' parameter for time-based filtering.\")\n raise MemoryCleanerError(\"Missing 'threshold_days' parameter for time-based filtering.\")\n return self.filter_by_time(memory_units, threshold_days)\n elif filter_type == 'modality':\n modalities = kwargs.get('modalities')\n if not modalities:\n self.logger.error(\"Missing 'modalities' parameter for modality-based filtering.\")\n raise MemoryCleanerError(\"Missing 'modalities' parameter for modality-based filtering.\")\n return self.filter_by_modality(memory_units, modalities)\n elif filter_type == 'type':\n types = kwargs.get('types')\n if not types:\n self.logger.error(\"Missing 'types' parameter for type-based filtering.\")\n raise MemoryCleanerError(\"Missing 'types' parameter for type-based filtering.\")\n return self.filter_by_type(memory_units, types)\n else:\n self.logger.error(f\"Unknown filter_type: {filter_type}\")\n raise MemoryCleanerError(f\"Unknown filter_type: {filter_type}\")\n except MemoryCleanerError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to filter memory units: {e}\")\n raise MemoryCleanerError(f\"Failed to filter memory units: {e}\")\n # TODO: more filter options\n\n def filter_by_time(self, memory_units: List[MemoryUnit], threshold_days: int) -> List[MemoryUnit]:\n \"\"\"\n Removes memory units older than the specified threshold_days.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be filtered.\n threshold_days (int): Number of days beyond which memory units are removed.\n\n Returns:\n List[MemoryUnit]: The list of memory units after time-based filtering.\n\n Raises:\n MemoryCleanerError: If filtering fails.\n \"\"\"\n self.logger.debug(f\"Applying time-based filtering with threshold_days={threshold_days}\")\n try:\n current_time = datetime.now(UTC)\n threshold = timedelta(days=threshold_days)\n filtered_memory = [\n mu for mu in memory_units\n if (current_time - mu.timestamp) <= threshold\n ]\n removed_count = len(memory_units) - len(filtered_memory)\n self.logger.info(\n f\"Time-based filter: Removed {removed_count} memory units older than {threshold_days} days.\"\n )\n return filtered_memory\n except Exception as e:\n self.logger.error(f\"Time-based filtering failed: {e}\")\n raise MemoryCleanerError(f\"Time-based filtering failed: {e}\")\n\n def filter_by_modality(self, memory_units: List[MemoryUnit], modalities: List[str]) -> List[MemoryUnit]:\n \"\"\"\n Keeps only memory units that match the specified modalities.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be filtered.\n modalities (List[str]): List of modalities to retain (e.g., ['text', 'image']).\n\n Returns:\n List[MemoryUnit]: The list of memory units after modality-based filtering.\n\n Raises:\n MemoryCleanerError: If filtering fails.\n \"\"\"\n self.logger.debug(f\"Applying modality-based filtering with modalities={modalities}\")\n try:\n if not modalities:\n self.logger.warning(\"No modalities specified for modality-based filtering. Returning original memory units.\")\n return memory_units\n\n filtered_memory = [\n mu for mu in memory_units\n if mu.modality in modalities\n ]\n removed_count = len(memory_units) - len(filtered_memory)\n self.logger.info(\n f\"Modality-based filter: Removed {removed_count} memory units not in modalities {modalities}.\"\n )\n return filtered_memory\n except Exception as e:\n self.logger.error(f\"Modality-based filtering failed: {e}\")\n raise MemoryCleanerError(f\"Modality-based filtering failed: {e}\")\n\n def filter_by_type(self, memory_units: List[MemoryUnit], types: List[str]) -> List[MemoryUnit]:\n \"\"\"\n Keeps only memory units that match the specified types.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be filtered.\n types (List[str]): List of types to retain (e.g., ['dialogue', 'summary']).\n\n Returns:\n List[MemoryUnit]: The list of memory units after type-based filtering.\n\n Raises:\n MemoryCleanerError: If filtering fails.\n \"\"\"\n self.logger.debug(f\"Applying type-based filtering with types={types}\")\n try:\n if not types:\n self.logger.warning(\"No types specified for type-based filtering. Returning original memory units.\")\n return memory_units\n\n filtered_memory = [\n mu for mu in memory_units\n if mu.type in types\n ]\n removed_count = len(memory_units) - len(filtered_memory)\n self.logger.info(\n f\"Type-based filter: Removed {removed_count} memory units not in types {types}.\"\n )\n return filtered_memory\n except Exception as e:\n self.logger.error(f\"Type-based filtering failed: {e}\")\n raise MemoryCleanerError(f\"Type-based filtering failed: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_cleaner.MemoryCleaner.__init__","title":"__init__()
","text":"Initializes the MemoryCleaner.
Currently, no initialization parameters are required.
Source code in src/aeiva/cognition/memory/memory_cleaner.py
def __init__(self):\n \"\"\"\n Initializes the MemoryCleaner.\n\n Currently, no initialization parameters are required.\n \"\"\"\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemoryCleaner without default parameters.\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_cleaner.MemoryCleaner.filter","title":"filter(memory_units, filter_type, **kwargs)
","text":"Filters the provided memory units based on the specified filter type.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be filtered.
required filter_type
str
The type of filtering algorithm to use ('time', 'modality', 'type').
required **kwargs
Additional parameters required for specific filters. For 'time' filter: - threshold_days (int): Number of days beyond which memory units are removed. For 'modality' filter: - modalities (List[str]): List of modalities to retain (e.g., ['text', 'image']). For 'type' filter: - types (List[str]): List of types to retain (e.g., ['dialogue', 'summary']).
{}
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The list of memory units after filtering.
Raises:
Type Description MemoryCleanerError
If an unknown filter_type is provided or if required parameters are missing.
Source code in src/aeiva/cognition/memory/memory_cleaner.py
def filter(\n self,\n memory_units: List[MemoryUnit],\n filter_type: str,\n **kwargs\n) -> List[MemoryUnit]:\n \"\"\"\n Filters the provided memory units based on the specified filter type.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be filtered.\n filter_type (str): The type of filtering algorithm to use ('time', 'modality', 'type').\n **kwargs: Additional parameters required for specific filters.\n For 'time' filter:\n - threshold_days (int): Number of days beyond which memory units are removed.\n For 'modality' filter:\n - modalities (List[str]): List of modalities to retain (e.g., ['text', 'image']).\n For 'type' filter:\n - types (List[str]): List of types to retain (e.g., ['dialogue', 'summary']).\n\n Returns:\n List[MemoryUnit]: The list of memory units after filtering.\n\n Raises:\n MemoryCleanerError: If an unknown filter_type is provided or if required parameters are missing.\n \"\"\"\n self.logger.debug(f\"Filtering memory units using filter_type='{filter_type}' with kwargs={kwargs}\")\n try:\n if filter_type == 'time':\n threshold_days = kwargs.get('threshold_days')\n if threshold_days is None:\n self.logger.error(\"Missing 'threshold_days' parameter for time-based filtering.\")\n raise MemoryCleanerError(\"Missing 'threshold_days' parameter for time-based filtering.\")\n return self.filter_by_time(memory_units, threshold_days)\n elif filter_type == 'modality':\n modalities = kwargs.get('modalities')\n if not modalities:\n self.logger.error(\"Missing 'modalities' parameter for modality-based filtering.\")\n raise MemoryCleanerError(\"Missing 'modalities' parameter for modality-based filtering.\")\n return self.filter_by_modality(memory_units, modalities)\n elif filter_type == 'type':\n types = kwargs.get('types')\n if not types:\n self.logger.error(\"Missing 'types' parameter for type-based filtering.\")\n raise MemoryCleanerError(\"Missing 'types' parameter for type-based filtering.\")\n return self.filter_by_type(memory_units, types)\n else:\n self.logger.error(f\"Unknown filter_type: {filter_type}\")\n raise MemoryCleanerError(f\"Unknown filter_type: {filter_type}\")\n except MemoryCleanerError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to filter memory units: {e}\")\n raise MemoryCleanerError(f\"Failed to filter memory units: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_cleaner.MemoryCleaner.filter_by_modality","title":"filter_by_modality(memory_units, modalities)
","text":"Keeps only memory units that match the specified modalities.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be filtered.
required modalities
List[str]
List of modalities to retain (e.g., ['text', 'image']).
required Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The list of memory units after modality-based filtering.
Raises:
Type Description MemoryCleanerError
If filtering fails.
Source code in src/aeiva/cognition/memory/memory_cleaner.py
def filter_by_modality(self, memory_units: List[MemoryUnit], modalities: List[str]) -> List[MemoryUnit]:\n \"\"\"\n Keeps only memory units that match the specified modalities.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be filtered.\n modalities (List[str]): List of modalities to retain (e.g., ['text', 'image']).\n\n Returns:\n List[MemoryUnit]: The list of memory units after modality-based filtering.\n\n Raises:\n MemoryCleanerError: If filtering fails.\n \"\"\"\n self.logger.debug(f\"Applying modality-based filtering with modalities={modalities}\")\n try:\n if not modalities:\n self.logger.warning(\"No modalities specified for modality-based filtering. Returning original memory units.\")\n return memory_units\n\n filtered_memory = [\n mu for mu in memory_units\n if mu.modality in modalities\n ]\n removed_count = len(memory_units) - len(filtered_memory)\n self.logger.info(\n f\"Modality-based filter: Removed {removed_count} memory units not in modalities {modalities}.\"\n )\n return filtered_memory\n except Exception as e:\n self.logger.error(f\"Modality-based filtering failed: {e}\")\n raise MemoryCleanerError(f\"Modality-based filtering failed: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_cleaner.MemoryCleaner.filter_by_time","title":"filter_by_time(memory_units, threshold_days)
","text":"Removes memory units older than the specified threshold_days.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be filtered.
required threshold_days
int
Number of days beyond which memory units are removed.
required Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The list of memory units after time-based filtering.
Raises:
Type Description MemoryCleanerError
If filtering fails.
Source code in src/aeiva/cognition/memory/memory_cleaner.py
def filter_by_time(self, memory_units: List[MemoryUnit], threshold_days: int) -> List[MemoryUnit]:\n \"\"\"\n Removes memory units older than the specified threshold_days.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be filtered.\n threshold_days (int): Number of days beyond which memory units are removed.\n\n Returns:\n List[MemoryUnit]: The list of memory units after time-based filtering.\n\n Raises:\n MemoryCleanerError: If filtering fails.\n \"\"\"\n self.logger.debug(f\"Applying time-based filtering with threshold_days={threshold_days}\")\n try:\n current_time = datetime.now(UTC)\n threshold = timedelta(days=threshold_days)\n filtered_memory = [\n mu for mu in memory_units\n if (current_time - mu.timestamp) <= threshold\n ]\n removed_count = len(memory_units) - len(filtered_memory)\n self.logger.info(\n f\"Time-based filter: Removed {removed_count} memory units older than {threshold_days} days.\"\n )\n return filtered_memory\n except Exception as e:\n self.logger.error(f\"Time-based filtering failed: {e}\")\n raise MemoryCleanerError(f\"Time-based filtering failed: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_cleaner.MemoryCleaner.filter_by_type","title":"filter_by_type(memory_units, types)
","text":"Keeps only memory units that match the specified types.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be filtered.
required types
List[str]
List of types to retain (e.g., ['dialogue', 'summary']).
required Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The list of memory units after type-based filtering.
Raises:
Type Description MemoryCleanerError
If filtering fails.
Source code in src/aeiva/cognition/memory/memory_cleaner.py
def filter_by_type(self, memory_units: List[MemoryUnit], types: List[str]) -> List[MemoryUnit]:\n \"\"\"\n Keeps only memory units that match the specified types.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be filtered.\n types (List[str]): List of types to retain (e.g., ['dialogue', 'summary']).\n\n Returns:\n List[MemoryUnit]: The list of memory units after type-based filtering.\n\n Raises:\n MemoryCleanerError: If filtering fails.\n \"\"\"\n self.logger.debug(f\"Applying type-based filtering with types={types}\")\n try:\n if not types:\n self.logger.warning(\"No types specified for type-based filtering. Returning original memory units.\")\n return memory_units\n\n filtered_memory = [\n mu for mu in memory_units\n if mu.type in types\n ]\n removed_count = len(memory_units) - len(filtered_memory)\n self.logger.info(\n f\"Type-based filter: Removed {removed_count} memory units not in types {types}.\"\n )\n return filtered_memory\n except Exception as e:\n self.logger.error(f\"Type-based filtering failed: {e}\")\n raise MemoryCleanerError(f\"Type-based filtering failed: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_cleaner.MemoryCleanerError","title":"MemoryCleanerError
","text":" Bases: Exception
Exception raised when an error occurs in the MemoryCleaner.
Source code in src/aeiva/cognition/memory/memory_cleaner.py
class MemoryCleanerError(Exception):\n \"\"\"Exception raised when an error occurs in the MemoryCleaner.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_config","title":"memory_config
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_config.MemoryConfig","title":"MemoryConfig
dataclass
","text":" Bases: BaseConfig
Configuration class for the Memory system.
Attributes:
Name Type Description embedder_config
EmbedderConfig
Configuration for the embedding model.
storage_config
StorageConfig
Configuration for the storage system.
Source code in src/aeiva/cognition/memory/memory_config.py
@dataclass\nclass MemoryConfig(BaseConfig):\n \"\"\"\n Configuration class for the Memory system.\n\n Attributes:\n embedder_config (EmbedderConfig): Configuration for the embedding model.\n storage_config (StorageConfig): Configuration for the storage system.\n \"\"\"\n\n embedder_config: EmbedderConfig = field(\n metadata={\"help\": \"Configuration for the embedding model.\"}\n )\n storage_config: StorageConfig = field(\n metadata={\"help\": \"Configuration for the storage system.\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n # Perform any necessary validation\n if not self.embedder_config:\n raise ValueError(\"Embedder configuration must be provided.\")\n if not self.storage_config:\n raise ValueError(\"Storage configuration must be provided.\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_link","title":"memory_link
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_link.MemoryLink","title":"MemoryLink
","text":" Bases: BaseModel
MemoryLink represents a relationship between two memory units, allowing complex structures to be built by linking individual memory units.
Attributes:
Name Type Description id
str
Unique identifier for the edge, generated as a UUID string by default.
source_id
str
Unique identifier of the source memory unit.
target_id
str
Unique identifier of the target memory unit.
relationship
str
Type of relationship between memory units, such as 'causal' or 'association'.
metadata
Optional[Dict[str, Any]]
Additional metadata for the edge.
Source code in src/aeiva/cognition/memory/memory_link.py
class MemoryLink(BaseModel):\n \"\"\"\n MemoryLink represents a relationship between two memory units, allowing\n complex structures to be built by linking individual memory units.\n\n Attributes:\n id (str): Unique identifier for the edge, generated as a UUID string by default.\n source_id (str): Unique identifier of the source memory unit.\n target_id (str): Unique identifier of the target memory unit.\n relationship (str): Type of relationship between memory units, such as 'causal' or 'association'.\n metadata (Optional[Dict[str, Any]]): Additional metadata for the edge.\n \"\"\"\n id: str = Field(default_factory=lambda: uuid4().hex, description=\"Unique identifier for the edge.\")\n source_id: str = Field(..., description=\"Unique identifier of the source memory unit.\")\n target_id: str = Field(..., description=\"Unique identifier of the target memory unit.\")\n relationship: str = Field(\"\", description=\"Type of relationship, e.g., 'causal', 'temporal'.\")\n metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description=\"Additional metadata for the edge.\")\n\n def to_dict(self) -> dict:\n \"\"\"Converts the MemoryLink instance to a dictionary format for serialization.\"\"\"\n return self.dict()\n\n @classmethod\n def from_dict(cls, data: dict) -> \"MemoryLink\":\n \"\"\"Creates a MemoryLink instance from a dictionary.\"\"\"\n return cls(**data)\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_link.MemoryLink.from_dict","title":"from_dict(data)
classmethod
","text":"Creates a MemoryLink instance from a dictionary.
Source code in src/aeiva/cognition/memory/memory_link.py
@classmethod\ndef from_dict(cls, data: dict) -> \"MemoryLink\":\n \"\"\"Creates a MemoryLink instance from a dictionary.\"\"\"\n return cls(**data)\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_link.MemoryLink.to_dict","title":"to_dict()
","text":"Converts the MemoryLink instance to a dictionary format for serialization.
Source code in src/aeiva/cognition/memory/memory_link.py
def to_dict(self) -> dict:\n \"\"\"Converts the MemoryLink instance to a dictionary format for serialization.\"\"\"\n return self.dict()\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_organizer","title":"memory_organizer
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_organizer.MemoryOrganizer","title":"MemoryOrganizer
","text":"A class to organize memory units based on various organizing algorithms.
Supported organize types - 'dialogue': Groups memory units by 'dialogue_session_id'.
Source code in src/aeiva/cognition/memory/memory_organizer.py
class MemoryOrganizer:\n \"\"\"\n A class to organize memory units based on various organizing algorithms.\n\n Supported organize types:\n - 'dialogue': Groups memory units by 'dialogue_session_id'.\n # Future organize types can be added here.\n \"\"\"\n\n def __init__(self):\n \"\"\"\n Initializes the MemoryOrganizer.\n\n Currently, no initialization parameters are required.\n \"\"\"\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemoryOrganizer without default parameters.\")\n\n def organize(\n self,\n memory_units: List[MemoryUnit],\n organize_type: str,\n **kwargs\n ) -> List[MemoryUnit]:\n \"\"\"\n Organizes the provided memory units based on the specified organize type.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be organized.\n organize_type (str): The type of organizing algorithm to use ('dialogue').\n **kwargs: Additional parameters required for specific organizers.\n For 'dialogue' organize:\n - group_field (str): The metadata field to group by (default: 'dialogue_session_id').\n - derive_content (bool): Whether to derive content for the group (default: True).\n - derivation_type (str): The type of derivation to perform ('summary', etc.).\n\n Returns:\n List[MemoryUnit]: The list of memory units after organizing.\n\n Raises:\n MemoryOrganizerError: If an unknown organize_type is provided or if required parameters are missing.\n \"\"\"\n self.logger.debug(f\"Organizing memory units using organize_type='{organize_type}' with kwargs={kwargs}\")\n try:\n if organize_type == 'dialogue':\n group_field = kwargs.get('group_field', 'dialogue_session_id')\n derive_content = kwargs.get('derive_content', True)\n derivation_type = kwargs.get('derivation_type', 'summary')\n return self.organize_by_dialogue(memory_units, group_field, derive_content, derivation_type)\n else:\n self.logger.error(f\"Unknown organize_type: {organize_type}\")\n raise MemoryOrganizerError(f\"Unknown organize_type: {organize_type}\")\n except MemoryOrganizerError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to organize memory units: {e}\")\n raise MemoryOrganizerError(f\"Failed to organize memory units: {e}\")\n\n def organize_by_dialogue(\n self,\n memory_units: List[MemoryUnit],\n group_field: str = 'dialogue_session_id', # NOTE: here we assume the meta data field of dialogue memory units has a dialogue_session_id\n derive_content: bool = False,\n derivation_type: str = 'summary'\n ) -> List[MemoryUnit]:\n \"\"\"\n Organizes memory units into dialogue sessions based on a common 'dialogue_session_id'.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be organized.\n group_field (str): The metadata field to group by (default: 'dialogue_session_id').\n derive_content (bool): Whether to derive content for the group (default: True).\n derivation_type (str): The type of derivation to perform ('summary', etc.).\n\n Returns:\n List[MemoryUnit]: The list of memory units after organizing, including new dialogue groups.\n\n Raises:\n MemoryOrganizerError: If organizing fails.\n \"\"\"\n self.logger.debug(f\"Organizing by dialogue with group_field='{group_field}', derive_content={derive_content}, derivation_type='{derivation_type}'\")\n try:\n # Group memory units by the specified group_field\n groups = defaultdict(list)\n for mu in memory_units:\n group_id = mu.metadata.get(group_field)\n if group_id:\n groups[group_id].append(mu)\n else:\n self.logger.debug(f\"MemoryUnit '{mu.id}' does not have '{group_field}'. Skipping grouping.\")\n\n self.logger.info(f\"Found {len(groups)} dialogue groups based on '{group_field}'.\")\n\n # Create new MemoryUnit for each group\n new_memory_units = []\n for group_id, group_mus in groups.items():\n self.logger.debug(f\"Creating DialogueGroup for group_id='{group_id}' with {len(group_mus)} memory units.\")\n\n # Create a new MemoryUnit to represent the DialogueGroup\n dialogue_group = MemoryUnit(\n content=\"\", # Content to be derived\n type=\"dialogue_session\",\n metadata={\n \"organized_at\": datetime.now(timezone.utc).isoformat(),\n \"member_ids\": [mu.id for mu in group_mus],\n \"derivation_type\": derivation_type\n }\n )\n\n # Link each memory unit to the DialogueGroup\n for mu in group_mus:\n link = MemoryLink(\n source_id=mu.id,\n target_id=dialogue_group.id,\n relationship='part_of'\n )\n mu.edges.append(link)\n self.logger.debug(f\"Linked MemoryUnit '{mu.id}' to DialogueGroup '{dialogue_group.id}'.\")\n\n # Optionally, derive content for the group\n if derive_content:\n if derivation_type == 'summary':\n derived_content = self.derive_summary(group_mus)\n elif derivation_type == 'reflection':\n derived_content = self.derive_reflection(group_mus)\n else:\n self.logger.warning(f\"Unknown derivation_type '{derivation_type}'. Skipping content derivation.\")\n derived_content = \"\"\n dialogue_group.content = derived_content\n dialogue_group.status = 'derived'\n self.logger.debug(f\"Derived content for DialogueGroup '{dialogue_group.id}': {derived_content}\")\n\n new_memory_units.append(dialogue_group)\n self.logger.info(f\"DialogueGroup '{dialogue_group.id}' created for group_id='{group_id}'.\")\n\n # Return the original memory units plus the new dialogue groups\n organized_memory = memory_units + new_memory_units\n self.logger.debug(f\"Organizing by dialogue completed. Total memory units after organizing: {len(organized_memory)}\")\n return organized_memory\n\n except Exception as e:\n self.logger.error(f\"Error organizing by dialogue: {e}\")\n raise MemoryOrganizerError(f\"Error organizing by dialogue: {e}\")\n\n def derive_summary(self, memory_units: List[MemoryUnit]) -> str: # TODO: replace with lmp implementation\n \"\"\"\n Derives a summary from the given memory units.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to summarize.\n\n Returns:\n str: A summary string.\n \"\"\"\n self.logger.debug(f\"Deriving summary from {len(memory_units)} memory units.\")\n try:\n summary = \"Summary of dialogue session:\\n\"\n for mu in memory_units:\n summary += f\"- {mu.content}\\n\"\n derived_summary = summary.strip()\n self.logger.debug(f\"Derived summary: {derived_summary}\")\n return derived_summary\n except Exception as e:\n self.logger.error(f\"Failed to derive summary: {e}\")\n raise MemoryOrganizerError(f\"Failed to derive summary: {e}\")\n\n def derive_reflection(self, memory_units: List[MemoryUnit]) -> str: # TODO: replace with lmp implementation\n \"\"\"\n Derives a reflection from the given memory units.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to reflect upon.\n\n Returns:\n str: A reflection string.\n \"\"\"\n self.logger.debug(f\"Deriving reflection from {len(memory_units)} memory units.\")\n try:\n reflection = \"Reflection on dialogue session:\\n\"\n for mu in memory_units:\n reflection += f\"- {mu.content}\\n\"\n derived_reflection = reflection.strip()\n self.logger.debug(f\"Derived reflection: {derived_reflection}\")\n return derived_reflection\n except Exception as e:\n self.logger.error(f\"Failed to derive reflection: {e}\")\n raise MemoryOrganizerError(f\"Failed to derive reflection: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_organizer.MemoryOrganizer--future-organize-types-can-be-added-here","title":"Future organize types can be added here.","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_organizer.MemoryOrganizer.__init__","title":"__init__()
","text":"Initializes the MemoryOrganizer.
Currently, no initialization parameters are required.
Source code in src/aeiva/cognition/memory/memory_organizer.py
def __init__(self):\n \"\"\"\n Initializes the MemoryOrganizer.\n\n Currently, no initialization parameters are required.\n \"\"\"\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemoryOrganizer without default parameters.\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_organizer.MemoryOrganizer.derive_reflection","title":"derive_reflection(memory_units)
","text":"Derives a reflection from the given memory units.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to reflect upon.
required Returns:
Name Type Description str
str
A reflection string.
Source code in src/aeiva/cognition/memory/memory_organizer.py
def derive_reflection(self, memory_units: List[MemoryUnit]) -> str: # TODO: replace with lmp implementation\n \"\"\"\n Derives a reflection from the given memory units.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to reflect upon.\n\n Returns:\n str: A reflection string.\n \"\"\"\n self.logger.debug(f\"Deriving reflection from {len(memory_units)} memory units.\")\n try:\n reflection = \"Reflection on dialogue session:\\n\"\n for mu in memory_units:\n reflection += f\"- {mu.content}\\n\"\n derived_reflection = reflection.strip()\n self.logger.debug(f\"Derived reflection: {derived_reflection}\")\n return derived_reflection\n except Exception as e:\n self.logger.error(f\"Failed to derive reflection: {e}\")\n raise MemoryOrganizerError(f\"Failed to derive reflection: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_organizer.MemoryOrganizer.derive_summary","title":"derive_summary(memory_units)
","text":"Derives a summary from the given memory units.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to summarize.
required Returns:
Name Type Description str
str
A summary string.
Source code in src/aeiva/cognition/memory/memory_organizer.py
def derive_summary(self, memory_units: List[MemoryUnit]) -> str: # TODO: replace with lmp implementation\n \"\"\"\n Derives a summary from the given memory units.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to summarize.\n\n Returns:\n str: A summary string.\n \"\"\"\n self.logger.debug(f\"Deriving summary from {len(memory_units)} memory units.\")\n try:\n summary = \"Summary of dialogue session:\\n\"\n for mu in memory_units:\n summary += f\"- {mu.content}\\n\"\n derived_summary = summary.strip()\n self.logger.debug(f\"Derived summary: {derived_summary}\")\n return derived_summary\n except Exception as e:\n self.logger.error(f\"Failed to derive summary: {e}\")\n raise MemoryOrganizerError(f\"Failed to derive summary: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_organizer.MemoryOrganizer.organize","title":"organize(memory_units, organize_type, **kwargs)
","text":"Organizes the provided memory units based on the specified organize type.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be organized.
required organize_type
str
The type of organizing algorithm to use ('dialogue').
required **kwargs
Additional parameters required for specific organizers. For 'dialogue' organize: - group_field (str): The metadata field to group by (default: 'dialogue_session_id'). - derive_content (bool): Whether to derive content for the group (default: True). - derivation_type (str): The type of derivation to perform ('summary', etc.).
{}
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The list of memory units after organizing.
Raises:
Type Description MemoryOrganizerError
If an unknown organize_type is provided or if required parameters are missing.
Source code in src/aeiva/cognition/memory/memory_organizer.py
def organize(\n self,\n memory_units: List[MemoryUnit],\n organize_type: str,\n **kwargs\n) -> List[MemoryUnit]:\n \"\"\"\n Organizes the provided memory units based on the specified organize type.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be organized.\n organize_type (str): The type of organizing algorithm to use ('dialogue').\n **kwargs: Additional parameters required for specific organizers.\n For 'dialogue' organize:\n - group_field (str): The metadata field to group by (default: 'dialogue_session_id').\n - derive_content (bool): Whether to derive content for the group (default: True).\n - derivation_type (str): The type of derivation to perform ('summary', etc.).\n\n Returns:\n List[MemoryUnit]: The list of memory units after organizing.\n\n Raises:\n MemoryOrganizerError: If an unknown organize_type is provided or if required parameters are missing.\n \"\"\"\n self.logger.debug(f\"Organizing memory units using organize_type='{organize_type}' with kwargs={kwargs}\")\n try:\n if organize_type == 'dialogue':\n group_field = kwargs.get('group_field', 'dialogue_session_id')\n derive_content = kwargs.get('derive_content', True)\n derivation_type = kwargs.get('derivation_type', 'summary')\n return self.organize_by_dialogue(memory_units, group_field, derive_content, derivation_type)\n else:\n self.logger.error(f\"Unknown organize_type: {organize_type}\")\n raise MemoryOrganizerError(f\"Unknown organize_type: {organize_type}\")\n except MemoryOrganizerError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to organize memory units: {e}\")\n raise MemoryOrganizerError(f\"Failed to organize memory units: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_organizer.MemoryOrganizer.organize_by_dialogue","title":"organize_by_dialogue(memory_units, group_field='dialogue_session_id', derive_content=False, derivation_type='summary')
","text":"Organizes memory units into dialogue sessions based on a common 'dialogue_session_id'.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be organized.
required group_field
str
The metadata field to group by (default: 'dialogue_session_id').
'dialogue_session_id'
derive_content
bool
Whether to derive content for the group (default: True).
False
derivation_type
str
The type of derivation to perform ('summary', etc.).
'summary'
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The list of memory units after organizing, including new dialogue groups.
Raises:
Type Description MemoryOrganizerError
If organizing fails.
Source code in src/aeiva/cognition/memory/memory_organizer.py
def organize_by_dialogue(\n self,\n memory_units: List[MemoryUnit],\n group_field: str = 'dialogue_session_id', # NOTE: here we assume the meta data field of dialogue memory units has a dialogue_session_id\n derive_content: bool = False,\n derivation_type: str = 'summary'\n) -> List[MemoryUnit]:\n \"\"\"\n Organizes memory units into dialogue sessions based on a common 'dialogue_session_id'.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be organized.\n group_field (str): The metadata field to group by (default: 'dialogue_session_id').\n derive_content (bool): Whether to derive content for the group (default: True).\n derivation_type (str): The type of derivation to perform ('summary', etc.).\n\n Returns:\n List[MemoryUnit]: The list of memory units after organizing, including new dialogue groups.\n\n Raises:\n MemoryOrganizerError: If organizing fails.\n \"\"\"\n self.logger.debug(f\"Organizing by dialogue with group_field='{group_field}', derive_content={derive_content}, derivation_type='{derivation_type}'\")\n try:\n # Group memory units by the specified group_field\n groups = defaultdict(list)\n for mu in memory_units:\n group_id = mu.metadata.get(group_field)\n if group_id:\n groups[group_id].append(mu)\n else:\n self.logger.debug(f\"MemoryUnit '{mu.id}' does not have '{group_field}'. Skipping grouping.\")\n\n self.logger.info(f\"Found {len(groups)} dialogue groups based on '{group_field}'.\")\n\n # Create new MemoryUnit for each group\n new_memory_units = []\n for group_id, group_mus in groups.items():\n self.logger.debug(f\"Creating DialogueGroup for group_id='{group_id}' with {len(group_mus)} memory units.\")\n\n # Create a new MemoryUnit to represent the DialogueGroup\n dialogue_group = MemoryUnit(\n content=\"\", # Content to be derived\n type=\"dialogue_session\",\n metadata={\n \"organized_at\": datetime.now(timezone.utc).isoformat(),\n \"member_ids\": [mu.id for mu in group_mus],\n \"derivation_type\": derivation_type\n }\n )\n\n # Link each memory unit to the DialogueGroup\n for mu in group_mus:\n link = MemoryLink(\n source_id=mu.id,\n target_id=dialogue_group.id,\n relationship='part_of'\n )\n mu.edges.append(link)\n self.logger.debug(f\"Linked MemoryUnit '{mu.id}' to DialogueGroup '{dialogue_group.id}'.\")\n\n # Optionally, derive content for the group\n if derive_content:\n if derivation_type == 'summary':\n derived_content = self.derive_summary(group_mus)\n elif derivation_type == 'reflection':\n derived_content = self.derive_reflection(group_mus)\n else:\n self.logger.warning(f\"Unknown derivation_type '{derivation_type}'. Skipping content derivation.\")\n derived_content = \"\"\n dialogue_group.content = derived_content\n dialogue_group.status = 'derived'\n self.logger.debug(f\"Derived content for DialogueGroup '{dialogue_group.id}': {derived_content}\")\n\n new_memory_units.append(dialogue_group)\n self.logger.info(f\"DialogueGroup '{dialogue_group.id}' created for group_id='{group_id}'.\")\n\n # Return the original memory units plus the new dialogue groups\n organized_memory = memory_units + new_memory_units\n self.logger.debug(f\"Organizing by dialogue completed. Total memory units after organizing: {len(organized_memory)}\")\n return organized_memory\n\n except Exception as e:\n self.logger.error(f\"Error organizing by dialogue: {e}\")\n raise MemoryOrganizerError(f\"Error organizing by dialogue: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_organizer.MemoryOrganizerError","title":"MemoryOrganizerError
","text":" Bases: Exception
Exception raised when an error occurs in the MemoryOrganizer.
Source code in src/aeiva/cognition/memory/memory_organizer.py
class MemoryOrganizerError(Exception):\n \"\"\"Exception raised when an error occurs in the MemoryOrganizer.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace","title":"memory_palace
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace","title":"MemoryPalace
","text":" Bases: Memory
Concrete implementation of the Memory abstract base class.
This class provides methods to manage memory units, including creation, retrieval, updating, deletion, filtering, grouping, structurizing, skillizing, parameterizing, and more. It delegates specific operations to specialized components like MemoryCleaner, MemoryOrganizer, MemoryRetriever, MemoryStructurer, MemorySkillizer, and MemoryParameterizer.
Source code in src/aeiva/cognition/memory/memory_palace.py
class MemoryPalace(Memory):\n \"\"\"\n Concrete implementation of the Memory abstract base class.\n\n This class provides methods to manage memory units, including creation, retrieval,\n updating, deletion, filtering, grouping, structurizing, skillizing, parameterizing,\n and more. It delegates specific operations to specialized components like\n MemoryCleaner, MemoryOrganizer, MemoryRetriever, MemoryStructurer, MemorySkillizer,\n and MemoryParameterizer.\n \"\"\"\n\n def __init__(self, config: Dict):\n \"\"\"\n Initialize the MemoryPalace with the provided configuration.\n\n Args:\n config (MemoryConfig): Configuration settings for the MemoryPalace.\n \"\"\"\n self.config_dict = config\n self.config = None\n self.storage = None\n self.embedder = None\n self.cleaner = None\n self.organizer = None\n self.retriever = None\n self.structurer = None\n self.skillizer = None\n self.parameterizer = None\n self.setup()\n\n def setup(self):\n \"\"\"\n Setup the MemoryPalace by initializing all components.\n \"\"\"\n try:\n # Initialize EmbedderConfig\n embedder_config_dict = self.config_dict.get('embedder_config', {})\n self.embedder = Embedder(embedder_config_dict)\n\n storage_config_dict = self.config_dict.get('storage_config', {})\n self.storage = MemoryStorage(storage_config_dict) \n\n # Initialize Memory Configuration\n self.config = MemoryConfig(\n embedder_config=self.embedder.config,\n storage_config=self.storage.config\n )\n\n logger.info(\"MemoryPalace: MemoryStorage and Embedder initialized successfully.\")\n\n # Initialize specialized components\n self.cleaner = MemoryCleaner()\n self.organizer = MemoryOrganizer()\n self.retriever = MemoryRetriever(embedder=self.embedder, storage=self.storage)\n self.structurer = MemoryStructurer()\n self.skillizer = MemorySkillizer()\n self.parameterizer = MemoryParameterizer()\n logger.info(\"MemoryPalace: Specialized components initialized successfully.\")\n\n except Exception as e:\n logger.error(f\"MemoryPalace setup failed: {e}\")\n self.handle_error(e)\n raise\n\n # CRUD Operations\n\n def create(self, content: Any, **kwargs) -> MemoryUnit:\n \"\"\"\n Creates a new memory unit with the given content and metadata.\n\n Args:\n content (Any): The core content of the memory unit.\n **kwargs: Additional metadata for the memory unit.\n\n Returns:\n MemoryUnit: The created memory unit.\n \"\"\"\n try:\n # Instantiate MemoryUnit\n memory_unit = MemoryUnit(content=content, **kwargs)\n\n # Generate embedding\n embedding_response = self.embedder.embed(content)\n if embedding_response.get(\"data\"):\n memory_unit.embedding = embedding_response[\"data\"][0].get(\"embedding\")\n else:\n raise ValueError(\"Failed to generate embedding for the content.\")\n\n # Delegate storage operations to MemoryStorage\n self.storage.add_memory_unit(memory_unit)\n\n logger.info(f\"Created new MemoryUnit with ID: {memory_unit.id}\")\n return memory_unit\n except Exception as e:\n logger.error(f\"Error creating MemoryUnit: {e}\")\n self.handle_error(e)\n raise\n\n def get(self, unit_id: str) -> MemoryUnit:\n \"\"\"\n Retrieves a memory unit by its unique identifier.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n\n Returns:\n MemoryUnit: The retrieved memory unit.\n \"\"\"\n try:\n memory_unit = self.storage.get_memory_unit(unit_id)\n logger.info(f\"Retrieved MemoryUnit with ID: {unit_id}\")\n return memory_unit\n except Exception as e:\n logger.error(f\"Error retrieving MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n\n def update(self, unit_id: str, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a memory unit with the given updates.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n updates (Dict[str, Any]): A dictionary of fields to update.\n \"\"\"\n try:\n # Delegate update operations to MemoryStorage\n self.storage.update_memory_unit(unit_id, updates)\n logger.info(f\"Updated MemoryUnit with ID: {unit_id}\")\n except Exception as e:\n logger.error(f\"Error updating MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n\n def delete(self, unit_id: str) -> None:\n \"\"\"\n Deletes a memory unit by its unique identifier.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n try:\n # Delegate deletion to MemoryStorage\n self.storage.delete_memory_unit(unit_id)\n logger.info(f\"Deleted MemoryUnit with ID: {unit_id}\")\n except Exception as e:\n logger.error(f\"Error deleting MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n\n def get_all(self) -> List[MemoryUnit]:\n \"\"\"\n Retrieves all memory units.\n\n Returns:\n List[MemoryUnit]: A list of all memory units.\n \"\"\"\n try:\n memory_units = self.storage.get_all_memory_units()\n logger.info(f\"Retrieved all MemoryUnits. Total count: {len(memory_units)}\")\n return memory_units\n except Exception as e:\n logger.error(f\"Error retrieving all MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n\n def delete_all(self) -> None:\n \"\"\"\n Deletes all memory units.\n \"\"\"\n try:\n self.storage.delete_all_memory_units() # TODO: seems no work correctly, need to check\n logger.info(\"Deleted all MemoryUnits.\")\n except Exception as e:\n logger.error(f\"Error deleting all MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n\n def load(self) -> List[MemoryUnit]:\n \"\"\"\n Loads all memory units from the storage.\n\n Returns:\n List[MemoryUnit]: A list of all loaded memory units.\n \"\"\"\n try:\n # Retrieve all memory units from storage\n memory_units = self.get_all()\n logger.info(f\"Loaded {len(memory_units)} MemoryUnits from storage.\")\n return memory_units\n except Exception as e:\n logger.error(f\"Error loading MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n\n def save(self, export_path: Optional[str] = None) -> None:\n \"\"\"\n Saves all memory units to the storage or exports them to a specified path.\n\n Args:\n export_path (Optional[str]): The file path to export memory units as JSON.\n If None, saves are handled by MemoryStorage.\n \"\"\"\n try:\n if export_path:\n # Export memory units to a JSON file\n memory_units = self.get_all()\n export_data = [mu.to_dict() for mu in memory_units]\n with open(export_path, 'w', encoding='utf-8') as f:\n json.dump(export_data, f, ensure_ascii=False, indent=4)\n logger.info(f\"Exported {len(memory_units)} MemoryUnits to {export_path}.\")\n else:\n # If no export path is provided, assume that MemoryStorage handles persistence\n logger.info(\"Save operation delegated to MemoryStorage.\")\n # Example: self.storage.persist_changes()\n except Exception as e:\n logger.error(f\"Error saving MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n\n # Delegated Operations\n\n def filter(self, criteria: Dict[str, Any]) -> List[MemoryUnit]:\n \"\"\"\n Filters memory units based on the given criteria.\n\n Args:\n criteria (Dict[str, Any]): A dictionary of filter conditions.\n\n Returns:\n List[MemoryUnit]: A list of memory units matching the criteria.\n \"\"\"\n try:\n memory_units = self.get_all()\n filter_type = criteria.get('filter_type')\n if not filter_type:\n raise ValueError(\"Missing 'filter_type' in criteria.\")\n\n # Delegate filtering to MemoryCleaner\n filtered_memories = self.cleaner.filter(memory_units, filter_type, **criteria)\n logger.info(f\"Filtered memories based on criteria: {criteria}\")\n return filtered_memories\n except Exception as e:\n logger.error(f\"Error filtering memories: {e}\")\n self.handle_error(e)\n raise\n\n def organize(self, unit_ids: List[str], organize_type: str, metadata: Optional[Dict[str, Any]] = None) -> str:\n \"\"\"\n Groups memory units into a meaningful group.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to group.\n organize_type (str): The type of group (e.g., 'dialogue_session', 'procedure').\n metadata (Optional[Dict[str, Any]]): Additional metadata for the group.\n\n Returns:\n str: A unique identifier for the created group.\n \"\"\"\n try:\n # Retrieve the memory units to group\n memory_units = [self.get(unit_id) for unit_id in unit_ids]\n logger.debug(f\"Grouping {len(memory_units)} MemoryUnits into group_type='{organize_type}'.\")\n\n # Delegate grouping to MemoryOrganizer\n organized_memories = self.organizer.organize(memory_units, organize_type, metadata=metadata)\n logger.info(f\"Grouped memories into '{organize_type}'. Total memory units after grouping: {len(organized_memories)}\")\n return \"group_id_placeholder\" # Replace with actual group ID if applicable\n except Exception as e:\n logger.error(f\"Error grouping memories: {e}\")\n self.handle_error(e)\n raise\n\n def structurize(self, unit_ids: List[str], structure_type: str, **kwargs) -> None:\n \"\"\"\n Structures memory units into a knowledge graph or other structures.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to structurize.\n structure_type (str): The type of structure (e.g., 'knowledge_graph').\n **kwargs: Additional parameters for the structuring process.\n \"\"\"\n try:\n # Retrieve the memory units to structurize\n memory_units = [self.get(uid) for uid in unit_ids]\n logger.debug(f\"Structurizing {len(memory_units)} MemoryUnits with structure_type='{structure_type}'.\")\n\n # Delegate structuring to MemoryStructurer\n self.structurer.structure(memory_units, structure_type, **kwargs)\n logger.info(f\"Structurized memories with structure_type='{structure_type}'.\")\n except Exception as e:\n logger.error(f\"Error structurizing memories: {e}\")\n self.handle_error(e)\n raise\n\n def skillize(self, unit_ids: List[str], skill_name: str, **kwargs) -> str:\n \"\"\"\n Converts memory units into a reusable skill.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to skillize.\n skill_name (str): The name of the skill to create.\n **kwargs: Additional parameters for skill creation.\n\n Returns:\n str: The unique identifier of the created skill.\n \"\"\"\n try:\n # Retrieve the memory units to skillize\n memory_units = [self.get(uid) for uid in unit_ids]\n logger.debug(f\"Skillizing {len(memory_units)} MemoryUnits into skill_name='{skill_name}'.\")\n\n # Delegate skillizing to MemorySkillizer\n skill_id = self.skillizer.skillize(memory_units, skill_name, **kwargs)\n logger.info(f\"Skillized memories into skill with ID: {skill_id}\")\n return skill_id\n except Exception as e:\n logger.error(f\"Error skillizing memories: {e}\")\n self.handle_error(e)\n raise\n\n def parameterize(self, **kwargs) -> None:\n \"\"\"\n Trains a parametric model using the memory data.\n\n Args:\n **kwargs: Additional parameters for the training process.\n \"\"\"\n try:\n # Retrieve all memory units\n memory_units = self.get_all()\n logger.debug(f\"Parameterizing {len(memory_units)} MemoryUnits.\")\n\n # Delegate parameterizing to MemoryParameterizer\n self.parameterizer.parameterize(memory_units, **kwargs)\n logger.info(\"Parameterized memories successfully.\")\n except Exception as e:\n logger.error(f\"Error parameterizing memories: {e}\")\n self.handle_error(e)\n raise\n\n def retrieve(self, query: Any, retrieve_type: str, **kwargs) -> List[MemoryUnit]:\n \"\"\"\n Retrieve data from memory based on a query.\n\n Args:\n query (Any): The query or criteria to retrieve specific memory data.\n retrieve_type (str): The type of retrieval (e.g., 'similar', 'related').\n **kwargs: Additional parameters for the retrieval process.\n\n Returns:\n List[MemoryUnit]: The retrieved memory data.\n \"\"\"\n try:\n # Delegate retrieval to MemoryRetriever\n memories = self.retriever.retrieve(query=query, retrieve_type=retrieve_type, **kwargs)\n logger.info(f\"Retrieved {len(memories)} memories using retrieve_type='{retrieve_type}'.\")\n return memories\n except Exception as e:\n logger.error(f\"Error retrieving MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n\n def embed(self, unit_id: str) -> None:\n \"\"\"\n Generates an embedding for a memory unit.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n try:\n # Delegate embedding to MemoryRetriever\n memory_units = self.retriever.retrieve(query=unit_id, retrieve_type='similar', top_k=1)\n if not memory_units:\n raise ValueError(f\"No MemoryUnit found with ID {unit_id} to embed.\")\n\n memory_unit = memory_units[0]\n\n # Generate embedding using the embedder\n embedding_response = self.embedder.embed(memory_unit.content)\n if embedding_response.get(\"data\") and len(embedding_response[\"data\"]) > 0:\n memory_unit.embedding = embedding_response[\"data\"][0].get(\"embedding\")\n else:\n raise ValueError(\"Failed to generate embedding for the content.\")\n\n # Update the memory unit with the new embedding\n self.update(unit_id, {'embedding': memory_unit.embedding})\n\n logger.info(f\"Generated embedding for MemoryUnit ID: {unit_id}\")\n except Exception as e:\n logger.error(f\"Error generating embedding for MemoryUnit ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n\n # Error Handling\n\n def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during memory operations.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n logger.error(f\"MemoryPalace encountered an error: {error}\")\n # Additional error handling can be implemented here\n\n @staticmethod\n def get_api_key(self, config_section: Dict[str, Any], key_field: str, env_var_field: str) -> Optional[str]:\n \"\"\"\n Retrieve an API key from the configuration section.\n\n Args:\n config_section (Dict[str, Any]): The configuration section (e.g., embedder_config).\n key_field (str): The key in the config_section that may contain the API key directly.\n env_var_field (str): The key in the config_section that specifies the environment variable name.\n\n Returns:\n Optional[str]: The API key if found, else None.\n\n Raises:\n EnvironmentError: If the environment variable is specified but not set.\n \"\"\"\n # Check if API key is provided directly\n api_key = config_section.get(key_field)\n if api_key:\n logger.info(f\"Using provided API key for '{key_field}'.\")\n return api_key\n\n # Else, check if an environment variable is specified\n env_var = config_section.get(env_var_field)\n if env_var:\n api_key = os.getenv(env_var)\n if api_key:\n logger.info(f\"Retrieved API key for '{key_field}' from environment variable '{env_var}'.\")\n return api_key\n else:\n logger.error(f\"Environment variable '{env_var}' for '{key_field}' is not set.\")\n raise EnvironmentError(f\"Environment variable '{env_var}' for '{key_field}' is not set.\")\n\n logger.warning(f\"No API key provided for '{key_field}'.\")\n return None\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.__init__","title":"__init__(config)
","text":"Initialize the MemoryPalace with the provided configuration.
Parameters:
Name Type Description Default config
MemoryConfig
Configuration settings for the MemoryPalace.
required Source code in src/aeiva/cognition/memory/memory_palace.py
def __init__(self, config: Dict):\n \"\"\"\n Initialize the MemoryPalace with the provided configuration.\n\n Args:\n config (MemoryConfig): Configuration settings for the MemoryPalace.\n \"\"\"\n self.config_dict = config\n self.config = None\n self.storage = None\n self.embedder = None\n self.cleaner = None\n self.organizer = None\n self.retriever = None\n self.structurer = None\n self.skillizer = None\n self.parameterizer = None\n self.setup()\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.create","title":"create(content, **kwargs)
","text":"Creates a new memory unit with the given content and metadata.
Parameters:
Name Type Description Default content
Any
The core content of the memory unit.
required **kwargs
Additional metadata for the memory unit.
{}
Returns:
Name Type Description MemoryUnit
MemoryUnit
The created memory unit.
Source code in src/aeiva/cognition/memory/memory_palace.py
def create(self, content: Any, **kwargs) -> MemoryUnit:\n \"\"\"\n Creates a new memory unit with the given content and metadata.\n\n Args:\n content (Any): The core content of the memory unit.\n **kwargs: Additional metadata for the memory unit.\n\n Returns:\n MemoryUnit: The created memory unit.\n \"\"\"\n try:\n # Instantiate MemoryUnit\n memory_unit = MemoryUnit(content=content, **kwargs)\n\n # Generate embedding\n embedding_response = self.embedder.embed(content)\n if embedding_response.get(\"data\"):\n memory_unit.embedding = embedding_response[\"data\"][0].get(\"embedding\")\n else:\n raise ValueError(\"Failed to generate embedding for the content.\")\n\n # Delegate storage operations to MemoryStorage\n self.storage.add_memory_unit(memory_unit)\n\n logger.info(f\"Created new MemoryUnit with ID: {memory_unit.id}\")\n return memory_unit\n except Exception as e:\n logger.error(f\"Error creating MemoryUnit: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.delete","title":"delete(unit_id)
","text":"Deletes a memory unit by its unique identifier.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required Source code in src/aeiva/cognition/memory/memory_palace.py
def delete(self, unit_id: str) -> None:\n \"\"\"\n Deletes a memory unit by its unique identifier.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n try:\n # Delegate deletion to MemoryStorage\n self.storage.delete_memory_unit(unit_id)\n logger.info(f\"Deleted MemoryUnit with ID: {unit_id}\")\n except Exception as e:\n logger.error(f\"Error deleting MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.delete_all","title":"delete_all()
","text":"Deletes all memory units.
Source code in src/aeiva/cognition/memory/memory_palace.py
def delete_all(self) -> None:\n \"\"\"\n Deletes all memory units.\n \"\"\"\n try:\n self.storage.delete_all_memory_units() # TODO: seems no work correctly, need to check\n logger.info(\"Deleted all MemoryUnits.\")\n except Exception as e:\n logger.error(f\"Error deleting all MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.embed","title":"embed(unit_id)
","text":"Generates an embedding for a memory unit.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required Source code in src/aeiva/cognition/memory/memory_palace.py
def embed(self, unit_id: str) -> None:\n \"\"\"\n Generates an embedding for a memory unit.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n try:\n # Delegate embedding to MemoryRetriever\n memory_units = self.retriever.retrieve(query=unit_id, retrieve_type='similar', top_k=1)\n if not memory_units:\n raise ValueError(f\"No MemoryUnit found with ID {unit_id} to embed.\")\n\n memory_unit = memory_units[0]\n\n # Generate embedding using the embedder\n embedding_response = self.embedder.embed(memory_unit.content)\n if embedding_response.get(\"data\") and len(embedding_response[\"data\"]) > 0:\n memory_unit.embedding = embedding_response[\"data\"][0].get(\"embedding\")\n else:\n raise ValueError(\"Failed to generate embedding for the content.\")\n\n # Update the memory unit with the new embedding\n self.update(unit_id, {'embedding': memory_unit.embedding})\n\n logger.info(f\"Generated embedding for MemoryUnit ID: {unit_id}\")\n except Exception as e:\n logger.error(f\"Error generating embedding for MemoryUnit ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.filter","title":"filter(criteria)
","text":"Filters memory units based on the given criteria.
Parameters:
Name Type Description Default criteria
Dict[str, Any]
A dictionary of filter conditions.
required Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of memory units matching the criteria.
Source code in src/aeiva/cognition/memory/memory_palace.py
def filter(self, criteria: Dict[str, Any]) -> List[MemoryUnit]:\n \"\"\"\n Filters memory units based on the given criteria.\n\n Args:\n criteria (Dict[str, Any]): A dictionary of filter conditions.\n\n Returns:\n List[MemoryUnit]: A list of memory units matching the criteria.\n \"\"\"\n try:\n memory_units = self.get_all()\n filter_type = criteria.get('filter_type')\n if not filter_type:\n raise ValueError(\"Missing 'filter_type' in criteria.\")\n\n # Delegate filtering to MemoryCleaner\n filtered_memories = self.cleaner.filter(memory_units, filter_type, **criteria)\n logger.info(f\"Filtered memories based on criteria: {criteria}\")\n return filtered_memories\n except Exception as e:\n logger.error(f\"Error filtering memories: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.get","title":"get(unit_id)
","text":"Retrieves a memory unit by its unique identifier.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required Returns:
Name Type Description MemoryUnit
MemoryUnit
The retrieved memory unit.
Source code in src/aeiva/cognition/memory/memory_palace.py
def get(self, unit_id: str) -> MemoryUnit:\n \"\"\"\n Retrieves a memory unit by its unique identifier.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n\n Returns:\n MemoryUnit: The retrieved memory unit.\n \"\"\"\n try:\n memory_unit = self.storage.get_memory_unit(unit_id)\n logger.info(f\"Retrieved MemoryUnit with ID: {unit_id}\")\n return memory_unit\n except Exception as e:\n logger.error(f\"Error retrieving MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.get_all","title":"get_all()
","text":"Retrieves all memory units.
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of all memory units.
Source code in src/aeiva/cognition/memory/memory_palace.py
def get_all(self) -> List[MemoryUnit]:\n \"\"\"\n Retrieves all memory units.\n\n Returns:\n List[MemoryUnit]: A list of all memory units.\n \"\"\"\n try:\n memory_units = self.storage.get_all_memory_units()\n logger.info(f\"Retrieved all MemoryUnits. Total count: {len(memory_units)}\")\n return memory_units\n except Exception as e:\n logger.error(f\"Error retrieving all MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.get_api_key","title":"get_api_key(config_section, key_field, env_var_field)
staticmethod
","text":"Retrieve an API key from the configuration section.
Parameters:
Name Type Description Default config_section
Dict[str, Any]
The configuration section (e.g., embedder_config).
required key_field
str
The key in the config_section that may contain the API key directly.
required env_var_field
str
The key in the config_section that specifies the environment variable name.
required Returns:
Type Description Optional[str]
Optional[str]: The API key if found, else None.
Raises:
Type Description EnvironmentError
If the environment variable is specified but not set.
Source code in src/aeiva/cognition/memory/memory_palace.py
@staticmethod\ndef get_api_key(self, config_section: Dict[str, Any], key_field: str, env_var_field: str) -> Optional[str]:\n \"\"\"\n Retrieve an API key from the configuration section.\n\n Args:\n config_section (Dict[str, Any]): The configuration section (e.g., embedder_config).\n key_field (str): The key in the config_section that may contain the API key directly.\n env_var_field (str): The key in the config_section that specifies the environment variable name.\n\n Returns:\n Optional[str]: The API key if found, else None.\n\n Raises:\n EnvironmentError: If the environment variable is specified but not set.\n \"\"\"\n # Check if API key is provided directly\n api_key = config_section.get(key_field)\n if api_key:\n logger.info(f\"Using provided API key for '{key_field}'.\")\n return api_key\n\n # Else, check if an environment variable is specified\n env_var = config_section.get(env_var_field)\n if env_var:\n api_key = os.getenv(env_var)\n if api_key:\n logger.info(f\"Retrieved API key for '{key_field}' from environment variable '{env_var}'.\")\n return api_key\n else:\n logger.error(f\"Environment variable '{env_var}' for '{key_field}' is not set.\")\n raise EnvironmentError(f\"Environment variable '{env_var}' for '{key_field}' is not set.\")\n\n logger.warning(f\"No API key provided for '{key_field}'.\")\n return None\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.handle_error","title":"handle_error(error)
","text":"Handle errors that occur during memory operations.
Parameters:
Name Type Description Default error
Exception
The exception that was raised.
required Source code in src/aeiva/cognition/memory/memory_palace.py
def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during memory operations.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n logger.error(f\"MemoryPalace encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.load","title":"load()
","text":"Loads all memory units from the storage.
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of all loaded memory units.
Source code in src/aeiva/cognition/memory/memory_palace.py
def load(self) -> List[MemoryUnit]:\n \"\"\"\n Loads all memory units from the storage.\n\n Returns:\n List[MemoryUnit]: A list of all loaded memory units.\n \"\"\"\n try:\n # Retrieve all memory units from storage\n memory_units = self.get_all()\n logger.info(f\"Loaded {len(memory_units)} MemoryUnits from storage.\")\n return memory_units\n except Exception as e:\n logger.error(f\"Error loading MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.organize","title":"organize(unit_ids, organize_type, metadata=None)
","text":"Groups memory units into a meaningful group.
Parameters:
Name Type Description Default unit_ids
List[str]
A list of memory unit IDs to group.
required organize_type
str
The type of group (e.g., 'dialogue_session', 'procedure').
required metadata
Optional[Dict[str, Any]]
Additional metadata for the group.
None
Returns:
Name Type Description str
str
A unique identifier for the created group.
Source code in src/aeiva/cognition/memory/memory_palace.py
def organize(self, unit_ids: List[str], organize_type: str, metadata: Optional[Dict[str, Any]] = None) -> str:\n \"\"\"\n Groups memory units into a meaningful group.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to group.\n organize_type (str): The type of group (e.g., 'dialogue_session', 'procedure').\n metadata (Optional[Dict[str, Any]]): Additional metadata for the group.\n\n Returns:\n str: A unique identifier for the created group.\n \"\"\"\n try:\n # Retrieve the memory units to group\n memory_units = [self.get(unit_id) for unit_id in unit_ids]\n logger.debug(f\"Grouping {len(memory_units)} MemoryUnits into group_type='{organize_type}'.\")\n\n # Delegate grouping to MemoryOrganizer\n organized_memories = self.organizer.organize(memory_units, organize_type, metadata=metadata)\n logger.info(f\"Grouped memories into '{organize_type}'. Total memory units after grouping: {len(organized_memories)}\")\n return \"group_id_placeholder\" # Replace with actual group ID if applicable\n except Exception as e:\n logger.error(f\"Error grouping memories: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.parameterize","title":"parameterize(**kwargs)
","text":"Trains a parametric model using the memory data.
Parameters:
Name Type Description Default **kwargs
Additional parameters for the training process.
{}
Source code in src/aeiva/cognition/memory/memory_palace.py
def parameterize(self, **kwargs) -> None:\n \"\"\"\n Trains a parametric model using the memory data.\n\n Args:\n **kwargs: Additional parameters for the training process.\n \"\"\"\n try:\n # Retrieve all memory units\n memory_units = self.get_all()\n logger.debug(f\"Parameterizing {len(memory_units)} MemoryUnits.\")\n\n # Delegate parameterizing to MemoryParameterizer\n self.parameterizer.parameterize(memory_units, **kwargs)\n logger.info(\"Parameterized memories successfully.\")\n except Exception as e:\n logger.error(f\"Error parameterizing memories: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.retrieve","title":"retrieve(query, retrieve_type, **kwargs)
","text":"Retrieve data from memory based on a query.
Parameters:
Name Type Description Default query
Any
The query or criteria to retrieve specific memory data.
required retrieve_type
str
The type of retrieval (e.g., 'similar', 'related').
required **kwargs
Additional parameters for the retrieval process.
{}
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The retrieved memory data.
Source code in src/aeiva/cognition/memory/memory_palace.py
def retrieve(self, query: Any, retrieve_type: str, **kwargs) -> List[MemoryUnit]:\n \"\"\"\n Retrieve data from memory based on a query.\n\n Args:\n query (Any): The query or criteria to retrieve specific memory data.\n retrieve_type (str): The type of retrieval (e.g., 'similar', 'related').\n **kwargs: Additional parameters for the retrieval process.\n\n Returns:\n List[MemoryUnit]: The retrieved memory data.\n \"\"\"\n try:\n # Delegate retrieval to MemoryRetriever\n memories = self.retriever.retrieve(query=query, retrieve_type=retrieve_type, **kwargs)\n logger.info(f\"Retrieved {len(memories)} memories using retrieve_type='{retrieve_type}'.\")\n return memories\n except Exception as e:\n logger.error(f\"Error retrieving MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.save","title":"save(export_path=None)
","text":"Saves all memory units to the storage or exports them to a specified path.
Parameters:
Name Type Description Default export_path
Optional[str]
The file path to export memory units as JSON. If None, saves are handled by MemoryStorage.
None
Source code in src/aeiva/cognition/memory/memory_palace.py
def save(self, export_path: Optional[str] = None) -> None:\n \"\"\"\n Saves all memory units to the storage or exports them to a specified path.\n\n Args:\n export_path (Optional[str]): The file path to export memory units as JSON.\n If None, saves are handled by MemoryStorage.\n \"\"\"\n try:\n if export_path:\n # Export memory units to a JSON file\n memory_units = self.get_all()\n export_data = [mu.to_dict() for mu in memory_units]\n with open(export_path, 'w', encoding='utf-8') as f:\n json.dump(export_data, f, ensure_ascii=False, indent=4)\n logger.info(f\"Exported {len(memory_units)} MemoryUnits to {export_path}.\")\n else:\n # If no export path is provided, assume that MemoryStorage handles persistence\n logger.info(\"Save operation delegated to MemoryStorage.\")\n # Example: self.storage.persist_changes()\n except Exception as e:\n logger.error(f\"Error saving MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.setup","title":"setup()
","text":"Setup the MemoryPalace by initializing all components.
Source code in src/aeiva/cognition/memory/memory_palace.py
def setup(self):\n \"\"\"\n Setup the MemoryPalace by initializing all components.\n \"\"\"\n try:\n # Initialize EmbedderConfig\n embedder_config_dict = self.config_dict.get('embedder_config', {})\n self.embedder = Embedder(embedder_config_dict)\n\n storage_config_dict = self.config_dict.get('storage_config', {})\n self.storage = MemoryStorage(storage_config_dict) \n\n # Initialize Memory Configuration\n self.config = MemoryConfig(\n embedder_config=self.embedder.config,\n storage_config=self.storage.config\n )\n\n logger.info(\"MemoryPalace: MemoryStorage and Embedder initialized successfully.\")\n\n # Initialize specialized components\n self.cleaner = MemoryCleaner()\n self.organizer = MemoryOrganizer()\n self.retriever = MemoryRetriever(embedder=self.embedder, storage=self.storage)\n self.structurer = MemoryStructurer()\n self.skillizer = MemorySkillizer()\n self.parameterizer = MemoryParameterizer()\n logger.info(\"MemoryPalace: Specialized components initialized successfully.\")\n\n except Exception as e:\n logger.error(f\"MemoryPalace setup failed: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.skillize","title":"skillize(unit_ids, skill_name, **kwargs)
","text":"Converts memory units into a reusable skill.
Parameters:
Name Type Description Default unit_ids
List[str]
A list of memory unit IDs to skillize.
required skill_name
str
The name of the skill to create.
required **kwargs
Additional parameters for skill creation.
{}
Returns:
Name Type Description str
str
The unique identifier of the created skill.
Source code in src/aeiva/cognition/memory/memory_palace.py
def skillize(self, unit_ids: List[str], skill_name: str, **kwargs) -> str:\n \"\"\"\n Converts memory units into a reusable skill.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to skillize.\n skill_name (str): The name of the skill to create.\n **kwargs: Additional parameters for skill creation.\n\n Returns:\n str: The unique identifier of the created skill.\n \"\"\"\n try:\n # Retrieve the memory units to skillize\n memory_units = [self.get(uid) for uid in unit_ids]\n logger.debug(f\"Skillizing {len(memory_units)} MemoryUnits into skill_name='{skill_name}'.\")\n\n # Delegate skillizing to MemorySkillizer\n skill_id = self.skillizer.skillize(memory_units, skill_name, **kwargs)\n logger.info(f\"Skillized memories into skill with ID: {skill_id}\")\n return skill_id\n except Exception as e:\n logger.error(f\"Error skillizing memories: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.structurize","title":"structurize(unit_ids, structure_type, **kwargs)
","text":"Structures memory units into a knowledge graph or other structures.
Parameters:
Name Type Description Default unit_ids
List[str]
A list of memory unit IDs to structurize.
required structure_type
str
The type of structure (e.g., 'knowledge_graph').
required **kwargs
Additional parameters for the structuring process.
{}
Source code in src/aeiva/cognition/memory/memory_palace.py
def structurize(self, unit_ids: List[str], structure_type: str, **kwargs) -> None:\n \"\"\"\n Structures memory units into a knowledge graph or other structures.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to structurize.\n structure_type (str): The type of structure (e.g., 'knowledge_graph').\n **kwargs: Additional parameters for the structuring process.\n \"\"\"\n try:\n # Retrieve the memory units to structurize\n memory_units = [self.get(uid) for uid in unit_ids]\n logger.debug(f\"Structurizing {len(memory_units)} MemoryUnits with structure_type='{structure_type}'.\")\n\n # Delegate structuring to MemoryStructurer\n self.structurer.structure(memory_units, structure_type, **kwargs)\n logger.info(f\"Structurized memories with structure_type='{structure_type}'.\")\n except Exception as e:\n logger.error(f\"Error structurizing memories: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.update","title":"update(unit_id, updates)
","text":"Updates a memory unit with the given updates.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required updates
Dict[str, Any]
A dictionary of fields to update.
required Source code in src/aeiva/cognition/memory/memory_palace.py
def update(self, unit_id: str, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a memory unit with the given updates.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n updates (Dict[str, Any]): A dictionary of fields to update.\n \"\"\"\n try:\n # Delegate update operations to MemoryStorage\n self.storage.update_memory_unit(unit_id, updates)\n logger.info(f\"Updated MemoryUnit with ID: {unit_id}\")\n except Exception as e:\n logger.error(f\"Error updating MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_parameterizer","title":"memory_parameterizer
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_parameterizer.MemoryParameterizer","title":"MemoryParameterizer
","text":"A class to parameterize memory units based on various parameterizing algorithms.
Supported parameterize types - 'parameterize_type_example': Placeholder for future parameterizing algorithms.
Source code in src/aeiva/cognition/memory/memory_parameterizer.py
class MemoryParameterizer:\n \"\"\"\n A class to parameterize memory units based on various parameterizing algorithms.\n\n Supported parameterize types:\n - 'parameterize_type_example': Placeholder for future parameterizing algorithms.\n \"\"\"\n\n def __init__(self):\n \"\"\"\n Initializes the MemoryParameterizer.\n\n Currently, no initialization parameters are required.\n \"\"\"\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemoryParameterizer without default parameters.\")\n\n def parameterize(\n self,\n memory_units: List[MemoryUnit],\n parameterize_type: str,\n **kwargs\n ) -> List[MemoryUnit]:\n \"\"\"\n Parameterizes the provided memory units based on the specified parameterize type.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be parameterized.\n parameterize_type (str): The type of parameterizing algorithm to use ('parameterize_type_example').\n **kwargs: Additional parameters required for specific parameterizers.\n\n Returns:\n List[MemoryUnit]: The list of memory units after parameterization.\n\n Raises:\n MemoryParameterizerError: If an unknown parameterize_type is provided or if parameterizing fails.\n \"\"\"\n self.logger.debug(f\"Parameterizing memory units using parameterize_type='{parameterize_type}' with kwargs={kwargs}\")\n try:\n if parameterize_type == 'parameterize_type_example':\n # Placeholder for actual parameterizing logic\n return self.parameterize_example(memory_units, **kwargs)\n else:\n self.logger.error(f\"Unknown parameterize_type: {parameterize_type}\")\n raise MemoryParameterizerError(f\"Unknown parameterize_type: {parameterize_type}\")\n except MemoryParameterizerError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to parameterize memory units: {e}\")\n raise MemoryParameterizerError(f\"Failed to parameterize memory units: {e}\")\n\n def parameterize_example(\n self,\n memory_units: List[MemoryUnit],\n **kwargs\n ) -> List[MemoryUnit]:\n \"\"\"\n Example parameterizing method. Currently a placeholder that returns memory units unchanged.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be parameterized.\n **kwargs: Additional parameters (currently unused).\n\n Returns:\n List[MemoryUnit]: The original list of memory units, unchanged.\n \"\"\"\n self.logger.debug(\"Executing parameterize_example: No changes applied to memory units.\")\n # Placeholder: No operation performed\n return memory_units\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_parameterizer.MemoryParameterizer.__init__","title":"__init__()
","text":"Initializes the MemoryParameterizer.
Currently, no initialization parameters are required.
Source code in src/aeiva/cognition/memory/memory_parameterizer.py
def __init__(self):\n \"\"\"\n Initializes the MemoryParameterizer.\n\n Currently, no initialization parameters are required.\n \"\"\"\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemoryParameterizer without default parameters.\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_parameterizer.MemoryParameterizer.parameterize","title":"parameterize(memory_units, parameterize_type, **kwargs)
","text":"Parameterizes the provided memory units based on the specified parameterize type.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be parameterized.
required parameterize_type
str
The type of parameterizing algorithm to use ('parameterize_type_example').
required **kwargs
Additional parameters required for specific parameterizers.
{}
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The list of memory units after parameterization.
Raises:
Type Description MemoryParameterizerError
If an unknown parameterize_type is provided or if parameterizing fails.
Source code in src/aeiva/cognition/memory/memory_parameterizer.py
def parameterize(\n self,\n memory_units: List[MemoryUnit],\n parameterize_type: str,\n **kwargs\n) -> List[MemoryUnit]:\n \"\"\"\n Parameterizes the provided memory units based on the specified parameterize type.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be parameterized.\n parameterize_type (str): The type of parameterizing algorithm to use ('parameterize_type_example').\n **kwargs: Additional parameters required for specific parameterizers.\n\n Returns:\n List[MemoryUnit]: The list of memory units after parameterization.\n\n Raises:\n MemoryParameterizerError: If an unknown parameterize_type is provided or if parameterizing fails.\n \"\"\"\n self.logger.debug(f\"Parameterizing memory units using parameterize_type='{parameterize_type}' with kwargs={kwargs}\")\n try:\n if parameterize_type == 'parameterize_type_example':\n # Placeholder for actual parameterizing logic\n return self.parameterize_example(memory_units, **kwargs)\n else:\n self.logger.error(f\"Unknown parameterize_type: {parameterize_type}\")\n raise MemoryParameterizerError(f\"Unknown parameterize_type: {parameterize_type}\")\n except MemoryParameterizerError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to parameterize memory units: {e}\")\n raise MemoryParameterizerError(f\"Failed to parameterize memory units: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_parameterizer.MemoryParameterizer.parameterize_example","title":"parameterize_example(memory_units, **kwargs)
","text":"Example parameterizing method. Currently a placeholder that returns memory units unchanged.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be parameterized.
required **kwargs
Additional parameters (currently unused).
{}
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The original list of memory units, unchanged.
Source code in src/aeiva/cognition/memory/memory_parameterizer.py
def parameterize_example(\n self,\n memory_units: List[MemoryUnit],\n **kwargs\n) -> List[MemoryUnit]:\n \"\"\"\n Example parameterizing method. Currently a placeholder that returns memory units unchanged.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be parameterized.\n **kwargs: Additional parameters (currently unused).\n\n Returns:\n List[MemoryUnit]: The original list of memory units, unchanged.\n \"\"\"\n self.logger.debug(\"Executing parameterize_example: No changes applied to memory units.\")\n # Placeholder: No operation performed\n return memory_units\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_parameterizer.MemoryParameterizerError","title":"MemoryParameterizerError
","text":" Bases: Exception
Exception raised when an error occurs in the MemoryParameterizer.
Source code in src/aeiva/cognition/memory/memory_parameterizer.py
class MemoryParameterizerError(Exception):\n \"\"\"Exception raised when an error occurs in the MemoryParameterizer.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_retriever","title":"memory_retriever
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_retriever.MemoryRetriever","title":"MemoryRetriever
","text":"A class to retrieve memory units based on various retrieval algorithms.
Supported retrieval types - 'similar': Retrieves memory units similar to a given query based on embeddings.
- 'related': Retrieves memory units related to a specified query based on relationships.
Source code in src/aeiva/cognition/memory/memory_retriever.py
class MemoryRetriever:\n \"\"\"\n A class to retrieve memory units based on various retrieval algorithms.\n\n Supported retrieval types:\n - 'similar': Retrieves memory units similar to a given query based on embeddings.\n - 'related': Retrieves memory units related to a specified query based on relationships.\n \"\"\"\n\n def __init__(self, embedder: Embedder, storage: MemoryStorage):\n \"\"\"\n Initializes the MemoryRetriever.\n\n Args:\n embedder (Embedder): An instance responsible for generating embeddings.\n storage (MemoryStorage): An instance managing data storage and retrieval.\n \"\"\"\n self.embedder = embedder\n self.storage = storage\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemoryRetriever with provided embedder and storage.\")\n\n def retrieve(\n self,\n query: Any,\n retrieve_type: str,\n **kwargs\n ) -> List[MemoryUnit]:\n \"\"\"\n Factory method to retrieve memory units based on the specified retrieval type.\n\n Args:\n query (Any): The query for retrieval.\n retrieve_type (str): The type of retrieval ('similar' or 'related').\n **kwargs: Additional parameters required for specific retrieval types.\n For 'similar' retrieval:\n - top_k (int): The number of similar units to retrieve.\n For 'related' retrieval:\n - relationship (Optional[str]): The type of relationship to filter by.\n\n Returns:\n List[MemoryUnit]: A list of retrieved memory units.\n\n Raises:\n MemoryRetrieverError: If an unknown retrieval_type is provided or if retrieval fails.\n \"\"\"\n self.logger.info(f\"Initiating retrieval of type '{retrieve_type}' with query: {query}\")\n try:\n if retrieve_type == 'similar':\n top_k = kwargs.get('top_k', 5)\n self.logger.debug(f\"Retrieval Type: 'similar' with top_k={top_k}\")\n return self.retrieve_similar(query, top_k)\n elif retrieve_type == 'related':\n relationship = kwargs.get('relationship')\n self.logger.debug(f\"Retrieval Type: 'related' with relationship='{relationship}'\")\n return self.retrieve_related(query, relationship)\n else:\n self.logger.error(f\"Unknown retrieve_type: {retrieve_type}\")\n raise MemoryRetrieverError(f\"Unknown retrieve_type: {retrieve_type}\")\n except MemoryRetrieverError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to retrieve memory units: {e}\")\n raise MemoryRetrieverError(f\"Failed to retrieve memory units: {e}\") from e\n\n def retrieve_similar(self, query: Any, top_k: int = 5) -> List[MemoryUnit]:\n \"\"\"\n Retrieves memory units similar to the given input based on embeddings.\n\n Args:\n query (Any): The query for retrieval.\n top_k (int): The number of similar units to retrieve.\n\n Returns:\n List[MemoryUnit]: A list of similar memory units.\n\n Raises:\n MemoryRetrieverError: If retrieval fails due to embedding generation or storage issues.\n \"\"\"\n self.logger.info(f\"Retrieving top {top_k} similar MemoryUnits based on the query.\")\n try:\n # Generate embedding for the query\n self.logger.debug(\"Generating embedding for the query.\")\n embedding_response = self.embedder.embed(query)\n if not embedding_response.get(\"data\"):\n self.logger.error(\"Failed to generate embedding for the query.\")\n raise MemoryRetrieverError(\"Failed to generate embedding for the query.\")\n\n query_embedding = embedding_response[\"data\"][0].get(\"embedding\")\n if not query_embedding:\n self.logger.error(\"Embedding data is missing in the response.\")\n raise MemoryRetrieverError(\"Embedding data is missing in the response.\")\n\n self.logger.debug(f\"Embedding generated successfully: {query_embedding}\")\n\n # Perform similarity search via MemoryStorage\n self.logger.debug(\"Performing similarity search in the vector database.\")\n similar_units = self.storage.retrieve_similar_memory_units(query_embedding, top_k)\n self.logger.info(f\"Retrieved {len(similar_units)} similar MemoryUnits.\")\n return similar_units\n\n except MemoryRetrieverError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Unexpected error during retrieve_similar: {e}\")\n raise MemoryRetrieverError(f\"Unexpected error during retrieve_similar: {e}\") from e\n\n def retrieve_related(\n self,\n query: Any,\n relationship: Optional[str] = None\n ) -> List[MemoryUnit]: # TODO: revise the method later\n \"\"\"\n Retrieves memory units related to the given query based on relationships.\n\n Args:\n query (Any): The query for retrieval. Expected to be a MemoryUnit ID or similar identifier.\n relationship (Optional[str]): The type of relationship to filter by.\n\n Returns:\n List[MemoryUnit]: A list of related memory units.\n\n Raises:\n MemoryRetrieverError: If retrieval fails due to storage issues or invalid queries.\n \"\"\"\n self.logger.info(f\"Retrieving memories related to the query with relationship: {relationship}\")\n try:\n # Assuming 'query' is a MemoryUnit ID or can be used to fetch a MemoryUnit\n self.logger.debug(\"Fetching the target MemoryUnit from storage.\")\n target_memory_unit = self.storage.get_memory_unit(query)\n if not target_memory_unit:\n self.logger.error(f\"MemoryUnit with ID '{query}' not found.\")\n raise MemoryRetrieverError(f\"MemoryUnit with ID '{query}' not found.\")\n\n self.logger.debug(f\"MemoryUnit fetched successfully: {target_memory_unit}\")\n\n # Perform related retrieval via MemoryStorage\n self.logger.debug(\"Retrieving related MemoryUnits from the graph database.\")\n related_units = self.storage.retrieve_related_memory_units(target_memory_unit.id, relationship)\n self.logger.info(f\"Retrieved {len(related_units)} related MemoryUnits.\")\n return related_units\n\n except MemoryRetrieverError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Unexpected error during retrieve_related: {e}\")\n raise MemoryRetrieverError(f\"Unexpected error during retrieve_related: {e}\") from e\n\n def handle_error(self, error: Exception):\n \"\"\"\n Handles errors by logging or performing other necessary actions.\n\n Args:\n error (Exception): The exception to handle.\n \"\"\"\n # Implement any error handling logic here\n # For now, we'll just log the error\n self.logger.error(f\"An error occurred: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_retriever.MemoryRetriever.__init__","title":"__init__(embedder, storage)
","text":"Initializes the MemoryRetriever.
Parameters:
Name Type Description Default embedder
Embedder
An instance responsible for generating embeddings.
required storage
MemoryStorage
An instance managing data storage and retrieval.
required Source code in src/aeiva/cognition/memory/memory_retriever.py
def __init__(self, embedder: Embedder, storage: MemoryStorage):\n \"\"\"\n Initializes the MemoryRetriever.\n\n Args:\n embedder (Embedder): An instance responsible for generating embeddings.\n storage (MemoryStorage): An instance managing data storage and retrieval.\n \"\"\"\n self.embedder = embedder\n self.storage = storage\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemoryRetriever with provided embedder and storage.\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_retriever.MemoryRetriever.handle_error","title":"handle_error(error)
","text":"Handles errors by logging or performing other necessary actions.
Parameters:
Name Type Description Default error
Exception
The exception to handle.
required Source code in src/aeiva/cognition/memory/memory_retriever.py
def handle_error(self, error: Exception):\n \"\"\"\n Handles errors by logging or performing other necessary actions.\n\n Args:\n error (Exception): The exception to handle.\n \"\"\"\n # Implement any error handling logic here\n # For now, we'll just log the error\n self.logger.error(f\"An error occurred: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_retriever.MemoryRetriever.retrieve","title":"retrieve(query, retrieve_type, **kwargs)
","text":"Factory method to retrieve memory units based on the specified retrieval type.
Parameters:
Name Type Description Default query
Any
The query for retrieval.
required retrieve_type
str
The type of retrieval ('similar' or 'related').
required **kwargs
Additional parameters required for specific retrieval types. For 'similar' retrieval: - top_k (int): The number of similar units to retrieve. For 'related' retrieval: - relationship (Optional[str]): The type of relationship to filter by.
{}
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of retrieved memory units.
Raises:
Type Description MemoryRetrieverError
If an unknown retrieval_type is provided or if retrieval fails.
Source code in src/aeiva/cognition/memory/memory_retriever.py
def retrieve(\n self,\n query: Any,\n retrieve_type: str,\n **kwargs\n) -> List[MemoryUnit]:\n \"\"\"\n Factory method to retrieve memory units based on the specified retrieval type.\n\n Args:\n query (Any): The query for retrieval.\n retrieve_type (str): The type of retrieval ('similar' or 'related').\n **kwargs: Additional parameters required for specific retrieval types.\n For 'similar' retrieval:\n - top_k (int): The number of similar units to retrieve.\n For 'related' retrieval:\n - relationship (Optional[str]): The type of relationship to filter by.\n\n Returns:\n List[MemoryUnit]: A list of retrieved memory units.\n\n Raises:\n MemoryRetrieverError: If an unknown retrieval_type is provided or if retrieval fails.\n \"\"\"\n self.logger.info(f\"Initiating retrieval of type '{retrieve_type}' with query: {query}\")\n try:\n if retrieve_type == 'similar':\n top_k = kwargs.get('top_k', 5)\n self.logger.debug(f\"Retrieval Type: 'similar' with top_k={top_k}\")\n return self.retrieve_similar(query, top_k)\n elif retrieve_type == 'related':\n relationship = kwargs.get('relationship')\n self.logger.debug(f\"Retrieval Type: 'related' with relationship='{relationship}'\")\n return self.retrieve_related(query, relationship)\n else:\n self.logger.error(f\"Unknown retrieve_type: {retrieve_type}\")\n raise MemoryRetrieverError(f\"Unknown retrieve_type: {retrieve_type}\")\n except MemoryRetrieverError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to retrieve memory units: {e}\")\n raise MemoryRetrieverError(f\"Failed to retrieve memory units: {e}\") from e\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_retriever.MemoryRetriever.retrieve_related","title":"retrieve_related(query, relationship=None)
","text":"Retrieves memory units related to the given query based on relationships.
Parameters:
Name Type Description Default query
Any
The query for retrieval. Expected to be a MemoryUnit ID or similar identifier.
required relationship
Optional[str]
The type of relationship to filter by.
None
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of related memory units.
Raises:
Type Description MemoryRetrieverError
If retrieval fails due to storage issues or invalid queries.
Source code in src/aeiva/cognition/memory/memory_retriever.py
def retrieve_related(\n self,\n query: Any,\n relationship: Optional[str] = None\n) -> List[MemoryUnit]: # TODO: revise the method later\n \"\"\"\n Retrieves memory units related to the given query based on relationships.\n\n Args:\n query (Any): The query for retrieval. Expected to be a MemoryUnit ID or similar identifier.\n relationship (Optional[str]): The type of relationship to filter by.\n\n Returns:\n List[MemoryUnit]: A list of related memory units.\n\n Raises:\n MemoryRetrieverError: If retrieval fails due to storage issues or invalid queries.\n \"\"\"\n self.logger.info(f\"Retrieving memories related to the query with relationship: {relationship}\")\n try:\n # Assuming 'query' is a MemoryUnit ID or can be used to fetch a MemoryUnit\n self.logger.debug(\"Fetching the target MemoryUnit from storage.\")\n target_memory_unit = self.storage.get_memory_unit(query)\n if not target_memory_unit:\n self.logger.error(f\"MemoryUnit with ID '{query}' not found.\")\n raise MemoryRetrieverError(f\"MemoryUnit with ID '{query}' not found.\")\n\n self.logger.debug(f\"MemoryUnit fetched successfully: {target_memory_unit}\")\n\n # Perform related retrieval via MemoryStorage\n self.logger.debug(\"Retrieving related MemoryUnits from the graph database.\")\n related_units = self.storage.retrieve_related_memory_units(target_memory_unit.id, relationship)\n self.logger.info(f\"Retrieved {len(related_units)} related MemoryUnits.\")\n return related_units\n\n except MemoryRetrieverError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Unexpected error during retrieve_related: {e}\")\n raise MemoryRetrieverError(f\"Unexpected error during retrieve_related: {e}\") from e\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_retriever.MemoryRetriever.retrieve_similar","title":"retrieve_similar(query, top_k=5)
","text":"Retrieves memory units similar to the given input based on embeddings.
Parameters:
Name Type Description Default query
Any
The query for retrieval.
required top_k
int
The number of similar units to retrieve.
5
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of similar memory units.
Raises:
Type Description MemoryRetrieverError
If retrieval fails due to embedding generation or storage issues.
Source code in src/aeiva/cognition/memory/memory_retriever.py
def retrieve_similar(self, query: Any, top_k: int = 5) -> List[MemoryUnit]:\n \"\"\"\n Retrieves memory units similar to the given input based on embeddings.\n\n Args:\n query (Any): The query for retrieval.\n top_k (int): The number of similar units to retrieve.\n\n Returns:\n List[MemoryUnit]: A list of similar memory units.\n\n Raises:\n MemoryRetrieverError: If retrieval fails due to embedding generation or storage issues.\n \"\"\"\n self.logger.info(f\"Retrieving top {top_k} similar MemoryUnits based on the query.\")\n try:\n # Generate embedding for the query\n self.logger.debug(\"Generating embedding for the query.\")\n embedding_response = self.embedder.embed(query)\n if not embedding_response.get(\"data\"):\n self.logger.error(\"Failed to generate embedding for the query.\")\n raise MemoryRetrieverError(\"Failed to generate embedding for the query.\")\n\n query_embedding = embedding_response[\"data\"][0].get(\"embedding\")\n if not query_embedding:\n self.logger.error(\"Embedding data is missing in the response.\")\n raise MemoryRetrieverError(\"Embedding data is missing in the response.\")\n\n self.logger.debug(f\"Embedding generated successfully: {query_embedding}\")\n\n # Perform similarity search via MemoryStorage\n self.logger.debug(\"Performing similarity search in the vector database.\")\n similar_units = self.storage.retrieve_similar_memory_units(query_embedding, top_k)\n self.logger.info(f\"Retrieved {len(similar_units)} similar MemoryUnits.\")\n return similar_units\n\n except MemoryRetrieverError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Unexpected error during retrieve_similar: {e}\")\n raise MemoryRetrieverError(f\"Unexpected error during retrieve_similar: {e}\") from e\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_retriever.MemoryRetrieverError","title":"MemoryRetrieverError
","text":" Bases: Exception
Exception raised when an error occurs in the MemoryRetriever.
Source code in src/aeiva/cognition/memory/memory_retriever.py
class MemoryRetrieverError(Exception):\n \"\"\"Exception raised when an error occurs in the MemoryRetriever.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_skillizer","title":"memory_skillizer
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_skillizer.MemorySkillizer","title":"MemorySkillizer
","text":"A class to skillize memory units based on various skillizing algorithms.
Supported skill types - 'skill_type_example': Placeholder for future skillizing algorithms.
Source code in src/aeiva/cognition/memory/memory_skillizer.py
class MemorySkillizer:\n \"\"\"\n A class to skillize memory units based on various skillizing algorithms.\n\n Supported skill types:\n - 'skill_type_example': Placeholder for future skillizing algorithms.\n \"\"\"\n\n def __init__(self):\n \"\"\"\n Initializes the MemorySkillizer.\n\n Currently, no initialization parameters are required.\n \"\"\"\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemorySkillizer without default parameters.\")\n\n def skillize(\n self,\n memory_units: List[MemoryUnit],\n skill_type: str,\n **kwargs\n ) -> List[MemoryUnit]:\n \"\"\"\n Skillizes the provided memory units based on the specified skill type.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be skillized.\n skill_type (str): The type of skillizing algorithm to use ('skill_type_example').\n **kwargs: Additional parameters required for specific skillizers.\n\n Returns:\n List[MemoryUnit]: The list of memory units after skillizing.\n\n Raises:\n MemorySkillizerError: If an unknown skill_type is provided or if skillizing fails.\n \"\"\"\n self.logger.debug(f\"Skillizing memory units using skill_type='{skill_type}' with kwargs={kwargs}\")\n try:\n if skill_type == 'skill_type_example':\n # Placeholder for actual skillizing logic\n return self.skillize_example(memory_units, **kwargs)\n else:\n self.logger.error(f\"Unknown skill_type: {skill_type}\")\n raise MemorySkillizerError(f\"Unknown skill_type: {skill_type}\")\n except MemorySkillizerError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to skillize memory units: {e}\")\n raise MemorySkillizerError(f\"Failed to skillize memory units: {e}\")\n\n def skillize_example(\n self,\n memory_units: List[MemoryUnit],\n **kwargs\n ) -> List[MemoryUnit]:\n \"\"\"\n Example skillizing method. Currently a placeholder that returns memory units unchanged.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be skillized.\n **kwargs: Additional parameters (currently unused).\n\n Returns:\n List[MemoryUnit]: The original list of memory units, unchanged.\n \"\"\"\n self.logger.debug(\"Executing skillize_example: No changes applied to memory units.\")\n # Placeholder: No operation performed\n return memory_units\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_skillizer.MemorySkillizer.__init__","title":"__init__()
","text":"Initializes the MemorySkillizer.
Currently, no initialization parameters are required.
Source code in src/aeiva/cognition/memory/memory_skillizer.py
def __init__(self):\n \"\"\"\n Initializes the MemorySkillizer.\n\n Currently, no initialization parameters are required.\n \"\"\"\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemorySkillizer without default parameters.\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_skillizer.MemorySkillizer.skillize","title":"skillize(memory_units, skill_type, **kwargs)
","text":"Skillizes the provided memory units based on the specified skill type.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be skillized.
required skill_type
str
The type of skillizing algorithm to use ('skill_type_example').
required **kwargs
Additional parameters required for specific skillizers.
{}
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The list of memory units after skillizing.
Raises:
Type Description MemorySkillizerError
If an unknown skill_type is provided or if skillizing fails.
Source code in src/aeiva/cognition/memory/memory_skillizer.py
def skillize(\n self,\n memory_units: List[MemoryUnit],\n skill_type: str,\n **kwargs\n) -> List[MemoryUnit]:\n \"\"\"\n Skillizes the provided memory units based on the specified skill type.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be skillized.\n skill_type (str): The type of skillizing algorithm to use ('skill_type_example').\n **kwargs: Additional parameters required for specific skillizers.\n\n Returns:\n List[MemoryUnit]: The list of memory units after skillizing.\n\n Raises:\n MemorySkillizerError: If an unknown skill_type is provided or if skillizing fails.\n \"\"\"\n self.logger.debug(f\"Skillizing memory units using skill_type='{skill_type}' with kwargs={kwargs}\")\n try:\n if skill_type == 'skill_type_example':\n # Placeholder for actual skillizing logic\n return self.skillize_example(memory_units, **kwargs)\n else:\n self.logger.error(f\"Unknown skill_type: {skill_type}\")\n raise MemorySkillizerError(f\"Unknown skill_type: {skill_type}\")\n except MemorySkillizerError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to skillize memory units: {e}\")\n raise MemorySkillizerError(f\"Failed to skillize memory units: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_skillizer.MemorySkillizer.skillize_example","title":"skillize_example(memory_units, **kwargs)
","text":"Example skillizing method. Currently a placeholder that returns memory units unchanged.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be skillized.
required **kwargs
Additional parameters (currently unused).
{}
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The original list of memory units, unchanged.
Source code in src/aeiva/cognition/memory/memory_skillizer.py
def skillize_example(\n self,\n memory_units: List[MemoryUnit],\n **kwargs\n) -> List[MemoryUnit]:\n \"\"\"\n Example skillizing method. Currently a placeholder that returns memory units unchanged.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be skillized.\n **kwargs: Additional parameters (currently unused).\n\n Returns:\n List[MemoryUnit]: The original list of memory units, unchanged.\n \"\"\"\n self.logger.debug(\"Executing skillize_example: No changes applied to memory units.\")\n # Placeholder: No operation performed\n return memory_units\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_skillizer.MemorySkillizerError","title":"MemorySkillizerError
","text":" Bases: Exception
Exception raised when an error occurs in the MemorySkillizer.
Source code in src/aeiva/cognition/memory/memory_skillizer.py
class MemorySkillizerError(Exception):\n \"\"\"Exception raised when an error occurs in the MemorySkillizer.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage","title":"memory_storage
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryEventRepository","title":"MemoryEventRepository
","text":"Repository for MemoryEvent to handle CRUD operations without SQLAlchemy.
Source code in src/aeiva/cognition/memory/memory_storage.py
class MemoryEventRepository:\n \"\"\"\n Repository for MemoryEvent to handle CRUD operations without SQLAlchemy.\n \"\"\"\n\n def __init__(self, db: Any):\n \"\"\"\n Initialize the repository with a DatabaseFactory instance.\n\n Args:\n db (Any): An instance of DatabaseFactory for relational databases.\n \"\"\"\n self.db = db\n self.table_name = 'memory_events'\n self._create_table()\n\n def _create_table(self):\n \"\"\"\n Creates the memory_events table if it does not exist.\n \"\"\"\n create_table_query = f\"\"\"\n CREATE TABLE IF NOT EXISTS {self.table_name} (\n id TEXT PRIMARY KEY,\n memory_id TEXT NOT NULL,\n event_type TEXT NOT NULL,\n timestamp TEXT NOT NULL,\n memory_data TEXT,\n previous_data TEXT\n );\n \"\"\"\n self.db.execute_sql(create_table_query)\n\n def add(self, event: Dict[str, Any]) -> None:\n \"\"\"\n Adds a MemoryEvent to the relational database.\n\n Args:\n event (Dict[str, Any]): The event data to add.\n \"\"\"\n insert_query = f\"\"\"\n INSERT INTO {self.table_name} (id, memory_id, event_type, timestamp, memory_data, previous_data)\n VALUES (?, ?, ?, ?, ?, ?);\n \"\"\"\n data = (\n event.get('id', uuid4().hex),\n event['memory_id'],\n event['event_type'],\n datetime.utcnow().isoformat(), # TODO: revise utcnow.\n event.get('memory_data'),\n event.get('previous_data')\n )\n self.db.execute_sql(insert_query, data)\n\n def get(self, event_id: str) -> Optional[Dict[str, Any]]:\n \"\"\"\n Retrieves a MemoryEvent by its ID.\n\n Args:\n event_id (str): The unique identifier of the event.\n\n Returns:\n Optional[Dict[str, Any]]: The event data or None if not found.\n \"\"\"\n select_query = f\"SELECT * FROM {self.table_name} WHERE id = ?;\"\n result = self.db.execute_sql(select_query, (event_id,))\n row = result.fetchone()\n if row:\n return self._row_to_event(row)\n return None\n\n def delete_all(self) -> None:\n \"\"\"\n Deletes all MemoryEvents from the relational database.\n \"\"\"\n delete_query = f\"DELETE FROM {self.table_name};\"\n self.db.execute_sql(delete_query)\n\n def list_all(self) -> List[Dict[str, Any]]:\n \"\"\"\n Retrieves all MemoryEvents from the relational database.\n\n Returns:\n List[Dict[str, Any]]: A list of all events.\n \"\"\"\n select_query = f\"SELECT * FROM {self.table_name};\"\n results = self.db.execute_sql(select_query)\n return [self._row_to_event(row) for row in results.fetchall()]\n\n def _row_to_event(self, row: Any) -> Dict[str, Any]:\n \"\"\"\n Converts a database row to an event dictionary.\n\n Args:\n row (Any): A row fetched from the database.\n\n Returns:\n Dict[str, Any]: The corresponding event data.\n \"\"\"\n return {\n \"id\": row['id'],\n \"memory_id\": row['memory_id'],\n \"event_type\": row['event_type'],\n \"timestamp\": datetime.fromisoformat(row['timestamp']),\n \"memory_data\": json.loads(row['memory_data']) if row['memory_data'] else None,\n \"previous_data\": json.loads(row['previous_data']) if row['previous_data'] else None\n }\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryEventRepository.__init__","title":"__init__(db)
","text":"Initialize the repository with a DatabaseFactory instance.
Parameters:
Name Type Description Default db
Any
An instance of DatabaseFactory for relational databases.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def __init__(self, db: Any):\n \"\"\"\n Initialize the repository with a DatabaseFactory instance.\n\n Args:\n db (Any): An instance of DatabaseFactory for relational databases.\n \"\"\"\n self.db = db\n self.table_name = 'memory_events'\n self._create_table()\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryEventRepository.add","title":"add(event)
","text":"Adds a MemoryEvent to the relational database.
Parameters:
Name Type Description Default event
Dict[str, Any]
The event data to add.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def add(self, event: Dict[str, Any]) -> None:\n \"\"\"\n Adds a MemoryEvent to the relational database.\n\n Args:\n event (Dict[str, Any]): The event data to add.\n \"\"\"\n insert_query = f\"\"\"\n INSERT INTO {self.table_name} (id, memory_id, event_type, timestamp, memory_data, previous_data)\n VALUES (?, ?, ?, ?, ?, ?);\n \"\"\"\n data = (\n event.get('id', uuid4().hex),\n event['memory_id'],\n event['event_type'],\n datetime.utcnow().isoformat(), # TODO: revise utcnow.\n event.get('memory_data'),\n event.get('previous_data')\n )\n self.db.execute_sql(insert_query, data)\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryEventRepository.delete_all","title":"delete_all()
","text":"Deletes all MemoryEvents from the relational database.
Source code in src/aeiva/cognition/memory/memory_storage.py
def delete_all(self) -> None:\n \"\"\"\n Deletes all MemoryEvents from the relational database.\n \"\"\"\n delete_query = f\"DELETE FROM {self.table_name};\"\n self.db.execute_sql(delete_query)\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryEventRepository.get","title":"get(event_id)
","text":"Retrieves a MemoryEvent by its ID.
Parameters:
Name Type Description Default event_id
str
The unique identifier of the event.
required Returns:
Type Description Optional[Dict[str, Any]]
Optional[Dict[str, Any]]: The event data or None if not found.
Source code in src/aeiva/cognition/memory/memory_storage.py
def get(self, event_id: str) -> Optional[Dict[str, Any]]:\n \"\"\"\n Retrieves a MemoryEvent by its ID.\n\n Args:\n event_id (str): The unique identifier of the event.\n\n Returns:\n Optional[Dict[str, Any]]: The event data or None if not found.\n \"\"\"\n select_query = f\"SELECT * FROM {self.table_name} WHERE id = ?;\"\n result = self.db.execute_sql(select_query, (event_id,))\n row = result.fetchone()\n if row:\n return self._row_to_event(row)\n return None\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryEventRepository.list_all","title":"list_all()
","text":"Retrieves all MemoryEvents from the relational database.
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of all events.
Source code in src/aeiva/cognition/memory/memory_storage.py
def list_all(self) -> List[Dict[str, Any]]:\n \"\"\"\n Retrieves all MemoryEvents from the relational database.\n\n Returns:\n List[Dict[str, Any]]: A list of all events.\n \"\"\"\n select_query = f\"SELECT * FROM {self.table_name};\"\n results = self.db.execute_sql(select_query)\n return [self._row_to_event(row) for row in results.fetchall()]\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage","title":"MemoryStorage
","text":"Handles storage operations for MemoryPalace, including interactions with vector, graph, and relational databases.
Source code in src/aeiva/cognition/memory/memory_storage.py
class MemoryStorage:\n \"\"\"\n Handles storage operations for MemoryPalace, including interactions with vector,\n graph, and relational databases.\n \"\"\"\n\n def __init__(self, config: Dict):\n \"\"\"\n Initialize the MemoryStorage with the provided configuration.\n\n Args:\n config (Any): Configuration settings for MemoryStorage.\n \"\"\"\n self.config_dict = config\n self.config = None\n self.setup()\n\n def setup(self) -> None:\n \"\"\"\n Set up the MemoryStorage's components based on the provided configuration.\n \"\"\"\n try:\n # Initialize Vector Database Configuration\n vector_db_conf_dict = self.config_dict.get('vector_db_config', {})\n vector_db_provider_name = vector_db_conf_dict.get('provider_name', 'milvus')\n vector_db_config = DatabaseConfigFactory.create(\n provider_name=vector_db_conf_dict.get('provider_name', 'milvus'),\n uri=vector_db_conf_dict.get('uri', 'storage/milvus_demo.db'),\n collection_name=vector_db_conf_dict.get('collection_name', 'test_collection'),\n embedding_model_dims=vector_db_conf_dict.get('embedding_model_dims', 1536), # 'text-embedding-ada-002': 1536,\n metric_type=vector_db_conf_dict.get('metric_type', 'COSINE')\n )\n\n # Initialize Graph Database Configuration\n graph_db_conf_dict = self.config_dict.get('graph_db_config', {})\n graph_db_provider_name = graph_db_conf_dict.get('provider_name', 'neo4j')\n graph_db_password = graph_db_conf_dict.get('password')\n graph_db_config = DatabaseConfigFactory.create(\n provider_name=graph_db_conf_dict.get('provider_name', 'neo4j'),\n uri=graph_db_conf_dict.get('uri', 'bolt://localhost:7687'),\n user=graph_db_conf_dict.get('user', 'neo4j'),\n password=graph_db_password,\n database=graph_db_conf_dict.get('database', 'neo4j'),\n encrypted=graph_db_conf_dict.get('encrypted', False)\n )\n\n # Initialize Relational Database Configuration\n relational_db_conf_dict = self.config_dict.get('relational_db_config', {})\n relational_db_provider_name = relational_db_conf_dict.get('provider_name', 'sqlite')\n relational_db_config = DatabaseConfigFactory.create(\n provider_name=relational_db_conf_dict.get('provider_name', 'sqlite'),\n database=relational_db_conf_dict.get('database', 'storage/test_database.db')\n )\n\n self.config = StorageConfig(\n vector_db_provider=self.config_dict.get('vector_db_provider', 'milvus'),\n vector_db_config=vector_db_config,\n graph_db_provider=self.config_dict.get('graph_db_provider', 'neo4j'),\n graph_db_config=graph_db_config,\n relational_db_provider=self.config_dict.get('relational_db_provider', 'sqlite'),\n relational_db_config=relational_db_config,\n )\n\n # Initialize the vector database\n self.vector_db = DatabaseFactory.create(\n provider_name=vector_db_provider_name,\n config=self.config.vector_db_config\n )\n\n # Initialize the graph database if provided\n if graph_db_provider_name and self.config.graph_db_config:\n self.graph_db = DatabaseFactory.create(\n provider_name=graph_db_provider_name,\n config=self.config.graph_db_config\n )\n else:\n self.graph_db = None\n\n # Initialize the relational database if provided\n if relational_db_provider_name and self.config.relational_db_config:\n self.relational_db = DatabaseFactory.create(\n provider_name=relational_db_provider_name,\n config=self.config.relational_db_config\n )\n self.memory_unit_repo = MemoryUnitRepository(self.relational_db)\n self.memory_event_repo = MemoryEventRepository(self.relational_db)\n else:\n self.relational_db = None\n self.memory_unit_repo = None\n self.memory_event_repo = None\n\n logger.info(\"MemoryStorage setup completed successfully.\")\n except Exception as e:\n logger.error(f\"Error during MemoryStorage setup: {e}\")\n self.handle_error(e)\n raise # Re-raise the exception after logging\n\n def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during storage operations.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n logger.error(f\"MemoryStorage encountered an error: {error}\")\n # Additional error handling can be implemented here\n\n def add_memory_unit(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Adds a MemoryUnit to all configured databases.\n\n Args:\n memory_unit (MemoryUnit): The memory unit to add.\n \"\"\"\n try:\n # Add to vector database\n self._add_to_vector_db(memory_unit)\n\n # Add to graph database\n if self.graph_db:\n self._add_to_graph_db(memory_unit)\n\n # Add to relational database\n if self.relational_db and self.memory_unit_repo:\n self._add_to_relational_db(memory_unit)\n\n # Record creation event\n if self.relational_db and self.memory_event_repo:\n self._record_event(\n event_type=\"CREATE\",\n memory_unit=memory_unit\n )\n\n logger.info(f\"Added MemoryUnit with ID: {memory_unit.id} to all databases.\")\n except Exception as e:\n logger.error(f\"Error adding MemoryUnit to databases: {e}\")\n self.handle_error(e)\n raise\n\n def get_memory_unit(self, unit_id: str) -> MemoryUnit:\n \"\"\"\n Retrieves a MemoryUnit by its unique identifier from the relational database.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n\n Returns:\n MemoryUnit: The retrieved memory unit.\n \"\"\"\n try:\n if not self.relational_db or not self.memory_unit_repo:\n raise ValueError(\"Relational database is not configured.\")\n\n memory_unit = self.memory_unit_repo.get(unit_id)\n if not memory_unit:\n raise ValueError(f\"MemoryUnit with ID {unit_id} does not exist.\")\n\n logger.info(f\"Retrieved MemoryUnit with ID: {unit_id} from Relational DB.\")\n return memory_unit\n except Exception as e:\n logger.error(f\"Error retrieving MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n\n def update_memory_unit(self, unit_id: str, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a MemoryUnit in all configured databases.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n updates (Dict[str, Any]): The updates to apply.\n \"\"\"\n try:\n # Retrieve existing MemoryUnit\n memory_unit = self.get_memory_unit(unit_id)\n previous_state = memory_unit.to_dict()\n\n # Apply updates\n for key, value in updates.items():\n setattr(memory_unit, key, value)\n\n # Update in vector database\n self._update_vector_db(memory_unit)\n\n # Update in graph database\n if self.graph_db:\n self._update_graph_db(memory_unit)\n\n # Update in relational database\n if self.relational_db and self.memory_unit_repo:\n self._update_relational_db(memory_unit)\n\n # Record update event\n if self.relational_db and self.memory_event_repo:\n self._record_event(\n event_type=\"UPDATE\",\n memory_unit=memory_unit,\n previous_state=previous_state\n )\n\n logger.info(f\"Updated MemoryUnit with ID: {unit_id} in all databases.\")\n except Exception as e:\n logger.error(f\"Error updating MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n\n def delete_memory_unit(self, unit_id: str) -> None:\n \"\"\"\n Deletes a MemoryUnit from all configured databases.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n try:\n # Retrieve existing MemoryUnit\n memory_unit = self.get_memory_unit(unit_id)\n\n # Delete from vector database\n self._delete_from_vector_db(unit_id)\n\n # Delete from graph database\n if self.graph_db:\n self._delete_from_graph_db(unit_id)\n\n # Delete from relational database\n if self.relational_db and self.memory_unit_repo:\n self._delete_relational_db(unit_id)\n\n # Record deletion event\n if self.relational_db and self.memory_event_repo:\n self._record_event(\n event_type=\"DELETE\",\n memory_unit=memory_unit\n )\n\n logger.info(f\"Deleted MemoryUnit with ID: {unit_id} from all databases.\")\n except Exception as e:\n logger.error(f\"Error deleting MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n\n def get_all_memory_units(self) -> List[MemoryUnit]:\n \"\"\"\n Retrieves all MemoryUnits from the relational database.\n\n Returns:\n List[MemoryUnit]: A list of all memory units.\n \"\"\"\n try:\n if not self.relational_db or not self.memory_unit_repo:\n raise ValueError(\"Relational database is not configured.\")\n\n memory_units = self.memory_unit_repo.list_all()\n logger.info(f\"Retrieved all MemoryUnits from Relational DB. Total count: {len(memory_units)}\")\n return memory_units\n except Exception as e:\n logger.error(f\"Error retrieving all MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n\n def delete_all_memory_units(self) -> None:\n \"\"\"\n Deletes all MemoryUnits from all configured databases.\n \"\"\"\n try:\n # Delete from vector database\n self.vector_db.delete_collection(\n collection_name=self.config.vector_db_config.collection_name\n )\n\n # Delete all nodes from graph database\n if self.graph_db:\n self.graph_db.delete_all()\n\n # Delete all records from relational database\n if self.relational_db and self.memory_unit_repo and self.memory_event_repo:\n self.memory_unit_repo.delete_all()\n self.memory_event_repo.delete_all()\n\n logger.info(\"Deleted all MemoryUnits from all databases.\")\n except Exception as e:\n logger.error(f\"Error deleting all MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n\n # Internal helper methods\n\n def _add_to_vector_db(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Adds the embedding vector of a MemoryUnit to the vector database.\n\n Args:\n memory_unit (MemoryUnit): The memory unit to add.\n \"\"\"\n try:\n # Ensure embedding exists\n if not memory_unit.embedding:\n raise ValueError(\"MemoryUnit does not have an embedding.\")\n\n # Prepare payload with essential fields\n payload = {\n \"id\": memory_unit.id,\n \"type\": memory_unit.type,\n \"modality\": memory_unit.modality\n }\n\n # Insert into vector database\n self.vector_db.insert_vectors(\n collection_name=self.config.vector_db_config.collection_name,\n vectors=[memory_unit.embedding],\n payloads=[payload],\n ids=[memory_unit.id]\n )\n\n logger.info(f\"Inserted embedding for MemoryUnit ID: {memory_unit.id} into Vector DB.\")\n except Exception as e:\n logger.error(f\"Error adding MemoryUnit to Vector DB: {e}\")\n self.handle_error(e)\n raise\n\n def _update_vector_db(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Updates the embedding vector of a MemoryUnit in the vector database.\n\n Args:\n memory_unit (MemoryUnit): The memory unit to update.\n \"\"\"\n try:\n if not memory_unit.embedding:\n raise ValueError(\"MemoryUnit does not have an embedding.\")\n\n payload = {\n \"type\": memory_unit.type,\n \"modality\": memory_unit.modality\n }\n\n self.vector_db.update_vector(\n collection_name=self.config.vector_db_config.collection_name,\n vector_id=memory_unit.id,\n vector=memory_unit.embedding,\n payload=payload\n )\n\n logger.info(f\"Updated embedding for MemoryUnit ID: {memory_unit.id} in Vector DB.\")\n except Exception as e:\n logger.error(f\"Error updating MemoryUnit in Vector DB: {e}\")\n self.handle_error(e)\n raise\n\n def _delete_from_vector_db(self, unit_id: str) -> None:\n \"\"\"\n Deletes a MemoryUnit's embedding from the vector database.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n try:\n self.vector_db.delete_vector(\n collection_name=self.config.vector_db_config.collection_name,\n vector_id=unit_id\n )\n\n logger.info(f\"Deleted embedding for MemoryUnit ID: {unit_id} from Vector DB.\")\n except Exception as e:\n logger.error(f\"Error deleting MemoryUnit from Vector DB: {e}\")\n self.handle_error(e)\n raise\n\n def _add_to_graph_db(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Adds a MemoryUnit as a node in the graph database and establishes relationships.\n\n Args:\n memory_unit (MemoryUnit): The memory unit to add.\n \"\"\"\n try:\n # Serialize complex fields\n properties = {\n \"id\": memory_unit.id,\n \"content\": memory_unit.content,\n \"timestamp\": memory_unit.timestamp.isoformat(),\n \"modality\": memory_unit.modality,\n \"type\": memory_unit.type,\n \"status\": memory_unit.status,\n \"tags\": memory_unit.tags,\n \"embedding\": memory_unit.embedding,\n \"location\": json.dumps(memory_unit.location) if memory_unit.location else None, # Serialized\n \"source_role\": memory_unit.source_role,\n \"source_name\": memory_unit.source_name,\n \"source_id\": memory_unit.source_id,\n \"metadata\": json.dumps(memory_unit.metadata) if memory_unit.metadata else None # Serialized\n }\n\n # Add node to graph database\n self.graph_db.add_node(\n node_id=memory_unit.id,\n properties=properties,\n labels=[memory_unit.type or 'MemoryUnit']\n )\n\n logger.info(f\"Added MemoryUnit ID: {memory_unit.id} to Graph DB.\")\n\n # Add relationships (edges) if any\n for link in memory_unit.edges:\n # Serialize edge metadata if necessary\n edge_properties = {}\n if link.metadata:\n edge_properties['metadata'] = json.dumps(link.metadata)\n\n self.graph_db.add_edge(\n source_id=link.source_id,\n target_id=link.target_id,\n relationship=link.relationship,\n properties=edge_properties\n )\n\n logger.info(f\"Added {len(memory_unit.edges)} edges for MemoryUnit ID: {memory_unit.id} in Graph DB.\")\n except Exception as e:\n logger.error(f\"Error adding MemoryUnit to Graph DB: {e}\")\n self.handle_error(e)\n raise\n\n def _update_graph_db(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Updates a MemoryUnit in the graph database.\n\n Args:\n memory_unit (MemoryUnit): The memory unit to update.\n \"\"\"\n try:\n # Update node properties\n properties = {\n \"content\": memory_unit.content,\n \"timestamp\": memory_unit.timestamp.isoformat(),\n \"modality\": memory_unit.modality,\n \"type\": memory_unit.type,\n \"status\": memory_unit.status,\n \"tags\": memory_unit.tags,\n \"embedding\": memory_unit.embedding,\n \"location\": json.dumps(memory_unit.location) if memory_unit.location else None, # Serialized\n \"source_role\": memory_unit.source_role,\n \"source_name\": memory_unit.source_name,\n \"source_id\": memory_unit.source_id,\n \"metadata\": json.dumps(memory_unit.metadata) if memory_unit.metadata else None # Serialized\n }\n\n self.graph_db.update_node(\n node_id=memory_unit.id,\n properties=properties\n )\n\n # Handle edges updates as needed\n # This can be complex and depends on your specific requirements\n\n logger.info(f\"Updated MemoryUnit ID: {memory_unit.id} in Graph DB.\")\n except Exception as e:\n logger.error(f\"Error updating MemoryUnit in Graph DB: {e}\")\n self.handle_error(e)\n raise\n\n def _delete_from_graph_db(self, unit_id: str) -> None:\n \"\"\"\n Deletes a MemoryUnit from the graph database.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n try:\n self.graph_db.delete_node(node_id=unit_id)\n logger.info(f\"Deleted MemoryUnit ID: {unit_id} from Graph DB.\")\n except Exception as e:\n logger.error(f\"Error deleting MemoryUnit from Graph DB: {e}\")\n self.handle_error(e)\n raise\n\n def _add_to_relational_db(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Adds a MemoryUnit to the relational database.\n\n Args:\n memory_unit (MemoryUnit): The memory unit to add.\n \"\"\"\n try:\n self.memory_unit_repo.add(memory_unit)\n logger.info(f\"Inserted MemoryUnit ID: {memory_unit.id} into Relational DB.\")\n except Exception as e:\n logger.error(f\"Error adding MemoryUnit to Relational DB: {e}\")\n raise\n\n def _update_relational_db(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Updates a MemoryUnit in the relational database.\n\n Args:\n memory_unit (MemoryUnit): The memory unit to update.\n \"\"\"\n try:\n self.memory_unit_repo.update(memory_unit)\n logger.info(f\"Updated MemoryUnit ID: {memory_unit.id} in Relational DB.\")\n except Exception as e:\n logger.error(f\"Error updating MemoryUnit in Relational DB: {e}\")\n raise\n\n def _delete_relational_db(self, unit_id: str) -> None:\n \"\"\"\n Deletes a MemoryUnit from the relational database.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n try:\n self.memory_unit_repo.delete(unit_id)\n logger.info(f\"Deleted MemoryUnit ID: {unit_id} from Relational DB.\")\n except Exception as e:\n logger.error(f\"Error deleting MemoryUnit from Relational DB: {e}\")\n raise\n\n def _record_event(self, event_type: str, memory_unit: MemoryUnit, previous_state: Optional[Dict[str, Any]] = None) -> None:\n \"\"\"\n Records an event in the relational database.\n\n Args:\n event_type (str): The type of event ('CREATE', 'UPDATE', 'DELETE').\n memory_unit (MemoryUnit): The memory unit involved in the event.\n previous_state (Optional[Dict[str, Any]]): The previous state of the memory unit (for updates).\n \"\"\"\n try:\n event_record = {\n \"memory_id\": memory_unit.id,\n \"event_type\": event_type,\n \"memory_data\": json.dumps(memory_unit.to_dict()),\n \"previous_data\": json.dumps(previous_state) if previous_state else None\n }\n\n self.memory_event_repo.add(event_record)\n logger.info(f\"Recorded event '{event_type}' for MemoryUnit ID: {memory_unit.id}.\")\n except Exception as e:\n logger.error(f\"Error recording event in Relational DB: {e}\")\n raise\n\n def retrieve_similar_memory_units(self, query_embedding: List[float], top_k: int) -> List[MemoryUnit]:\n \"\"\"\n Retrieves memory units similar to the given embedding.\n\n Args:\n query_embedding (List[float]): The embedding vector of the query.\n top_k (int): The number of similar units to retrieve.\n\n Returns:\n List[MemoryUnit]: A list of similar memory units.\n \"\"\"\n try:\n # Perform similarity search\n results = self.vector_db.search_vectors(\n collection_name=self.config.vector_db_config.collection_name,\n query_vector=query_embedding,\n top_k=top_k\n )\n\n memory_units = []\n for result in results:\n unit_id = result['id']\n memory_unit = self.get_memory_unit(unit_id)\n memory_units.append(memory_unit)\n\n logger.info(f\"Retrieved {len(memory_units)} similar MemoryUnits.\")\n return memory_units\n except Exception as e:\n logger.error(f\"Error retrieving similar MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n\n def retrieve_related_memory_units(self, unit_id: str, relationship: Optional[str] = None) -> List[MemoryUnit]:\n \"\"\"\n Retrieves memory units related to the given one based on relationships.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n relationship (Optional[str]): Filter by relationship type.\n\n Returns:\n List[MemoryUnit]: A list of related memory units.\n \"\"\"\n try:\n if not self.graph_db:\n raise ValueError(\"Graph database is not configured.\")\n\n # Retrieve related nodes from graph database\n neighbors = self.graph_db.get_neighbors(\n node_id=unit_id,\n relationship=relationship\n )\n\n related_units = []\n for neighbor in neighbors:\n related_unit = self.get_memory_unit(neighbor['id'])\n related_units.append(related_unit)\n\n logger.info(f\"Retrieved {len(related_units)} related MemoryUnits.\")\n return related_units\n except Exception as e:\n logger.error(f\"Error retrieving related MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.__init__","title":"__init__(config)
","text":"Initialize the MemoryStorage with the provided configuration.
Parameters:
Name Type Description Default config
Any
Configuration settings for MemoryStorage.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def __init__(self, config: Dict):\n \"\"\"\n Initialize the MemoryStorage with the provided configuration.\n\n Args:\n config (Any): Configuration settings for MemoryStorage.\n \"\"\"\n self.config_dict = config\n self.config = None\n self.setup()\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.add_memory_unit","title":"add_memory_unit(memory_unit)
","text":"Adds a MemoryUnit to all configured databases.
Parameters:
Name Type Description Default memory_unit
MemoryUnit
The memory unit to add.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def add_memory_unit(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Adds a MemoryUnit to all configured databases.\n\n Args:\n memory_unit (MemoryUnit): The memory unit to add.\n \"\"\"\n try:\n # Add to vector database\n self._add_to_vector_db(memory_unit)\n\n # Add to graph database\n if self.graph_db:\n self._add_to_graph_db(memory_unit)\n\n # Add to relational database\n if self.relational_db and self.memory_unit_repo:\n self._add_to_relational_db(memory_unit)\n\n # Record creation event\n if self.relational_db and self.memory_event_repo:\n self._record_event(\n event_type=\"CREATE\",\n memory_unit=memory_unit\n )\n\n logger.info(f\"Added MemoryUnit with ID: {memory_unit.id} to all databases.\")\n except Exception as e:\n logger.error(f\"Error adding MemoryUnit to databases: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.delete_all_memory_units","title":"delete_all_memory_units()
","text":"Deletes all MemoryUnits from all configured databases.
Source code in src/aeiva/cognition/memory/memory_storage.py
def delete_all_memory_units(self) -> None:\n \"\"\"\n Deletes all MemoryUnits from all configured databases.\n \"\"\"\n try:\n # Delete from vector database\n self.vector_db.delete_collection(\n collection_name=self.config.vector_db_config.collection_name\n )\n\n # Delete all nodes from graph database\n if self.graph_db:\n self.graph_db.delete_all()\n\n # Delete all records from relational database\n if self.relational_db and self.memory_unit_repo and self.memory_event_repo:\n self.memory_unit_repo.delete_all()\n self.memory_event_repo.delete_all()\n\n logger.info(\"Deleted all MemoryUnits from all databases.\")\n except Exception as e:\n logger.error(f\"Error deleting all MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.delete_memory_unit","title":"delete_memory_unit(unit_id)
","text":"Deletes a MemoryUnit from all configured databases.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def delete_memory_unit(self, unit_id: str) -> None:\n \"\"\"\n Deletes a MemoryUnit from all configured databases.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n try:\n # Retrieve existing MemoryUnit\n memory_unit = self.get_memory_unit(unit_id)\n\n # Delete from vector database\n self._delete_from_vector_db(unit_id)\n\n # Delete from graph database\n if self.graph_db:\n self._delete_from_graph_db(unit_id)\n\n # Delete from relational database\n if self.relational_db and self.memory_unit_repo:\n self._delete_relational_db(unit_id)\n\n # Record deletion event\n if self.relational_db and self.memory_event_repo:\n self._record_event(\n event_type=\"DELETE\",\n memory_unit=memory_unit\n )\n\n logger.info(f\"Deleted MemoryUnit with ID: {unit_id} from all databases.\")\n except Exception as e:\n logger.error(f\"Error deleting MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.get_all_memory_units","title":"get_all_memory_units()
","text":"Retrieves all MemoryUnits from the relational database.
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of all memory units.
Source code in src/aeiva/cognition/memory/memory_storage.py
def get_all_memory_units(self) -> List[MemoryUnit]:\n \"\"\"\n Retrieves all MemoryUnits from the relational database.\n\n Returns:\n List[MemoryUnit]: A list of all memory units.\n \"\"\"\n try:\n if not self.relational_db or not self.memory_unit_repo:\n raise ValueError(\"Relational database is not configured.\")\n\n memory_units = self.memory_unit_repo.list_all()\n logger.info(f\"Retrieved all MemoryUnits from Relational DB. Total count: {len(memory_units)}\")\n return memory_units\n except Exception as e:\n logger.error(f\"Error retrieving all MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.get_memory_unit","title":"get_memory_unit(unit_id)
","text":"Retrieves a MemoryUnit by its unique identifier from the relational database.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required Returns:
Name Type Description MemoryUnit
MemoryUnit
The retrieved memory unit.
Source code in src/aeiva/cognition/memory/memory_storage.py
def get_memory_unit(self, unit_id: str) -> MemoryUnit:\n \"\"\"\n Retrieves a MemoryUnit by its unique identifier from the relational database.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n\n Returns:\n MemoryUnit: The retrieved memory unit.\n \"\"\"\n try:\n if not self.relational_db or not self.memory_unit_repo:\n raise ValueError(\"Relational database is not configured.\")\n\n memory_unit = self.memory_unit_repo.get(unit_id)\n if not memory_unit:\n raise ValueError(f\"MemoryUnit with ID {unit_id} does not exist.\")\n\n logger.info(f\"Retrieved MemoryUnit with ID: {unit_id} from Relational DB.\")\n return memory_unit\n except Exception as e:\n logger.error(f\"Error retrieving MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.handle_error","title":"handle_error(error)
","text":"Handle errors that occur during storage operations.
Parameters:
Name Type Description Default error
Exception
The exception that was raised.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during storage operations.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n logger.error(f\"MemoryStorage encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.retrieve_related_memory_units","title":"retrieve_related_memory_units(unit_id, relationship=None)
","text":"Retrieves memory units related to the given one based on relationships.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required relationship
Optional[str]
Filter by relationship type.
None
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of related memory units.
Source code in src/aeiva/cognition/memory/memory_storage.py
def retrieve_related_memory_units(self, unit_id: str, relationship: Optional[str] = None) -> List[MemoryUnit]:\n \"\"\"\n Retrieves memory units related to the given one based on relationships.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n relationship (Optional[str]): Filter by relationship type.\n\n Returns:\n List[MemoryUnit]: A list of related memory units.\n \"\"\"\n try:\n if not self.graph_db:\n raise ValueError(\"Graph database is not configured.\")\n\n # Retrieve related nodes from graph database\n neighbors = self.graph_db.get_neighbors(\n node_id=unit_id,\n relationship=relationship\n )\n\n related_units = []\n for neighbor in neighbors:\n related_unit = self.get_memory_unit(neighbor['id'])\n related_units.append(related_unit)\n\n logger.info(f\"Retrieved {len(related_units)} related MemoryUnits.\")\n return related_units\n except Exception as e:\n logger.error(f\"Error retrieving related MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.retrieve_similar_memory_units","title":"retrieve_similar_memory_units(query_embedding, top_k)
","text":"Retrieves memory units similar to the given embedding.
Parameters:
Name Type Description Default query_embedding
List[float]
The embedding vector of the query.
required top_k
int
The number of similar units to retrieve.
required Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of similar memory units.
Source code in src/aeiva/cognition/memory/memory_storage.py
def retrieve_similar_memory_units(self, query_embedding: List[float], top_k: int) -> List[MemoryUnit]:\n \"\"\"\n Retrieves memory units similar to the given embedding.\n\n Args:\n query_embedding (List[float]): The embedding vector of the query.\n top_k (int): The number of similar units to retrieve.\n\n Returns:\n List[MemoryUnit]: A list of similar memory units.\n \"\"\"\n try:\n # Perform similarity search\n results = self.vector_db.search_vectors(\n collection_name=self.config.vector_db_config.collection_name,\n query_vector=query_embedding,\n top_k=top_k\n )\n\n memory_units = []\n for result in results:\n unit_id = result['id']\n memory_unit = self.get_memory_unit(unit_id)\n memory_units.append(memory_unit)\n\n logger.info(f\"Retrieved {len(memory_units)} similar MemoryUnits.\")\n return memory_units\n except Exception as e:\n logger.error(f\"Error retrieving similar MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.setup","title":"setup()
","text":"Set up the MemoryStorage's components based on the provided configuration.
Source code in src/aeiva/cognition/memory/memory_storage.py
def setup(self) -> None:\n \"\"\"\n Set up the MemoryStorage's components based on the provided configuration.\n \"\"\"\n try:\n # Initialize Vector Database Configuration\n vector_db_conf_dict = self.config_dict.get('vector_db_config', {})\n vector_db_provider_name = vector_db_conf_dict.get('provider_name', 'milvus')\n vector_db_config = DatabaseConfigFactory.create(\n provider_name=vector_db_conf_dict.get('provider_name', 'milvus'),\n uri=vector_db_conf_dict.get('uri', 'storage/milvus_demo.db'),\n collection_name=vector_db_conf_dict.get('collection_name', 'test_collection'),\n embedding_model_dims=vector_db_conf_dict.get('embedding_model_dims', 1536), # 'text-embedding-ada-002': 1536,\n metric_type=vector_db_conf_dict.get('metric_type', 'COSINE')\n )\n\n # Initialize Graph Database Configuration\n graph_db_conf_dict = self.config_dict.get('graph_db_config', {})\n graph_db_provider_name = graph_db_conf_dict.get('provider_name', 'neo4j')\n graph_db_password = graph_db_conf_dict.get('password')\n graph_db_config = DatabaseConfigFactory.create(\n provider_name=graph_db_conf_dict.get('provider_name', 'neo4j'),\n uri=graph_db_conf_dict.get('uri', 'bolt://localhost:7687'),\n user=graph_db_conf_dict.get('user', 'neo4j'),\n password=graph_db_password,\n database=graph_db_conf_dict.get('database', 'neo4j'),\n encrypted=graph_db_conf_dict.get('encrypted', False)\n )\n\n # Initialize Relational Database Configuration\n relational_db_conf_dict = self.config_dict.get('relational_db_config', {})\n relational_db_provider_name = relational_db_conf_dict.get('provider_name', 'sqlite')\n relational_db_config = DatabaseConfigFactory.create(\n provider_name=relational_db_conf_dict.get('provider_name', 'sqlite'),\n database=relational_db_conf_dict.get('database', 'storage/test_database.db')\n )\n\n self.config = StorageConfig(\n vector_db_provider=self.config_dict.get('vector_db_provider', 'milvus'),\n vector_db_config=vector_db_config,\n graph_db_provider=self.config_dict.get('graph_db_provider', 'neo4j'),\n graph_db_config=graph_db_config,\n relational_db_provider=self.config_dict.get('relational_db_provider', 'sqlite'),\n relational_db_config=relational_db_config,\n )\n\n # Initialize the vector database\n self.vector_db = DatabaseFactory.create(\n provider_name=vector_db_provider_name,\n config=self.config.vector_db_config\n )\n\n # Initialize the graph database if provided\n if graph_db_provider_name and self.config.graph_db_config:\n self.graph_db = DatabaseFactory.create(\n provider_name=graph_db_provider_name,\n config=self.config.graph_db_config\n )\n else:\n self.graph_db = None\n\n # Initialize the relational database if provided\n if relational_db_provider_name and self.config.relational_db_config:\n self.relational_db = DatabaseFactory.create(\n provider_name=relational_db_provider_name,\n config=self.config.relational_db_config\n )\n self.memory_unit_repo = MemoryUnitRepository(self.relational_db)\n self.memory_event_repo = MemoryEventRepository(self.relational_db)\n else:\n self.relational_db = None\n self.memory_unit_repo = None\n self.memory_event_repo = None\n\n logger.info(\"MemoryStorage setup completed successfully.\")\n except Exception as e:\n logger.error(f\"Error during MemoryStorage setup: {e}\")\n self.handle_error(e)\n raise # Re-raise the exception after logging\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.update_memory_unit","title":"update_memory_unit(unit_id, updates)
","text":"Updates a MemoryUnit in all configured databases.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required updates
Dict[str, Any]
The updates to apply.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def update_memory_unit(self, unit_id: str, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a MemoryUnit in all configured databases.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n updates (Dict[str, Any]): The updates to apply.\n \"\"\"\n try:\n # Retrieve existing MemoryUnit\n memory_unit = self.get_memory_unit(unit_id)\n previous_state = memory_unit.to_dict()\n\n # Apply updates\n for key, value in updates.items():\n setattr(memory_unit, key, value)\n\n # Update in vector database\n self._update_vector_db(memory_unit)\n\n # Update in graph database\n if self.graph_db:\n self._update_graph_db(memory_unit)\n\n # Update in relational database\n if self.relational_db and self.memory_unit_repo:\n self._update_relational_db(memory_unit)\n\n # Record update event\n if self.relational_db and self.memory_event_repo:\n self._record_event(\n event_type=\"UPDATE\",\n memory_unit=memory_unit,\n previous_state=previous_state\n )\n\n logger.info(f\"Updated MemoryUnit with ID: {unit_id} in all databases.\")\n except Exception as e:\n logger.error(f\"Error updating MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryUnitRepository","title":"MemoryUnitRepository
","text":"Repository for MemoryUnit to handle CRUD operations without SQLAlchemy.
Source code in src/aeiva/cognition/memory/memory_storage.py
class MemoryUnitRepository:\n \"\"\"\n Repository for MemoryUnit to handle CRUD operations without SQLAlchemy.\n \"\"\"\n\n def __init__(self, db: Any):\n \"\"\"\n Initialize the repository with a DatabaseFactory instance.\n\n Args:\n db (Any): An instance of DatabaseFactory for relational databases.\n \"\"\"\n self.db = db\n self.table_name = 'memory_units'\n self._create_table()\n\n def _create_table(self):\n \"\"\"\n Creates the memory_units table if it does not exist.\n \"\"\"\n create_table_query = f\"\"\"\n CREATE TABLE IF NOT EXISTS {self.table_name} (\n id TEXT PRIMARY KEY,\n content TEXT NOT NULL,\n timestamp TEXT NOT NULL,\n modality TEXT,\n type TEXT,\n status TEXT,\n tags TEXT,\n embedding TEXT,\n location TEXT,\n source_role TEXT,\n source_name TEXT,\n source_id TEXT,\n edges TEXT,\n metadata TEXT\n );\n \"\"\"\n self.db.execute_sql(create_table_query)\n\n def add(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Adds a MemoryUnit to the relational database.\n\n Args:\n memory_unit (MemoryUnit): The memory unit to add.\n \"\"\"\n insert_query = f\"\"\"\n INSERT INTO {self.table_name} (id, content, timestamp, modality, type, status, tags, embedding, location, \n source_role, source_name, source_id, edges, metadata)\n VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);\n \"\"\"\n data = (\n memory_unit.id,\n memory_unit.content,\n memory_unit.timestamp.isoformat(),\n memory_unit.modality,\n memory_unit.type,\n memory_unit.status,\n json.dumps(memory_unit.tags),\n json.dumps(memory_unit.embedding) if memory_unit.embedding else None,\n json.dumps(memory_unit.location) if memory_unit.location else None,\n memory_unit.source_role,\n memory_unit.source_name,\n memory_unit.source_id,\n json.dumps([link.to_dict() for link in memory_unit.edges]),\n json.dumps(memory_unit.metadata) if memory_unit.metadata else None\n )\n self.db.execute_sql(insert_query, data)\n\n def get(self, unit_id: str) -> Optional[MemoryUnit]:\n \"\"\"\n Retrieves a MemoryUnit by its ID.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n\n Returns:\n Optional[MemoryUnit]: The retrieved memory unit or None if not found.\n \"\"\"\n select_query = f\"SELECT * FROM {self.table_name} WHERE id = ?;\"\n result = self.db.execute_sql(select_query, (unit_id,))\n row = result.fetchone()\n if row:\n return self._row_to_memory_unit(row)\n return None\n\n def update(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Updates an existing MemoryUnit in the relational database.\n\n Args:\n memory_unit (MemoryUnit): The memory unit with updated data.\n \"\"\"\n update_query = f\"\"\"\n UPDATE {self.table_name}\n SET content = ?, timestamp = ?, modality = ?, type = ?, status = ?, tags = ?, embedding = ?, \n location = ?, source_role = ?, source_name = ?, source_id = ?, edges = ?, metadata = ?\n WHERE id = ?;\n \"\"\"\n data = (\n memory_unit.content,\n memory_unit.timestamp.isoformat(),\n memory_unit.modality,\n memory_unit.type,\n memory_unit.status,\n json.dumps(memory_unit.tags),\n json.dumps(memory_unit.embedding) if memory_unit.embedding else None,\n json.dumps(memory_unit.location) if memory_unit.location else None,\n memory_unit.source_role,\n memory_unit.source_name,\n memory_unit.source_id,\n json.dumps([link.to_dict() for link in memory_unit.edges]),\n json.dumps(memory_unit.metadata) if memory_unit.metadata else None,\n memory_unit.id\n )\n self.db.execute_sql(update_query, data)\n\n def delete(self, unit_id: str) -> None:\n \"\"\"\n Deletes a MemoryUnit from the relational database.\n\n Args:\n unit_id (str): The unique identifier of the memory unit to delete.\n \"\"\"\n delete_query = f\"DELETE FROM {self.table_name} WHERE id = ?;\"\n self.db.execute_sql(delete_query, (unit_id,))\n\n def list_all(self) -> List[MemoryUnit]:\n \"\"\"\n Retrieves all MemoryUnits from the relational database.\n\n Returns:\n List[MemoryUnit]: A list of all memory units.\n \"\"\"\n select_query = f\"SELECT * FROM {self.table_name};\"\n results = self.db.execute_sql(select_query)\n return [self._row_to_memory_unit(row) for row in results.fetchall()]\n\n def delete_all(self) -> None:\n \"\"\"\n Deletes all MemoryUnits from the relational database.\n \"\"\"\n delete_query = f\"DELETE FROM {self.table_name};\"\n self.db.execute_sql(delete_query)\n\n def _row_to_memory_unit(self, row: Any) -> MemoryUnit:\n \"\"\"\n Converts a database row to a MemoryUnit instance.\n\n Args:\n row (Any): A row fetched from the database.\n\n Returns:\n MemoryUnit: The corresponding MemoryUnit instance.\n \"\"\"\n return MemoryUnit(\n id=row['id'],\n content=row['content'],\n timestamp=datetime.fromisoformat(row['timestamp']),\n modality=row['modality'],\n type=row['type'],\n status=row['status'],\n tags=json.loads(row['tags']) if row['tags'] else [],\n embedding=json.loads(row['embedding']) if row['embedding'] else [],\n location=json.loads(row['location']) if row['location'] else {},\n source_role=row['source_role'],\n source_name=row['source_name'],\n source_id=row['source_id'],\n edges=[MemoryLink.from_dict(link) for link in json.loads(row['edges'])] if row['edges'] else [],\n metadata=json.loads(row['metadata']) if row['metadata'] else {}\n )\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryUnitRepository.__init__","title":"__init__(db)
","text":"Initialize the repository with a DatabaseFactory instance.
Parameters:
Name Type Description Default db
Any
An instance of DatabaseFactory for relational databases.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def __init__(self, db: Any):\n \"\"\"\n Initialize the repository with a DatabaseFactory instance.\n\n Args:\n db (Any): An instance of DatabaseFactory for relational databases.\n \"\"\"\n self.db = db\n self.table_name = 'memory_units'\n self._create_table()\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryUnitRepository.add","title":"add(memory_unit)
","text":"Adds a MemoryUnit to the relational database.
Parameters:
Name Type Description Default memory_unit
MemoryUnit
The memory unit to add.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def add(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Adds a MemoryUnit to the relational database.\n\n Args:\n memory_unit (MemoryUnit): The memory unit to add.\n \"\"\"\n insert_query = f\"\"\"\n INSERT INTO {self.table_name} (id, content, timestamp, modality, type, status, tags, embedding, location, \n source_role, source_name, source_id, edges, metadata)\n VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);\n \"\"\"\n data = (\n memory_unit.id,\n memory_unit.content,\n memory_unit.timestamp.isoformat(),\n memory_unit.modality,\n memory_unit.type,\n memory_unit.status,\n json.dumps(memory_unit.tags),\n json.dumps(memory_unit.embedding) if memory_unit.embedding else None,\n json.dumps(memory_unit.location) if memory_unit.location else None,\n memory_unit.source_role,\n memory_unit.source_name,\n memory_unit.source_id,\n json.dumps([link.to_dict() for link in memory_unit.edges]),\n json.dumps(memory_unit.metadata) if memory_unit.metadata else None\n )\n self.db.execute_sql(insert_query, data)\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryUnitRepository.delete","title":"delete(unit_id)
","text":"Deletes a MemoryUnit from the relational database.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit to delete.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def delete(self, unit_id: str) -> None:\n \"\"\"\n Deletes a MemoryUnit from the relational database.\n\n Args:\n unit_id (str): The unique identifier of the memory unit to delete.\n \"\"\"\n delete_query = f\"DELETE FROM {self.table_name} WHERE id = ?;\"\n self.db.execute_sql(delete_query, (unit_id,))\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryUnitRepository.delete_all","title":"delete_all()
","text":"Deletes all MemoryUnits from the relational database.
Source code in src/aeiva/cognition/memory/memory_storage.py
def delete_all(self) -> None:\n \"\"\"\n Deletes all MemoryUnits from the relational database.\n \"\"\"\n delete_query = f\"DELETE FROM {self.table_name};\"\n self.db.execute_sql(delete_query)\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryUnitRepository.get","title":"get(unit_id)
","text":"Retrieves a MemoryUnit by its ID.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required Returns:
Type Description Optional[MemoryUnit]
Optional[MemoryUnit]: The retrieved memory unit or None if not found.
Source code in src/aeiva/cognition/memory/memory_storage.py
def get(self, unit_id: str) -> Optional[MemoryUnit]:\n \"\"\"\n Retrieves a MemoryUnit by its ID.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n\n Returns:\n Optional[MemoryUnit]: The retrieved memory unit or None if not found.\n \"\"\"\n select_query = f\"SELECT * FROM {self.table_name} WHERE id = ?;\"\n result = self.db.execute_sql(select_query, (unit_id,))\n row = result.fetchone()\n if row:\n return self._row_to_memory_unit(row)\n return None\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryUnitRepository.list_all","title":"list_all()
","text":"Retrieves all MemoryUnits from the relational database.
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of all memory units.
Source code in src/aeiva/cognition/memory/memory_storage.py
def list_all(self) -> List[MemoryUnit]:\n \"\"\"\n Retrieves all MemoryUnits from the relational database.\n\n Returns:\n List[MemoryUnit]: A list of all memory units.\n \"\"\"\n select_query = f\"SELECT * FROM {self.table_name};\"\n results = self.db.execute_sql(select_query)\n return [self._row_to_memory_unit(row) for row in results.fetchall()]\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryUnitRepository.update","title":"update(memory_unit)
","text":"Updates an existing MemoryUnit in the relational database.
Parameters:
Name Type Description Default memory_unit
MemoryUnit
The memory unit with updated data.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def update(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Updates an existing MemoryUnit in the relational database.\n\n Args:\n memory_unit (MemoryUnit): The memory unit with updated data.\n \"\"\"\n update_query = f\"\"\"\n UPDATE {self.table_name}\n SET content = ?, timestamp = ?, modality = ?, type = ?, status = ?, tags = ?, embedding = ?, \n location = ?, source_role = ?, source_name = ?, source_id = ?, edges = ?, metadata = ?\n WHERE id = ?;\n \"\"\"\n data = (\n memory_unit.content,\n memory_unit.timestamp.isoformat(),\n memory_unit.modality,\n memory_unit.type,\n memory_unit.status,\n json.dumps(memory_unit.tags),\n json.dumps(memory_unit.embedding) if memory_unit.embedding else None,\n json.dumps(memory_unit.location) if memory_unit.location else None,\n memory_unit.source_role,\n memory_unit.source_name,\n memory_unit.source_id,\n json.dumps([link.to_dict() for link in memory_unit.edges]),\n json.dumps(memory_unit.metadata) if memory_unit.metadata else None,\n memory_unit.id\n )\n self.db.execute_sql(update_query, data)\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_structurer","title":"memory_structurer
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_structurer.MemoryStructurer","title":"MemoryStructurer
","text":"A class to structure memory units based on various structuring algorithms.
Supported structure types - 'structure_type_example': Placeholder for future structuring algorithms.
Source code in src/aeiva/cognition/memory/memory_structurer.py
class MemoryStructurer:\n \"\"\"\n A class to structure memory units based on various structuring algorithms.\n\n Supported structure types:\n - 'structure_type_example': Placeholder for future structuring algorithms.\n \"\"\"\n\n def __init__(self):\n \"\"\"\n Initializes the MemoryStructurer.\n\n Currently, no initialization parameters are required.\n \"\"\"\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemoryStructurer without default parameters.\")\n\n def structure(\n self,\n memory_units: List[MemoryUnit],\n structure_type: str,\n **kwargs\n ) -> List[MemoryUnit]:\n \"\"\"\n Structures the provided memory units based on the specified structure type.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be structured.\n structure_type (str): The type of structuring algorithm to use ('structure_type_example').\n **kwargs: Additional parameters required for specific structurers.\n\n Returns:\n List[MemoryUnit]: The list of memory units after structuring.\n\n Raises:\n MemoryStructurerError: If an unknown structure_type is provided or if structuring fails.\n \"\"\"\n self.logger.debug(f\"Structuring memory units using structure_type='{structure_type}' with kwargs={kwargs}\")\n try:\n if structure_type == 'structure_type_example':\n # Placeholder for actual structuring logic\n return self.structure_example(memory_units, **kwargs)\n else:\n self.logger.error(f\"Unknown structure_type: {structure_type}\")\n raise MemoryStructurerError(f\"Unknown structure_type: {structure_type}\")\n except MemoryStructurerError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to structure memory units: {e}\")\n raise MemoryStructurerError(f\"Failed to structure memory units: {e}\")\n\n def structure_example(\n self,\n memory_units: List[MemoryUnit],\n **kwargs\n ) -> List[MemoryUnit]:\n \"\"\"\n Example structuring method. Currently a placeholder that returns memory units unchanged.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be structured.\n **kwargs: Additional parameters (currently unused).\n\n Returns:\n List[MemoryUnit]: The original list of memory units, unchanged.\n \"\"\"\n self.logger.debug(\"Executing structure_example: No changes applied to memory units.\")\n # Placeholder: No operation performed\n return memory_units\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_structurer.MemoryStructurer.__init__","title":"__init__()
","text":"Initializes the MemoryStructurer.
Currently, no initialization parameters are required.
Source code in src/aeiva/cognition/memory/memory_structurer.py
def __init__(self):\n \"\"\"\n Initializes the MemoryStructurer.\n\n Currently, no initialization parameters are required.\n \"\"\"\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemoryStructurer without default parameters.\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_structurer.MemoryStructurer.structure","title":"structure(memory_units, structure_type, **kwargs)
","text":"Structures the provided memory units based on the specified structure type.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be structured.
required structure_type
str
The type of structuring algorithm to use ('structure_type_example').
required **kwargs
Additional parameters required for specific structurers.
{}
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The list of memory units after structuring.
Raises:
Type Description MemoryStructurerError
If an unknown structure_type is provided or if structuring fails.
Source code in src/aeiva/cognition/memory/memory_structurer.py
def structure(\n self,\n memory_units: List[MemoryUnit],\n structure_type: str,\n **kwargs\n) -> List[MemoryUnit]:\n \"\"\"\n Structures the provided memory units based on the specified structure type.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be structured.\n structure_type (str): The type of structuring algorithm to use ('structure_type_example').\n **kwargs: Additional parameters required for specific structurers.\n\n Returns:\n List[MemoryUnit]: The list of memory units after structuring.\n\n Raises:\n MemoryStructurerError: If an unknown structure_type is provided or if structuring fails.\n \"\"\"\n self.logger.debug(f\"Structuring memory units using structure_type='{structure_type}' with kwargs={kwargs}\")\n try:\n if structure_type == 'structure_type_example':\n # Placeholder for actual structuring logic\n return self.structure_example(memory_units, **kwargs)\n else:\n self.logger.error(f\"Unknown structure_type: {structure_type}\")\n raise MemoryStructurerError(f\"Unknown structure_type: {structure_type}\")\n except MemoryStructurerError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to structure memory units: {e}\")\n raise MemoryStructurerError(f\"Failed to structure memory units: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_structurer.MemoryStructurer.structure_example","title":"structure_example(memory_units, **kwargs)
","text":"Example structuring method. Currently a placeholder that returns memory units unchanged.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be structured.
required **kwargs
Additional parameters (currently unused).
{}
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The original list of memory units, unchanged.
Source code in src/aeiva/cognition/memory/memory_structurer.py
def structure_example(\n self,\n memory_units: List[MemoryUnit],\n **kwargs\n) -> List[MemoryUnit]:\n \"\"\"\n Example structuring method. Currently a placeholder that returns memory units unchanged.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be structured.\n **kwargs: Additional parameters (currently unused).\n\n Returns:\n List[MemoryUnit]: The original list of memory units, unchanged.\n \"\"\"\n self.logger.debug(\"Executing structure_example: No changes applied to memory units.\")\n # Placeholder: No operation performed\n return memory_units\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_structurer.MemoryStructurerError","title":"MemoryStructurerError
","text":" Bases: Exception
Exception raised when an error occurs in the MemoryStructurer.
Source code in src/aeiva/cognition/memory/memory_structurer.py
class MemoryStructurerError(Exception):\n \"\"\"Exception raised when an error occurs in the MemoryStructurer.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_unit","title":"memory_unit
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_unit.MemoryUnit","title":"MemoryUnit
","text":" Bases: BaseModel
MemoryUnit represents a single unit of memory with core content and rich metadata. It includes fields for tracking information about the memory\u2019s source, modality, temporal and spatial attributes, and its connections to other memory units.
Essential Fields id (str): Unique identifier for the memory unit, generated as a UUID string by default. content (Any): Core content of the memory, which is convertible to a string.
Metadata timestamp (datetime): Creation timestamp, defaulting to the current time. modality (Optional[str]): Modality type, such as 'text', 'image', 'audio'. type (Optional[str]): Semantic type, such as 'dialogue', 'summary', 'document'. status (Optional[str]): Processing status, e.g., 'raw', 'cleaned', 'processed'. tags (Optional[List[str]]): Tags for categorization and filtering. embedding (Optional[List[float]]): Vector embedding for retrieval. location (Optional[Union[str, Dict]]): Spatial location data.
Source Information source_role (Optional[str]): Role of the source, e.g., 'user', 'agent'. source_name (Optional[str]): Descriptive name of the source. source_id (Optional[str]): Unique identifier for the memory source, generated as a UUID string.
Connections edges (List[MemoryLink]): List of edges connecting this memory unit to others.
Additional Metadata metadata (Optional[Dict[str, Any]]): Dictionary for extensible metadata.
Source code in src/aeiva/cognition/memory/memory_unit.py
class MemoryUnit(BaseModel):\n \"\"\"\n MemoryUnit represents a single unit of memory with core content and rich metadata.\n It includes fields for tracking information about the memory\u2019s source, modality,\n temporal and spatial attributes, and its connections to other memory units.\n\n Essential Fields:\n id (str): Unique identifier for the memory unit, generated as a UUID string by default.\n content (Any): Core content of the memory, which is convertible to a string.\n\n Metadata:\n timestamp (datetime): Creation timestamp, defaulting to the current time.\n modality (Optional[str]): Modality type, such as 'text', 'image', 'audio'.\n type (Optional[str]): Semantic type, such as 'dialogue', 'summary', 'document'.\n status (Optional[str]): Processing status, e.g., 'raw', 'cleaned', 'processed'.\n tags (Optional[List[str]]): Tags for categorization and filtering.\n embedding (Optional[List[float]]): Vector embedding for retrieval.\n location (Optional[Union[str, Dict]]): Spatial location data.\n\n Source Information:\n source_role (Optional[str]): Role of the source, e.g., 'user', 'agent'.\n source_name (Optional[str]): Descriptive name of the source.\n source_id (Optional[str]): Unique identifier for the memory source, generated as a UUID string.\n\n Connections:\n edges (List[MemoryLink]): List of edges connecting this memory unit to others.\n\n Additional Metadata:\n metadata (Optional[Dict[str, Any]]): Dictionary for extensible metadata.\n \"\"\"\n\n # Essential Fields\n id: str = Field(default_factory=lambda: uuid4().hex, description=\"Unique identifier for the memory unit.\")\n content: Any = Field(\"\", description=\"Core content of the memory unit, convertible to a string.\")\n\n # Metadata Fields\n timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), description=\"Creation timestamp of the memory.\")\n modality: Optional[str] = Field(None, description=\"Modality type, e.g., 'text', 'image', 'audio'.\")\n type: Optional[str] = Field(None, description=\"Semantic type, e.g., 'dialogue', 'summary'.\")\n status: Optional[str] = Field(None, description=\"Processing status, e.g., 'raw', 'cleaned', 'derived', 'grouped', 'structured', 'indexed'.\")\n tags: Optional[List[str]] = Field(default_factory=list, description=\"Tags for categorization or filtering.\")\n embedding: Optional[List[float]] = Field(None, description=\"Embedding vector for memory.\")\n location: Optional[Union[str, Dict]] = Field(None, description=\"Location data as a string or structured dictionary.\")\n\n # Source Information\n source_role: Optional[str] = Field(None, description=\"Role of the memory source, e.g., 'user', 'agent'.\")\n source_name: Optional[str] = Field(None, description=\"Descriptive name of the source, e.g., 'User123'.\")\n source_id: Optional[str] = Field(default_factory=lambda: uuid4().hex, description=\"Unique identifier associated with the source.\")\n\n # Connections\n edges: List[MemoryLink] = Field(default_factory=list, description=\"List of edges linking this memory unit to others.\")\n\n # Additional Metadata\n metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description=\"Dictionary for extensible metadata.\")\n\n def to_dict(self) -> Dict[str, Any]:\n \"\"\"\n Converts the MemoryUnit instance to a dictionary format for serialization.\n Each field is handled explicitly to ensure proper serialization.\n\n Returns:\n Dict[str, Any]: A dictionary representation of the MemoryUnit.\n \"\"\"\n return {\n \"id\": self.id,\n \"content\": self.content,\n \"timestamp\": self.timestamp.isoformat(), # Convert datetime to string\n \"modality\": self.modality,\n \"type\": self.type,\n \"status\": self.status,\n \"tags\": self.tags,\n \"embedding\": self.embedding,\n \"location\": self.location,\n \"source_role\": self.source_role,\n \"source_name\": self.source_name,\n \"source_id\": self.source_id,\n \"edges\": [edge.to_dict() for edge in self.edges], # Serialize each MemoryLink\n \"metadata\": self.metadata\n }\n\n @classmethod\n def from_dict(cls, data: dict) -> \"MemoryUnit\":\n \"\"\"\n Creates a MemoryUnit instance from a dictionary.\n Each field is handled explicitly to ensure proper deserialization.\n\n Args:\n data (dict): A dictionary containing MemoryUnit data.\n\n Returns:\n MemoryUnit: The created MemoryUnit instance.\n \"\"\"\n try:\n return cls(\n id=data.get('id', uuid4().hex),\n content=data.get('content', \"\"),\n timestamp=datetime.fromisoformat(data['timestamp']) if 'timestamp' in data else datetime.now(UTC),\n modality=data.get('modality'),\n type=data.get('type'),\n status=data.get('status'),\n tags=data.get('tags', []),\n embedding=data.get('embedding'),\n location=data.get('location'),\n source_role=data.get('source_role'),\n source_name=data.get('source_name'),\n source_id=data.get('source_id', uuid4().hex),\n edges=[MemoryLink.from_dict(edge) for edge in data.get('edges', [])],\n metadata=data.get('metadata', {})\n )\n except Exception as e:\n # logger.error(f\"Error deserializing MemoryUnit from dict: {e}\")\n raise e\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_unit.MemoryUnit.from_dict","title":"from_dict(data)
classmethod
","text":"Creates a MemoryUnit instance from a dictionary. Each field is handled explicitly to ensure proper deserialization.
Parameters:
Name Type Description Default data
dict
A dictionary containing MemoryUnit data.
required Returns:
Name Type Description MemoryUnit
MemoryUnit
The created MemoryUnit instance.
Source code in src/aeiva/cognition/memory/memory_unit.py
@classmethod\ndef from_dict(cls, data: dict) -> \"MemoryUnit\":\n \"\"\"\n Creates a MemoryUnit instance from a dictionary.\n Each field is handled explicitly to ensure proper deserialization.\n\n Args:\n data (dict): A dictionary containing MemoryUnit data.\n\n Returns:\n MemoryUnit: The created MemoryUnit instance.\n \"\"\"\n try:\n return cls(\n id=data.get('id', uuid4().hex),\n content=data.get('content', \"\"),\n timestamp=datetime.fromisoformat(data['timestamp']) if 'timestamp' in data else datetime.now(UTC),\n modality=data.get('modality'),\n type=data.get('type'),\n status=data.get('status'),\n tags=data.get('tags', []),\n embedding=data.get('embedding'),\n location=data.get('location'),\n source_role=data.get('source_role'),\n source_name=data.get('source_name'),\n source_id=data.get('source_id', uuid4().hex),\n edges=[MemoryLink.from_dict(edge) for edge in data.get('edges', [])],\n metadata=data.get('metadata', {})\n )\n except Exception as e:\n # logger.error(f\"Error deserializing MemoryUnit from dict: {e}\")\n raise e\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_unit.MemoryUnit.to_dict","title":"to_dict()
","text":"Converts the MemoryUnit instance to a dictionary format for serialization. Each field is handled explicitly to ensure proper serialization.
Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary representation of the MemoryUnit.
Source code in src/aeiva/cognition/memory/memory_unit.py
def to_dict(self) -> Dict[str, Any]:\n \"\"\"\n Converts the MemoryUnit instance to a dictionary format for serialization.\n Each field is handled explicitly to ensure proper serialization.\n\n Returns:\n Dict[str, Any]: A dictionary representation of the MemoryUnit.\n \"\"\"\n return {\n \"id\": self.id,\n \"content\": self.content,\n \"timestamp\": self.timestamp.isoformat(), # Convert datetime to string\n \"modality\": self.modality,\n \"type\": self.type,\n \"status\": self.status,\n \"tags\": self.tags,\n \"embedding\": self.embedding,\n \"location\": self.location,\n \"source_role\": self.source_role,\n \"source_name\": self.source_name,\n \"source_id\": self.source_id,\n \"edges\": [edge.to_dict() for edge in self.edges], # Serialize each MemoryLink\n \"metadata\": self.metadata\n }\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_utils","title":"memory_utils
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_utils.derive_content","title":"derive_content(derivation_type, data)
","text":"You are a creative assistant capable of deriving new content based on specified types. Your task is to derive a {derivation_type} from the provided combined content.
Source code in src/aeiva/cognition/memory/memory_utils.py
@simple(model='gpt-4', temperature=0.7)\ndef derive_content(derivation_type: str, data: str) -> str:\n \"\"\"\n You are a creative assistant capable of deriving new content based on specified types.\n Your task is to derive a {derivation_type} from the provided combined content.\n \"\"\"\n result = f\"Derive a {derivation_type} from the following content:\\n{data}\"\n return result\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_utils.extract_entities_relationships","title":"extract_entities_relationships(data)
","text":"You are an intelligent assistant skilled in natural language processing. Your task is to extract entities and the relationships between them from the provided content.
Source code in src/aeiva/cognition/memory/memory_utils.py
@simple(model='gpt-4', temperature=0.7)\ndef extract_entities_relationships(data: Any) -> str:\n \"\"\"\n You are an intelligent assistant skilled in natural language processing.\n Your task is to extract entities and the relationships between them from the provided content.\n \"\"\"\n result = f\"Extract entities and relationships from the following content:\\n{data}\"\n return result\n
"},{"location":"reference/#src.aeiva.cognition.memory.storage_config","title":"storage_config
","text":""},{"location":"reference/#src.aeiva.cognition.memory.storage_config.StorageConfig","title":"StorageConfig
dataclass
","text":" Bases: BaseConfig
Configuration class for the Memory storage.
Attributes:
Name Type Description vector_db_config
DatabaseConfig
Configuration for the vector database.
graph_db_config
Optional[DatabaseConfig]
Configuration for the graph database.
relational_db_config
Optional[DatabaseConfig]
Configuration for the relational database.
Source code in src/aeiva/cognition/memory/storage_config.py
@dataclass\nclass StorageConfig(BaseConfig):\n \"\"\"\n Configuration class for the Memory storage.\n\n Attributes:\n vector_db_config (DatabaseConfig): Configuration for the vector database.\n graph_db_config (Optional[DatabaseConfig]): Configuration for the graph database.\n relational_db_config (Optional[DatabaseConfig]): Configuration for the relational database.\n \"\"\"\n vector_db_provider: str = field(\n metadata={\"help\": \"Vector database provider name.\"}\n )\n vector_db_config: BaseConfig = field(\n metadata={\"help\": \"Configuration for the vector database.\"}\n )\n graph_db_provider: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Graph database provider name.\"}\n )\n graph_db_config: Optional[BaseConfig] = field(\n default=None,\n metadata={\"help\": \"Configuration for the graph database.\"}\n )\n relational_db_provider: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Relational database provider name.\"}\n )\n relational_db_config: Optional[BaseConfig] = field(\n default=None,\n metadata={\"help\": \"Configuration for the relational database.\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n # Perform any necessary validation\n if not self.vector_db_config:\n raise ValueError(\"Vector database configuration must be provided.\")\n
"},{"location":"reference/#src.aeiva.cognition.observation","title":"observation
","text":""},{"location":"reference/#src.aeiva.cognition.observation.Observation","title":"Observation
","text":"Represents a processed input from the PerceptionSystem.
Source code in src/aeiva/cognition/observation.py
class Observation:\n \"\"\"\n Represents a processed input from the PerceptionSystem.\n \"\"\"\n def __init__(self, data: Any, modality: str = 'text', timestamp: Optional[datetime] = None, metadata: Optional[Dict[str, Any]] = None):\n self.data = data # The processed data (e.g., text)\n self.modality = modality\n self.timestamp = timestamp or datetime.now()\n self.metadata = metadata or {}\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n 'data': self.data,\n 'modality': self.modality,\n 'timestamp': self.timestamp.isoformat(),\n 'metadata': self.metadata\n }\n
"},{"location":"reference/#src.aeiva.cognition.thought","title":"thought
","text":""},{"location":"reference/#src.aeiva.cognition.thought.Thought","title":"Thought
","text":"Represents the output from the Brain after processing an Observation.
Source code in src/aeiva/cognition/thought.py
class Thought:\n \"\"\"\n Represents the output from the Brain after processing an Observation.\n \"\"\"\n def __init__(self, content: Any, modality: str = 'text', timestamp: Optional[datetime] = None, metadata: Optional[Dict[str, Any]] = None):\n self.content = content # The thought content (e.g., text)\n self.modality = modality\n self.timestamp = timestamp or datetime.now()\n self.metadata = metadata or {}\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n 'content': self.content,\n 'modality': self.modality,\n 'timestamp': self.timestamp.isoformat(),\n 'metadata': self.metadata\n }\n
"},{"location":"reference/#src.aeiva.cognition.world_model","title":"world_model
","text":""},{"location":"reference/#src.aeiva.cognition.world_model.world_model","title":"world_model
","text":""},{"location":"reference/#src.aeiva.cognition.world_model.world_model.WorldModel","title":"WorldModel
","text":" Bases: ABC
Abstract base class representing the World Model system of an agent.
The World Model maintains an internal representation of the environment, enabling the agent to understand, predict, and interact with its surroundings effectively.
Attributes:
Name Type Description config
Any
Configuration settings for the World Model system.
state
Any
The internal state of the World Model system.
Source code in src/aeiva/cognition/world_model/world_model.py
class WorldModel(ABC):\n \"\"\"\n Abstract base class representing the World Model system of an agent.\n\n The World Model maintains an internal representation of the environment, enabling the agent\n to understand, predict, and interact with its surroundings effectively.\n\n Attributes:\n config (Any): Configuration settings for the World Model system.\n state (Any): The internal state of the World Model system.\n \"\"\"\n\n def __init__(self, config: Any):\n \"\"\"\n Initialize the World Model system with the provided configuration.\n\n Args:\n config (Any): Configuration settings for the World Model system.\n \"\"\"\n self.config = config\n self.state = self.init_state()\n\n @abstractmethod\n def init_state(self) -> Any:\n \"\"\"\n Initialize the internal state of the World Model system.\n\n This method should set up the initial state required for the World Model system's operations.\n\n Returns:\n Any: The initial state of the World Model system.\n \"\"\"\n pass\n\n @abstractmethod\n def setup(self) -> None:\n \"\"\"\n Asynchronously set up the World Model system's components.\n\n This method should initialize any necessary components or resources based on the provided configuration.\n\n Raises:\n ConfigurationError: If the configuration is invalid or incomplete.\n \"\"\"\n pass\n\n @abstractmethod\n async def update(self, observation: Any) -> None:\n \"\"\"\n Asynchronously update the world model based on new observations.\n\n Args:\n observation (Any): The new observation to incorporate into the world model.\n\n Raises:\n UpdateError: If updating the world model fails.\n \"\"\"\n pass\n\n @abstractmethod\n async def query(self, query: Any) -> Any:\n \"\"\"\n Asynchronously query the world model for specific information.\n\n Args:\n query (Any): The query or criteria to retrieve specific information from the world model.\n\n Returns:\n Any: The information retrieved from the world model.\n\n Raises:\n QueryError: If the query process fails.\n \"\"\"\n pass\n\n def get_current_state(self) -> Any:\n \"\"\"\n Retrieve the current internal state of the World Model system.\n\n Returns:\n Any: The current internal state.\n \"\"\"\n return self.state\n\n def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during world model operations.\n\n This method can be overridden to implement custom error handling logic.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n # Default error handling: log the error\n print(f\"WorldModel system encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.world_model.world_model.WorldModel.__init__","title":"__init__(config)
","text":"Initialize the World Model system with the provided configuration.
Parameters:
Name Type Description Default config
Any
Configuration settings for the World Model system.
required Source code in src/aeiva/cognition/world_model/world_model.py
def __init__(self, config: Any):\n \"\"\"\n Initialize the World Model system with the provided configuration.\n\n Args:\n config (Any): Configuration settings for the World Model system.\n \"\"\"\n self.config = config\n self.state = self.init_state()\n
"},{"location":"reference/#src.aeiva.cognition.world_model.world_model.WorldModel.get_current_state","title":"get_current_state()
","text":"Retrieve the current internal state of the World Model system.
Returns:
Name Type Description Any
Any
The current internal state.
Source code in src/aeiva/cognition/world_model/world_model.py
def get_current_state(self) -> Any:\n \"\"\"\n Retrieve the current internal state of the World Model system.\n\n Returns:\n Any: The current internal state.\n \"\"\"\n return self.state\n
"},{"location":"reference/#src.aeiva.cognition.world_model.world_model.WorldModel.handle_error","title":"handle_error(error)
","text":"Handle errors that occur during world model operations.
This method can be overridden to implement custom error handling logic.
Parameters:
Name Type Description Default error
Exception
The exception that was raised.
required Source code in src/aeiva/cognition/world_model/world_model.py
def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during world model operations.\n\n This method can be overridden to implement custom error handling logic.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n # Default error handling: log the error\n print(f\"WorldModel system encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.world_model.world_model.WorldModel.init_state","title":"init_state()
abstractmethod
","text":"Initialize the internal state of the World Model system.
This method should set up the initial state required for the World Model system's operations.
Returns:
Name Type Description Any
Any
The initial state of the World Model system.
Source code in src/aeiva/cognition/world_model/world_model.py
@abstractmethod\ndef init_state(self) -> Any:\n \"\"\"\n Initialize the internal state of the World Model system.\n\n This method should set up the initial state required for the World Model system's operations.\n\n Returns:\n Any: The initial state of the World Model system.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.world_model.world_model.WorldModel.query","title":"query(query)
abstractmethod
async
","text":"Asynchronously query the world model for specific information.
Parameters:
Name Type Description Default query
Any
The query or criteria to retrieve specific information from the world model.
required Returns:
Name Type Description Any
Any
The information retrieved from the world model.
Raises:
Type Description QueryError
If the query process fails.
Source code in src/aeiva/cognition/world_model/world_model.py
@abstractmethod\nasync def query(self, query: Any) -> Any:\n \"\"\"\n Asynchronously query the world model for specific information.\n\n Args:\n query (Any): The query or criteria to retrieve specific information from the world model.\n\n Returns:\n Any: The information retrieved from the world model.\n\n Raises:\n QueryError: If the query process fails.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.world_model.world_model.WorldModel.setup","title":"setup()
abstractmethod
","text":"Asynchronously set up the World Model system's components.
This method should initialize any necessary components or resources based on the provided configuration.
Raises:
Type Description ConfigurationError
If the configuration is invalid or incomplete.
Source code in src/aeiva/cognition/world_model/world_model.py
@abstractmethod\ndef setup(self) -> None:\n \"\"\"\n Asynchronously set up the World Model system's components.\n\n This method should initialize any necessary components or resources based on the provided configuration.\n\n Raises:\n ConfigurationError: If the configuration is invalid or incomplete.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.world_model.world_model.WorldModel.update","title":"update(observation)
abstractmethod
async
","text":"Asynchronously update the world model based on new observations.
Parameters:
Name Type Description Default observation
Any
The new observation to incorporate into the world model.
required Raises:
Type Description UpdateError
If updating the world model fails.
Source code in src/aeiva/cognition/world_model/world_model.py
@abstractmethod\nasync def update(self, observation: Any) -> None:\n \"\"\"\n Asynchronously update the world model based on new observations.\n\n Args:\n observation (Any): The new observation to incorporate into the world model.\n\n Raises:\n UpdateError: If updating the world model fails.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.command","title":"command
","text":""},{"location":"reference/#src.aeiva.command.aeiva_chat_gradio","title":"aeiva_chat_gradio
","text":"We can run the command like below: (specify your own config file path)
aeiva-chat-gradio --config configs/agent_config.yaml
"},{"location":"reference/#src.aeiva.command.aeiva_chat_gradio.run","title":"run(config, verbose)
","text":"Starts the Aeiva chat Gradio interface with the provided configuration.
Source code in src/aeiva/command/aeiva_chat_gradio.py
@click.command(name=\"aeiva-chat-gradio\")\n@click.option('--config', '-c', default=str(DEFAULT_CONFIG_PATH),\n help='Path to the configuration file (YAML or JSON).',\n type=click.Path(exists=True, dir_okay=False))\n@click.option('--verbose', '-v', is_flag=True, help='Enable verbose logging.')\ndef run(config, verbose):\n \"\"\"\n Starts the Aeiva chat Gradio interface with the provided configuration.\n \"\"\"\n # Setup logging\n logger = setup_logging(DEFAULT_LOG_PATH, verbose)\n\n # Load environment variables (API keys, etc.)\n load_dotenv()\n\n logger.info(f\"Loading configuration from {config}\")\n config_dict = from_json_or_yaml(config)\n\n # Initialize the Agent\n try:\n agent = Agent(config_dict)\n agent.setup()\n logger.info(\"Agent initialized successfully.\")\n except Exception as e:\n logger.error(f\"Failed to initialize Agent: {e}\")\n click.echo(f\"Error: Failed to initialize Agent: {e}\")\n sys.exit(1)\n\n # Function to run the Agent's run method in a separate thread\n def run_agent(agent_instance):\n try:\n asyncio.run(agent_instance.run())\n except Exception as e:\n logger.error(f\"Error running Agent: {e}\")\n\n # Start the Agent in a separate daemon thread\n agent_thread = threading.Thread(target=run_agent, args=(agent,), daemon=True)\n agent_thread.start()\n logger.info(\"Agent run thread started.\")\n\n # Initialize a thread-safe queue to receive responses from the Agent\n response_queue = queue.Queue()\n\n # Define a handler for 'response.gradio' events\n def handle_response_gradio(event: Event):\n response = event.payload\n response_queue.put_nowait(response) # Put response into the thread-safe queue\n logger.info(f\"Received 'response.gradio' event: {response}\")\n\n # Register the handler with the Agent's EventBus\n agent.event_bus.on('response.gradio')(handle_response_gradio)\n logger.info(\"Registered handler for 'response.gradio' events.\")\n\n # Validate and start Neo4j\n neo4j_home = os.getenv('NEO4J_HOME')\n if not neo4j_home:\n logger.error(\"NEO4J_HOME environment variable is not set.\")\n click.echo(\"Error: NEO4J_HOME environment variable is not set.\")\n sys.exit(1)\n\n validate_neo4j_home(logger, neo4j_home)\n neo4j_process = start_neo4j(logger, neo4j_home)\n\n # Register signal handlers to ensure Neo4j stops gracefully\n for sig in [signal.SIGINT, signal.SIGTERM]:\n signal.signal(sig, lambda s, f: handle_exit(s, f, logger, neo4j_process))\n\n # Define handlers for multimodal inputs\n\n def handle_image_upload(image: Image.Image):\n if image is not None:\n timestamp = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n image_path = f\"uploads/uploaded_image_{timestamp}.jpg\"\n try:\n image.save(image_path)\n logger.info(f\"Image uploaded and saved to {image_path}\")\n return \"User uploaded an image.\"\n except Exception as e:\n logger.error(f\"Error saving uploaded image: {e}\")\n return \"Failed to upload image.\"\n return \"\"\n\n def handle_video_upload(video):\n if video is not None:\n timestamp = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n video_path = f\"uploads/uploaded_video_{timestamp}.mp4\"\n try:\n with open(video_path, \"wb\") as f:\n f.write(video.read())\n logger.info(f\"Video uploaded and saved to {video_path}\")\n return \"User uploaded a video.\"\n except Exception as e:\n logger.error(f\"Error saving uploaded video: {e}\")\n return \"Failed to upload video.\"\n return \"\"\n\n def handle_audio_upload(audio):\n if audio is not None:\n try:\n sample_rate, audio_data = audio\n # Normalize audio_data to float32 in the range -1.0 to 1.0\n audio_data_normalized = audio_data.astype(np.float32) / np.abs(audio_data).max()\n timestamp = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n audio_path = f\"uploads/uploaded_audio_{timestamp}.wav\"\n sf.write(audio_path, audio_data_normalized, sample_rate, subtype='PCM_16')\n logger.info(f\"Audio uploaded and saved to {audio_path}\")\n return \"User uploaded an audio file.\"\n except Exception as e:\n logger.error(f\"Error saving uploaded audio: {e}\")\n return \"Failed to upload audio.\"\n return \"\"\n\n def handle_upload(file):\n \"\"\"\n Handles file uploads and delegates to specific handlers based on file type.\n\n Args:\n file: Uploaded file object.\n\n Returns:\n str: Message indicating the upload status.\n \"\"\"\n if file is None:\n return \"\"\n if file.type.startswith(\"image\"):\n return handle_image_upload(file)\n elif file.type.startswith(\"video\"):\n return handle_video_upload(file)\n elif file.type.startswith(\"audio\"):\n return handle_audio_upload(file)\n else:\n logger.warning(f\"Unsupported file type uploaded: {file.type}\")\n return \"Unsupported file type uploaded.\"\n\n def clear_media():\n \"\"\"\n Clears the uploaded media paths.\n \"\"\"\n # Implement any necessary logic to clear media paths or data\n logger.info(\"Cleared uploaded media paths.\")\n return \"\"\n\n async def bot(user_input, history):\n \"\"\"\n Handles chatbot logic by emitting perception.gradio events to the Agent and retrieving responses.\n \"\"\"\n if agent is None:\n logger.error(\"Agent is not initialized.\")\n history.append({\"role\": \"assistant\", \"content\": \"Agent is not initialized.\"})\n yield history, ''\n return\n\n try:\n # Append user's message to history\n history.append({\"role\": \"user\", \"content\": user_input})\n # Append an empty assistant response\n history.append({\"role\": \"assistant\", \"content\": \"\"})\n yield history, '' # Display the user's message\n logger.info(f\"User input appended to history: {user_input}\")\n\n stream = config_dict[\"llm_gateway_config\"][\"llm_stream\"]\n use_async = config_dict[\"llm_gateway_config\"][\"llm_use_async\"]\n\n # Emit the 'perception.gradio' event with stream=True\n emit_future = asyncio.run_coroutine_threadsafe(\n agent.event_bus.emit('perception.gradio', payload=user_input),\n agent.event_bus.loop\n )\n emit_future.result() # Ensure the event is emitted\n logger.info(f\"Emitted 'perception.gradio' event with payload: {user_input} | Stream: {stream}\")\n\n assistant_message = ''\n if stream:\n while True:\n try:\n # Non-blocking response retrieval from the thread-safe queue with timeout\n response = await asyncio.wait_for(\n asyncio.to_thread(response_queue.get, True, 30),\n timeout=30\n )\n logger.info(f\"Retrieved response from queue: {response}\")\n if response == \"<END_OF_RESPONSE>\":\n logger.info(\"Received end of response signal.\")\n break\n assistant_message += response\n # Create a new history list to ensure Gradio detects the update\n new_history = history.copy()\n new_history[-1][\"content\"] = assistant_message\n logger.info(f\"Yielding updated history: {new_history}\")\n yield new_history, ''\n except asyncio.TimeoutError:\n logger.warning(\"Timeout: No response received from Agent.\")\n # Create a new history list to ensure Gradio detects the update\n new_history = history.copy()\n new_history[-1][\"content\"] = \"I'm sorry, I didn't receive a response in time.\"\n yield new_history, ''\n break\n else:\n try:\n # Non-blocking response retrieval from the thread-safe queue with timeout\n response = await asyncio.wait_for(\n asyncio.to_thread(response_queue.get, True, 30),\n timeout=30\n )\n logger.info(f\"Retrieved response from queue: {response}\")\n assistant_message += response\n # Create a new history list to ensure Gradio detects the update\n new_history = history.copy()\n new_history[-1][\"content\"] = assistant_message\n logger.info(f\"Yielding updated history: {new_history}\")\n yield new_history, ''\n except asyncio.TimeoutError:\n logger.warning(\"Timeout: No response received from Agent.\")\n # Create a new history list to ensure Gradio detects the update\n new_history = history.copy()\n new_history[-1][\"content\"] = \"I'm sorry, I didn't receive a response in time.\"\n yield new_history, ''\n\n except Exception as e:\n logger.error(f\"Unexpected Error in bot function: {e}\")\n # Create a new history list to ensure Gradio detects the update\n new_history = history.copy()\n new_history[-1][\"content\"] = \"An unexpected error occurred.\"\n yield new_history, ''\n\n def launch_gradio_interface():\n \"\"\"\n Main gradio interface.\n \"\"\"\n with gr.Blocks(title=\"Multimodal LLM Chatbot with Tools\") as demo:\n # Header Section\n gr.Markdown(\"\"\"\n <h1 align=\"center\">\n <a href=\"https://github.com/chatsci/Aeiva\">\n <img src=\"https://i.ibb.co/P4zQHDk/aeiva-1024.png\",\n alt=\"Aeiva\" border=\"0\" style=\"margin: 0 auto; height: 200px;\" />\n </a>\n </h1>\n\n <h2 align=\"center\">\n AEIVA: An Evolving Intelligent Virtual Assistant\n </h2>\n\n <h5 align=\"center\">\n If you like our project, please give us a star \u2728 on Github for the latest update.\n </h5>\n\n <div align=\"center\">\n <div style=\"display:flex; gap: 0.25rem;\" align=\"center\">\n <a href='https://github.com/chatsci/Aeiva'><img src='https://img.shields.io/badge/Github-Code-blue'></a>\n <a href=\"https://arxiv.org/abs/2304.14178\"><img src=\"https://img.shields.io/badge/Arxiv-2304.14178-red\"></a>\n <a href='https://github.com/chatsci/Aeiva/stargazers'><img src='https://img.shields.io/github/stars/chatsci/Aeiva.svg?style=social'></a>\n </div>\n </div>\n \"\"\")\n\n # Main Layout: Two Columns\n with gr.Row():\n # Left Column: Parameter Settings and Multimodal Inputs\n with gr.Column(scale=1, min_width=700):\n # Parameter Settings Tab\n with gr.Tab(label=\"Parameter Setting\"):\n gr.Markdown(\"# Parameters\")\n top_p = gr.Slider(\n minimum=0,\n maximum=1.0,\n value=0.95,\n step=0.05,\n interactive=True,\n label=\"Top-p\"\n )\n temperature = gr.Slider(\n minimum=0.1,\n maximum=2.0,\n value=1.0,\n step=0.1,\n interactive=True,\n label=\"Temperature\"\n )\n max_length_tokens = gr.Slider(\n minimum=0,\n maximum=512,\n value=512,\n step=8,\n interactive=True,\n label=\"Max Generation Tokens\"\n )\n max_context_length_tokens = gr.Slider(\n minimum=0,\n maximum=4096,\n value=2048,\n step=128,\n interactive=True,\n label=\"Max History Tokens\"\n )\n\n # Multimodal Inputs Section\n with gr.Row():\n imagebox = gr.Image(type=\"pil\", label=\"Upload Image\")\n videobox = gr.File(label=\"Upload Video\", file_types=[\"video\"])\n audiobox = gr.Audio(label=\"Upload Audio\", type=\"numpy\")\n\n with gr.Row():\n record_videobox = gr.Video(label=\"Record Video\")\n record_audiobox = gr.Audio(label=\"Record Audio\")\n\n # Clear Media Button\n with gr.Row():\n clear_media_btn = gr.Button(\"\ud83e\uddf9 Clear Media\", variant=\"secondary\")\n\n # Right Column: Chat Interface and Action Buttons\n with gr.Column(scale=1, min_width=700):\n # Chatbot Component\n chatbot = gr.Chatbot(\n [],\n type=\"messages\", # Specify type as 'messages'\n elem_id=\"chatbot\",\n height=730\n )\n\n # Input Textbox and Upload Button\n with gr.Row():\n with gr.Column(scale=4, min_width=300):\n txt = gr.Textbox(\n show_label=False,\n placeholder=\"Enter text and press enter, or upload an image/video/audio\",\n lines=1,\n elem_classes=[\"input-textbox\"] # Assign a CSS class for styling\n )\n with gr.Column(scale=1, min_width=100):\n btn = gr.UploadButton(\"\ud83d\udcc1\", file_types=[\"image\", \"video\", \"audio\"], elem_classes=[\"upload-button\"])\n # Changed the button label to an icon for a more compact look\n\n # Action Buttons Placed Below the Input Box\n with gr.Row():\n upvote_btn = gr.Button(\"\ud83d\udc4d Upvote\", interactive=True)\n downvote_btn = gr.Button(\"\ud83d\udc4e Downvote\", interactive=True)\n flag_btn = gr.Button(\"\u26a0\ufe0f Flag\", interactive=True)\n regenerate_btn = gr.Button(\"\ud83d\udd04 Regenerate\", interactive=True)\n clear_history_btn = gr.Button(\"\ud83d\uddd1\ufe0f Clear History\", interactive=True)\n new_conv_btn = gr.Button(\"\ud83e\uddf9 New Conversation\", interactive=True)\n del_last_turn_btn = gr.Button(\"\ud83d\uddd1\ufe0f Remove Last Turn\", interactive=True)\n\n # Define interactions\n\n # Text input submission with streaming\n txt.submit(\n bot,\n inputs=[txt, chatbot],\n outputs=[chatbot, txt],\n queue=True, # Enable queue for better performance\n # stream=True # Enable streaming (already handled in the bot function)\n )\n # Removed the .then callback to prevent layout shifts\n\n # File upload (image/video/audio)\n btn.upload(\n handle_upload,\n inputs=btn,\n outputs=txt, # Set message in textbox to trigger bot\n queue=True\n )\n\n # Image upload\n imagebox.upload(\n handle_image_upload,\n inputs=imagebox,\n outputs=txt, # Set message in textbox to trigger bot\n queue=True\n )\n\n # Video upload\n videobox.upload(\n handle_video_upload,\n inputs=videobox,\n outputs=txt, # Set message in textbox to trigger bot\n queue=True\n )\n\n # Audio upload\n audiobox.upload(\n handle_audio_upload,\n inputs=audiobox,\n outputs=txt, # Set message in textbox to trigger bot\n queue=True\n )\n\n # Record Video\n record_videobox.change(\n handle_video_upload,\n inputs=record_videobox,\n outputs=txt, # Set message in textbox to trigger bot\n queue=True\n )\n\n # Record Audio\n record_audiobox.change(\n handle_audio_upload,\n inputs=record_audiobox,\n outputs=txt, # Set message in textbox to trigger bot\n queue=True\n )\n\n # Clear Media Button\n clear_media_btn.click(\n clear_media,\n inputs=None,\n outputs=None,\n queue=False\n )\n\n # Action Buttons Functionality\n\n # Clear History\n clear_history_btn.click(\n lambda: ([], \"\"),\n inputs=None,\n outputs=[chatbot, txt],\n queue=False\n )\n\n # New Conversation\n new_conv_btn.click(\n lambda: ([], \"\"),\n inputs=None,\n outputs=[chatbot, txt],\n queue=False\n )\n\n # Remove Last Turn (Removes the last user and assistant messages)\n del_last_turn_btn.click(\n lambda history: history[:-2] if len(history) >= 2 else history,\n inputs=chatbot,\n outputs=chatbot,\n queue=False\n )\n\n # Launch the Gradio interface\n demo.launch(share=True)\n\n # Launch aeiva chat gradio\n launch_gradio_interface()\n
"},{"location":"reference/#src.aeiva.command.aeiva_chat_terminal","title":"aeiva_chat_terminal
","text":"We can run the command like below: (specify your own config file path)
aeiva-chat-terminal --config configs/agent_config.yaml
"},{"location":"reference/#src.aeiva.command.aeiva_chat_terminal.run","title":"run(config, verbose)
","text":"Starts the Aeiva chat terminal with the provided configuration.
Source code in src/aeiva/command/aeiva_chat_terminal.py
@click.command()\n@click.option('--config', '-c', default=str(DEFAULT_CONFIG_PATH),\n help='Path to the configuration file (YAML or JSON).',\n type=click.Path(exists=True, dir_okay=False))\n@click.option('--verbose', '-v', is_flag=True, help='Enable verbose logging.')\ndef run(config, verbose):\n \"\"\"\n Starts the Aeiva chat terminal with the provided configuration.\n \"\"\"\n # Setup logging\n logger = setup_logging(DEFAULT_LOG_PATH, verbose)\n\n click.echo(f\"Loading configuration from {config}\")\n config_path = Path(config)\n\n # Parse the configuration file with error handling\n try:\n config_data = from_json_or_yaml(config_path)\n except Exception as e:\n logger.error(f\"Failed to parse configuration file: {e}\")\n click.echo(f\"Error: Failed to parse configuration file: {e}\")\n sys.exit(1)\n\n # Retrieve NEO4J_HOME from environment variables\n neo4j_home = os.getenv('NEO4J_HOME')\n if not neo4j_home:\n logger.error(\"NEO4J_HOME is not set in the environment.\")\n click.echo(\"Error: NEO4J_HOME is not set in the environment. Please set it in your shell configuration (e.g., .bashrc or .zshrc).\")\n sys.exit(1)\n\n # Validate NEO4J_HOME path\n validate_neo4j_home(logger, neo4j_home)\n\n # Start Neo4j\n neo4j_process = start_neo4j(logger, neo4j_home)\n\n # Register signal handlers to ensure Neo4j stops gracefully\n signal.signal(signal.SIGINT, lambda s, f: handle_exit(s, f, neo4j_process))\n signal.signal(signal.SIGTERM, lambda s, f: handle_exit(s, f, neo4j_process))\n\n # Start the Agent\n try:\n agent = Agent(config_data)\n agent.setup()\n asyncio.run(agent.run())\n except KeyboardInterrupt:\n logger.info(\"Agent execution interrupted by user.\")\n click.echo(\"\\nAgent execution interrupted by user.\")\n except Exception as e:\n logger.error(f\"An error occurred during agent execution: {e}\")\n click.echo(f\"An error occurred during agent execution: {e}\")\n finally:\n # # Perform any necessary cleanup\n # try:\n # agent.cognition_components['memory'].delete_all()\n # logger.info(\"All memory units deleted during cleanup.\")\n # except NotImplementedError as nie:\n # logger.warning(f\"Delete All feature not implemented: {nie}\")\n # except Exception as e:\n # logger.error(f\"Error during cleanup: {e}\")\n # click.echo(\"Failed to delete all memory units.\")\n\n # Stop Neo4j\n stop_neo4j(logger, neo4j_process)\n logger.info(\"Cleanup completed.\")\n
"},{"location":"reference/#src.aeiva.command.aeiva_server","title":"aeiva_server
","text":""},{"location":"reference/#src.aeiva.command.aeiva_server.run","title":"run(config, host, port, verbose)
","text":"Starts the Aeiva Agent Server using FastAPI.
Source code in src/aeiva/command/aeiva_server.py
@click.command(name=\"aeiva-server\")\n@click.option(\n '--config', '-c',\n default=None,\n help='Path to the configuration file (YAML or JSON).',\n type=click.Path(exists=True, dir_okay=False)\n)\n@click.option(\n '--host', '-H',\n default=\"0.0.0.0\",\n help='Host address to run the server on.',\n show_default=True\n)\n@click.option(\n '--port', '-p',\n default=8000,\n help='Port number to run the server on.',\n show_default=True\n)\n@click.option(\n '--verbose', '-v',\n is_flag=True,\n help='Enable verbose logging.'\n)\ndef run(config, host, port, verbose):\n \"\"\"\n Starts the Aeiva Agent Server using FastAPI.\n \"\"\"\n # Setup logging\n logger = setup_logging(get_log_dir() / 'aeiva-server.log', verbose)\n\n # Load configuration\n if config is None:\n PACKAGE_ROOT = get_package_root()\n config_path = PACKAGE_ROOT / 'configs' / 'agent_config.yaml'\n else:\n config_path = Path(config)\n\n logger.info(f\"Loading configuration from {config_path}\")\n config_dict = from_json_or_yaml(config_path)\n\n # Validate and start Neo4j\n neo4j_home = os.getenv('NEO4J_HOME')\n if not neo4j_home:\n logger.error(\"NEO4J_HOME environment variable is not set.\")\n click.echo(\"Error: NEO4J_HOME environment variable is not set.\")\n sys.exit(1)\n\n validate_neo4j_home(logger, neo4j_home)\n neo4j_process = start_neo4j(logger, neo4j_home)\n\n # Initialize the Agent\n try:\n agent = Agent(config_dict)\n agent.setup()\n logger.info(\"Agent initialized successfully.\")\n except Exception as e:\n logger.error(f\"Failed to initialize Agent: {e}\")\n click.echo(f\"Error: Failed to initialize Agent: {e}\")\n stop_neo4j(logger, neo4j_process)\n sys.exit(1)\n\n # Define the FastAPI app with lifespan\n @asynccontextmanager\n async def lifespan(app: FastAPI):\n app.state.agent = agent\n logger.info(\"Agent has been initialized and is ready to receive messages.\")\n try:\n yield\n finally:\n logger.info(\"Shutting down the agent server.\")\n # If the Agent class has a shutdown method, call it here\n if hasattr(app.state.agent, 'shutdown'):\n await app.state.agent.shutdown()\n stop_neo4j(logger, neo4j_process)\n logger.info(\"Agent server shut down gracefully.\")\n\n app = FastAPI(lifespan=lifespan)\n\n # Enable CORS for all origins (for development purposes)\n app.add_middleware(\n CORSMiddleware,\n allow_origins=[\"*\"], # Adjust in production\n allow_credentials=True,\n allow_methods=[\"*\"],\n allow_headers=[\"*\"],\n )\n\n # Define the endpoint\n @app.post(\"/process_text\", response_model=MessageResponse)\n async def process_text(request: MessageRequest):\n if not request.message:\n raise HTTPException(status_code=400, detail=\"No message provided\")\n\n logger.info(f\"Received message: {request.message}\")\n\n # Process the message using the agent\n try:\n response_text = await app.state.agent.process_input(request.message)\n logger.info(f\"Agent response: {response_text}\")\n return MessageResponse(response=response_text)\n except Exception as e:\n logger.error(f\"Error processing input: {e}\")\n raise HTTPException(status_code=500, detail=\"Internal Server Error\")\n\n # Register signal handlers for graceful shutdown using handle_exit\n for sig in [signal.SIGINT, signal.SIGTERM]:\n signal.signal(sig, lambda s, f: handle_exit(s, f, logger, neo4j_process))\n\n # Run the FastAPI app using Uvicorn\n try:\n logger.info(f\"Starting server at http://{host}:{port}\")\n uvicorn.run(app, host=host, port=port)\n except Exception as e:\n logger.error(f\"Server encountered an error: {e}\")\n handle_exit(None, None, logger, neo4j_process) # Ensure cleanup on exception\n sys.exit(1)\n finally:\n logger.info(\"Server has been stopped.\")\n
"},{"location":"reference/#src.aeiva.command.command_utils","title":"command_utils
","text":"Here we put util functions related to database, logging and so on for different aeiva commands execution.
"},{"location":"reference/#src.aeiva.command.command_utils.get_log_dir","title":"get_log_dir()
","text":"Determines a suitable path for the log file. Logs are stored in the user's home directory under '.aeiva/logs/'.
Source code in src/aeiva/command/command_utils.py
def get_log_dir():\n \"\"\"\n Determines a suitable path for the log file.\n Logs are stored in the user's home directory under '.aeiva/logs/'.\n \"\"\"\n home_dir = Path.home()\n log_dir = home_dir / '.aeiva' / 'logs' # Log saved to `~/.aeiva/logs/`\n log_dir.mkdir(parents=True, exist_ok=True) # Ensure the log directory exists\n return log_dir\n
"},{"location":"reference/#src.aeiva.command.command_utils.get_package_root","title":"get_package_root()
","text":"Determines the root path of the 'aeiva' package.
Source code in src/aeiva/command/command_utils.py
def get_package_root():\n \"\"\"\n Determines the root path of the 'aeiva' package.\n \"\"\"\n aeiva_path = Path(importlib_resources.files(\"aeiva\"))\n package_root = aeiva_path.parents[1]\n return package_root.resolve()\n
"},{"location":"reference/#src.aeiva.command.command_utils.handle_exit","title":"handle_exit(signum, frame, logger, neo4j_process)
","text":"Handles termination signals to ensure Neo4j is stopped gracefully.
Source code in src/aeiva/command/command_utils.py
def handle_exit(signum, frame, logger, neo4j_process):\n \"\"\"\n Handles termination signals to ensure Neo4j is stopped gracefully.\n \"\"\"\n logger.info(f\"Received signal {signum}. Shutting down Neo4j.\")\n click.echo(f\"\\nReceived signal {signum}. Shutting down Neo4j.\")\n stop_neo4j(logger, neo4j_process)\n sys.exit(0)\n
"},{"location":"reference/#src.aeiva.command.command_utils.setup_logging","title":"setup_logging(log_file, verbose=False)
","text":"Sets up logging to both file and console.
Source code in src/aeiva/command/command_utils.py
def setup_logging(log_file, verbose=False):\n \"\"\"\n Sets up logging to both file and console.\n \"\"\"\n logger = get_logger(__name__, level=\"DEBUG\" if verbose else \"INFO\")\n\n # Create a file handler\n file_handler = logging.FileHandler(log_file, mode='a')\n file_handler.setLevel(logging.DEBUG if verbose else logging.INFO)\n\n # Create a console handler\n console_handler = logging.StreamHandler()\n console_handler.setLevel(logging.DEBUG if verbose else logging.INFO)\n\n # Create a logging format\n formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n file_handler.setFormatter(formatter)\n console_handler.setFormatter(formatter)\n\n # Add handlers to the logger\n logger.addHandler(file_handler)\n logger.addHandler(console_handler)\n\n return logger\n
"},{"location":"reference/#src.aeiva.command.command_utils.start_neo4j","title":"start_neo4j(logger, neo4j_home)
","text":"Starts the Neo4j database as a subprocess.
Source code in src/aeiva/command/command_utils.py
def start_neo4j(logger, neo4j_home):\n \"\"\"\n Starts the Neo4j database as a subprocess.\n \"\"\"\n neo4j_command = [os.path.join(neo4j_home, 'bin', 'neo4j'), 'console']\n try:\n neo4j_process = subprocess.Popen(\n neo4j_command,\n stdout=subprocess.DEVNULL, # Suppress stdout\n stderr=subprocess.DEVNULL, # Suppress stderr\n stdin=subprocess.DEVNULL, # Prevent Neo4j from waiting for input\n preexec_fn=os.setsid # Start the process in a new session\n )\n logger.info(\"Neo4j database started successfully.\")\n click.echo(\"Neo4j database started successfully.\")\n return neo4j_process\n except FileNotFoundError:\n logger.error(f\"Neo4j executable not found in {neo4j_command}.\")\n click.echo(f\"Error: Neo4j executable not found in {neo4j_command}.\")\n sys.exit(1)\n except Exception as e:\n logger.error(f\"Failed to start Neo4j: {e}\")\n click.echo(f\"Error: Failed to start Neo4j: {e}\")\n sys.exit(1)\n
"},{"location":"reference/#src.aeiva.command.command_utils.stop_neo4j","title":"stop_neo4j(logger, neo4j_process)
","text":"Stops the Neo4j database subprocess gracefully.
Source code in src/aeiva/command/command_utils.py
def stop_neo4j(logger, neo4j_process):\n \"\"\"\n Stops the Neo4j database subprocess gracefully.\n \"\"\"\n try:\n # Check if the process is still running\n if neo4j_process.poll() is None:\n os.killpg(os.getpgid(neo4j_process.pid), signal.SIGINT) # Send SIGINT for graceful shutdown\n logger.info(\"Sent SIGINT to Neo4j subprocess.\")\n click.echo(\"Shutting down Neo4j...\")\n neo4j_process.wait(timeout=15) # Increased timeout to 15 seconds\n logger.info(\"Neo4j database stopped successfully.\")\n click.echo(\"Neo4j database stopped successfully.\")\n else:\n logger.warning(\"Neo4j subprocess is already terminated.\")\n click.echo(\"Warning: Neo4j subprocess is already terminated.\")\n except subprocess.TimeoutExpired:\n logger.error(\"Neo4j did not terminate within the timeout period.\")\n click.echo(\"Error: Neo4j did not terminate within the timeout period.\")\n # Optionally, force kill\n try:\n os.killpg(os.getpgid(neo4j_process.pid), signal.SIGKILL)\n neo4j_process.wait(timeout=5)\n logger.info(\"Neo4j database forcefully terminated.\")\n click.echo(\"Neo4j database forcefully terminated.\")\n except Exception as e:\n logger.error(f\"Failed to forcefully terminate Neo4j: {e}\")\n click.echo(f\"Error: Failed to forcefully terminate Neo4j: {e}\")\n except ProcessLookupError:\n logger.warning(\"Neo4j subprocess does not exist.\")\n click.echo(\"Warning: Neo4j subprocess does not exist. It may have already terminated.\")\n except Exception as e:\n logger.error(f\"Error stopping Neo4j: {e}\")\n click.echo(f\"Error: Failed to stop Neo4j: {e}\")\n
"},{"location":"reference/#src.aeiva.command.command_utils.validate_neo4j_home","title":"validate_neo4j_home(logger, neo4j_home)
","text":"Validates that the NEO4J_HOME path exists and contains the Neo4j executable.
Source code in src/aeiva/command/command_utils.py
def validate_neo4j_home(logger, neo4j_home):\n \"\"\"\n Validates that the NEO4J_HOME path exists and contains the Neo4j executable.\n \"\"\"\n if not os.path.isdir(neo4j_home):\n logger.error(f\"NEO4J_HOME path does not exist or is not a directory: {neo4j_home}\")\n click.echo(f\"Error: NEO4J_HOME path does not exist or is not a directory: {neo4j_home}\")\n sys.exit(1)\n\n neo4j_executable = os.path.join(neo4j_home, 'bin', 'neo4j')\n if not os.path.isfile(neo4j_executable) or not os.access(neo4j_executable, os.X_OK):\n logger.error(f\"Neo4j executable not found or not executable at: {neo4j_executable}\")\n click.echo(f\"Error: Neo4j executable not found or not executable at: {neo4j_executable}\")\n sys.exit(1)\n
"},{"location":"reference/#src.aeiva.command.maid_chat","title":"maid_chat
","text":""},{"location":"reference/#src.aeiva.command.maid_chat.run","title":"run(config, host, port, verbose)
","text":"Starts the Aeiva Agent Server and launches the Unity application.
Source code in src/aeiva/command/maid_chat.py
@click.command(name=\"maid-chat\")\n@click.option(\n '--config', '-c',\n default=None,\n help='Path to the configuration file (YAML or JSON).',\n type=click.Path(exists=True, dir_okay=False)\n)\n@click.option(\n '--host', '-H',\n default=\"0.0.0.0\",\n help='Host address to run the server on.',\n show_default=True\n)\n@click.option(\n '--port', '-p',\n default=8000,\n help='Port number to run the server on.',\n show_default=True\n)\n@click.option(\n '--verbose', '-v',\n is_flag=True,\n help='Enable verbose logging.'\n)\ndef run(config, host, port, verbose):\n \"\"\"\n Starts the Aeiva Agent Server and launches the Unity application.\n \"\"\"\n # Setup logging\n logger = setup_logging(get_log_dir() / 'maid-chat.log', verbose)\n\n # Load configuration\n if config is None:\n PACKAGE_ROOT = get_package_root()\n config_path = PACKAGE_ROOT / 'configs' / 'agent_config.yaml'\n else:\n config_path = Path(config)\n\n logger.info(f\"Loading configuration from {config_path}\")\n config_dict = from_json_or_yaml(config_path)\n\n # Validate and start Neo4j\n neo4j_home = os.getenv('NEO4J_HOME')\n if not neo4j_home:\n logger.error(\"NEO4J_HOME environment variable is not set.\")\n click.echo(\"Error: NEO4J_HOME environment variable is not set.\")\n sys.exit(1)\n\n validate_neo4j_home(logger, neo4j_home)\n neo4j_process = start_neo4j(logger, neo4j_home)\n\n # Initialize the Agent\n try:\n agent = Agent(config_dict)\n agent.setup()\n logger.info(\"Agent initialized successfully.\")\n except Exception as e:\n logger.error(f\"Failed to initialize Agent: {e}\")\n click.echo(f\"Error: Failed to initialize Agent: {e}\")\n stop_neo4j(logger, neo4j_process)\n sys.exit(1)\n\n # Read MAID_HOME environment variable\n maid_home = os.getenv('MAID_HOME')\n if not maid_home:\n logger.error(\"MAID_HOME environment variable is not set.\")\n click.echo(\"Error: MAID_HOME environment variable is not set.\")\n stop_neo4j(logger, neo4j_process)\n sys.exit(1)\n\n maid_home_path = Path(maid_home)\n if not maid_home_path.exists():\n logger.error(f\"Unity application not found at MAID_HOME: {maid_home}\")\n click.echo(f\"Error: Unity application not found at MAID_HOME: {maid_home}\")\n stop_neo4j(logger, neo4j_process)\n sys.exit(1)\n\n # Start the Unity application\n unity_process = start_unity_app(str(maid_home_path), logger)\n if unity_process is None:\n stop_neo4j(logger, neo4j_process)\n sys.exit(1)\n\n # Define the FastAPI app with lifespan\n @asynccontextmanager\n async def lifespan(app: FastAPI):\n app.state.agent = agent\n logger.info(\"Agent has been initialized and is ready to receive messages.\")\n try:\n yield\n finally:\n logger.info(\"Shutting down the agent server.\")\n # If the Agent class has a shutdown method, call it here\n if hasattr(app.state.agent, 'shutdown'):\n await app.state.agent.shutdown()\n stop_neo4j(logger, neo4j_process)\n # Terminate the Unity application\n stop_unity_app(unity_process, logger)\n logger.info(\"Agent server shut down gracefully.\")\n\n app = FastAPI(lifespan=lifespan)\n\n # Enable CORS for all origins (for development purposes)\n app.add_middleware(\n CORSMiddleware,\n allow_origins=[\"*\"], # Adjust in production\n allow_credentials=True,\n allow_methods=[\"*\"],\n allow_headers=[\"*\"],\n )\n\n # Define the endpoint\n @app.post(\"/process_text\", response_model=MessageResponse)\n async def process_text(request: MessageRequest):\n if not request.message:\n raise HTTPException(status_code=400, detail=\"No message provided\")\n\n logger.info(f\"Received message: {request.message}\")\n\n # Process the message using the agent\n try:\n response_text = await app.state.agent.process_input(request.message)\n logger.info(f\"Agent response: {response_text}\")\n return MessageResponse(response=response_text)\n except Exception as e:\n logger.error(f\"Error processing input: {e}\")\n raise HTTPException(status_code=500, detail=\"Internal Server Error\")\n\n # Register signal handlers for graceful shutdown using handle_exit\n for sig in [signal.SIGINT, signal.SIGTERM]:\n signal.signal(sig, lambda s, f: handle_exit(s, f, logger, neo4j_process, unity_process))\n\n # Run the FastAPI app using Uvicorn\n try:\n logger.info(f\"Starting server at http://{host}:{port}\")\n uvicorn.run(app, host=host, port=port)\n except Exception as e:\n logger.error(f\"Server encountered an error: {e}\")\n handle_exit(None, None, logger, neo4j_process, unity_process) # Ensure cleanup on exception\n sys.exit(1)\n finally:\n logger.info(\"Server has been stopped.\")\n
"},{"location":"reference/#src.aeiva.command.maid_chat.start_unity_app","title":"start_unity_app(maid_home, logger)
","text":"Starts the Unity application.
Parameters:
Name Type Description Default maid_home
str
Path to the Unity application executable.
required logger
Logger
Logger instance.
required Returns:
Type Description Optional[Popen]
Optional[subprocess.Popen]: The subprocess running the Unity application, or None if failed.
Source code in src/aeiva/command/maid_chat.py
def start_unity_app(maid_home: str, logger: logging.Logger) -> Optional[subprocess.Popen]:\n \"\"\"\n Starts the Unity application.\n\n Args:\n maid_home (str): Path to the Unity application executable.\n logger (logging.Logger): Logger instance.\n\n Returns:\n Optional[subprocess.Popen]: The subprocess running the Unity application, or None if failed.\n \"\"\"\n try:\n unity_process = subprocess.Popen(\n [maid_home],\n stdout=subprocess.DEVNULL,\n stderr=subprocess.DEVNULL,\n preexec_fn=os.setsid # Start the process in a new session\n )\n logger.info(f\"Unity application started from {maid_home}.\")\n click.echo(f\"Unity application started from {maid_home}.\")\n return unity_process\n except FileNotFoundError:\n logger.error(f\"Unity application not found at {maid_home}.\")\n click.echo(f\"Error: Unity application not found at {maid_home}.\")\n return None\n except Exception as e:\n logger.error(f\"Failed to start Unity application: {e}\")\n click.echo(f\"Error: Failed to start Unity application: {e}.\")\n return None\n
"},{"location":"reference/#src.aeiva.command.maid_chat.stop_unity_app","title":"stop_unity_app(unity_process, logger)
","text":"Stops the Unity application gracefully.
Parameters:
Name Type Description Default unity_process
Popen
The subprocess running the Unity application.
required logger
Logger
Logger instance.
required Source code in src/aeiva/command/maid_chat.py
def stop_unity_app(unity_process: subprocess.Popen, logger: logging.Logger):\n \"\"\"\n Stops the Unity application gracefully.\n\n Args:\n unity_process (subprocess.Popen): The subprocess running the Unity application.\n logger (logging.Logger): Logger instance.\n \"\"\"\n try:\n os.killpg(os.getpgid(unity_process.pid), signal.SIGTERM)\n unity_process.wait(timeout=10)\n logger.info(\"Unity application terminated gracefully.\")\n click.echo(\"Unity application terminated gracefully.\")\n except Exception as e:\n logger.error(f\"Error terminating Unity application: {e}\")\n click.echo(f\"Error: Failed to terminate Unity application: {e}.\")\n
"},{"location":"reference/#src.aeiva.common","title":"common
","text":""},{"location":"reference/#src.aeiva.common.decorators","title":"decorators
","text":""},{"location":"reference/#src.aeiva.common.decorators.import_submodules","title":"import_submodules(package, recursive=True)
","text":"Import all submodules of a module, recursively, including subpackages
Source code in src/aeiva/common/decorators.py
def import_submodules(package, recursive=True):\n \"\"\" Import all submodules of a module, recursively, including subpackages \"\"\"\n\n if isinstance(package, str):\n package = importlib.import_module(package)\n\n results = {}\n\n for loader, name, is_pkg in pkgutil.walk_packages(package.__path__):\n full_name = package.__name__ + \".\" + name\n results[full_name] = importlib.import_module(full_name)\n if recursive and is_pkg:\n results.update(import_submodules(full_name))\n\n return results\n
"},{"location":"reference/#src.aeiva.common.id_generator","title":"id_generator
","text":""},{"location":"reference/#src.aeiva.common.id_generator.IDGenerator","title":"IDGenerator
","text":"A simple class to generate unique IDs for distinct names.
Attributes:
Name Type Description name_to_id
dict
A dictionary to map names to IDs.
next_id
int
The next ID to be assigned.
Source code in src/aeiva/common/id_generator.py
class IDGenerator:\n \"\"\"\n A simple class to generate unique IDs for distinct names.\n\n Attributes:\n name_to_id (dict): A dictionary to map names to IDs.\n next_id (int): The next ID to be assigned.\n \"\"\"\n\n def __init__(self):\n \"\"\"\n Constructs all the necessary attributes for the IDGenerator object.\n\n Attributes:\n name_to_id (dict): Initializes an empty dictionary to map names to IDs.\n next_id (int): Initializes the next ID to be assigned as 0.\n \"\"\"\n self.name_to_id = {}\n self.next_id = 0\n\n def get_id(self, name: str) -> int:\n \"\"\"\n Returns the ID of the 'name'. If 'name' does not exist, assigns a new ID.\n\n Parameters:\n name (str): The name for which the ID is required.\n\n Returns:\n int: The ID associated with the 'name'.\n \"\"\"\n if name not in self.name_to_id:\n self.name_to_id[name] = self.next_id\n self.next_id += 1\n return self.name_to_id[name]\n
"},{"location":"reference/#src.aeiva.common.id_generator.IDGenerator.__init__","title":"__init__()
","text":"Constructs all the necessary attributes for the IDGenerator object.
Attributes:
Name Type Description name_to_id
dict
Initializes an empty dictionary to map names to IDs.
next_id
int
Initializes the next ID to be assigned as 0.
Source code in src/aeiva/common/id_generator.py
def __init__(self):\n \"\"\"\n Constructs all the necessary attributes for the IDGenerator object.\n\n Attributes:\n name_to_id (dict): Initializes an empty dictionary to map names to IDs.\n next_id (int): Initializes the next ID to be assigned as 0.\n \"\"\"\n self.name_to_id = {}\n self.next_id = 0\n
"},{"location":"reference/#src.aeiva.common.id_generator.IDGenerator.get_id","title":"get_id(name)
","text":"Returns the ID of the 'name'. If 'name' does not exist, assigns a new ID.
Parameters:
Name Type Description Default name
str
The name for which the ID is required.
required Returns:
Name Type Description int
int
The ID associated with the 'name'.
Source code in src/aeiva/common/id_generator.py
def get_id(self, name: str) -> int:\n \"\"\"\n Returns the ID of the 'name'. If 'name' does not exist, assigns a new ID.\n\n Parameters:\n name (str): The name for which the ID is required.\n\n Returns:\n int: The ID associated with the 'name'.\n \"\"\"\n if name not in self.name_to_id:\n self.name_to_id[name] = self.next_id\n self.next_id += 1\n return self.name_to_id[name]\n
"},{"location":"reference/#src.aeiva.common.pipeline","title":"pipeline
","text":""},{"location":"reference/#src.aeiva.common.pipeline.Pipeline","title":"Pipeline
","text":"This class is used to rurn a list of functions into a pipeline.
Source code in src/aeiva/common/pipeline.py
class Pipeline:\n r\"\"\"This class is used to rurn a list of functions into a pipeline.\"\"\"\n def __init__(self, functions):\n self.functions = functions\n\n def run(self, *args, **kwargs):\n result = self.functions[0](*args, **kwargs)\n for f in self.functions[1:]:\n if isinstance(result, tuple):\n result = f(*result)\n else:\n result = f(result)\n return result\n\n def __call__(self, *args, **kwargs):\n return self.run(*args, **kwargs)\n
"},{"location":"reference/#src.aeiva.common.types","title":"types
","text":""},{"location":"reference/#src.aeiva.common.types.DataBatch","title":"DataBatch
","text":" Bases: TypedDict
DataBatch is a batch of data items created by a dataloader.
Source code in src/aeiva/common/types.py
class DataBatch(TypedDict):\n r\"\"\"DataBatch is a batch of data items created by a dataloader.\n \"\"\"\n videos: Optional[torch.Tensor] # videos representation\n audios: Optional[torch.Tensor] # audios representation\n images: Optional[torch.Tensor] # images representation\n input_ids: Optional[torch.Tensor] # text token ids\n attention_mask: Optional[torch.Tensor] # attention mask\n image_starts: Optional[torch.Tensor] # image start token\n image_ends: Optional[torch.Tensor] # image end token\n audio_starts: Optional[torch.Tensor] # audio start token\n audio_ends: Optional[torch.Tensor] # audio end token\n video_starts: Optional[torch.Tensor] # video start token\n video_ends: Optional[torch.Tensor] # video end token\n labels: Optional[torch.Tensor] # labels\n
"},{"location":"reference/#src.aeiva.common.types.DataItem","title":"DataItem
","text":" Bases: TypedDict
DataItem is a dictionary that contains all the information for a single data item.
Source code in src/aeiva/common/types.py
class DataItem(TypedDict):\n r\"\"\"DataItem is a dictionary that contains all the information for a single data item.\n \"\"\"\n instruction: str # instruction text\n input: Optional[str] # input text\n output: Optional[str] # output text\n text: Optional[str] # text field. How it is formed depends on the task.\n\n image: Optional[str] # image name or path\n transformed_image: Optional[torch.Tensor] # transformed image tensor\n\n audio: Optional[str] # audio name or path\n audio_mels: Optional[torch.Tensor] # audio melspectrogram tensor\n\n video: Optional[str] # video name or path\n sampled_video_frame_indices: Optional[list[int]] # sampled video frame indices\n video_frames: Optional[torch.Tensor] # video frames tensor\n
"},{"location":"reference/#src.aeiva.common.types.DataSet","title":"DataSet
","text":" Bases: TypedDict
DataSet is a dictionary that contains data items and meta information.
Source code in src/aeiva/common/types.py
class DataSet(TypedDict):\n r\"\"\"DataSet is a dictionary that contains data items and meta information.\n \"\"\"\n data: list[DataItem]\n metadata: dict[str, Any]\n
"},{"location":"reference/#src.aeiva.common.types.ModelInput","title":"ModelInput
","text":" Bases: TypedDict
ModelInput is a dictionary that contains all the information for a model input. We use it to construct LEGO style models.
Source code in src/aeiva/common/types.py
class ModelInput(TypedDict):\n r\"\"\"ModelInput is a dictionary that contains all the information for a model input.\n We use it to construct LEGO style models.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.common.types.ModelOutput","title":"ModelOutput
","text":" Bases: TypedDict
ModelOutput is a dictionary that contains all the information for a model output. We use it to construct LEGO style models.
Source code in src/aeiva/common/types.py
class ModelOutput(TypedDict):\n r\"\"\"ModelOutput is a dictionary that contains all the information for a model output.\n We use it to construct LEGO style models.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.common.types.TaskContext","title":"TaskContext
","text":" Bases: TypedDict
TaskContext is a dictionary that contains all the information for a task.
Source code in src/aeiva/common/types.py
class TaskContext(TypedDict):\n r\"\"\"TaskContext is a dictionary that contains all the information for a task.\n \"\"\"\n config_path: Optional[str]\n config: Optional[OmniConfig]\n dataloader: Optional[torch.utils.data.DataLoader]\n tokenizer: Optional[Any]\n model: Optional[Any]\n logger: Optional[Any]\n trainer: Optional[Any]\n current_model_input: Optional[DataItem]\n current_model_output: Optional[Any]\n
"},{"location":"reference/#src.aeiva.config","title":"config
","text":""},{"location":"reference/#src.aeiva.config.DataConfig","title":"DataConfig
dataclass
","text":" Bases: BaseConfig
This class contains the data configuration.
Source code in src/aeiva/config/general_configs.py
@dataclass\nclass DataConfig(BaseConfig):\n \"\"\"This class contains the data configuration.\"\"\"\n dataset_path: Optional[str] = field(\n default=None, metadata={\"help\": \"The path of the dataset to use.\"}\n )\n dataset_name: Optional[str] = field(\n default=\"customized\", metadata={\"help\": \"Should be \\\"customized\\\"\"}\n )\n is_custom_dataset: Optional[bool] = field(\n default=False, metadata={\"help\": \"whether to use custom data\"}\n )\n customized_cache_dir: Optional[str] = field(\n default=\".cache/llm-ft/datasets\",\n metadata={\"help\": \"Where do you want to store the customized dataset caches\"},\n )\n dataset_config_name: Optional[str] = field(\n default=None, metadata={\"help\": \"The configuration name of the dataset to use (via the datasets library).\"}\n )\n train_file: Optional[str] = field(default=None, metadata={\"help\": \"The input training data file (a text file).\"})\n validation_file: Optional[str] = field(\n default=None,\n metadata={\"help\": \"An optional input evaluation data file to evaluate the perplexity on (a text file).\"},\n )\n max_train_samples: Optional[int] = field(\n default=None,\n metadata={\n \"help\": (\n \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n \"value if set.\"\n )\n },\n )\n max_eval_samples: Optional[int] = field(\n default=1e10,\n metadata={\n \"help\": (\n \"For debugging purposes or quicker training, truncate the number of evaluation examples to this \"\n \"value if set.\"\n )\n },\n )\n streaming: Optional[bool] = field(default=False, metadata={\"help\": \"Enable streaming mode\"})\n block_size: Optional[int] = field(\n default=512,\n metadata={\n \"help\": (\n \"Optional input sequence length after tokenization. \"\n \"The training dataset will be truncated in block of this size for training. \"\n \"Default to the model max input length for single sentence inputs (take into account special tokens).\"\n )\n },\n )\n overwrite_cache: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"Overwrite the cached training and evaluation sets\"}\n )\n validation_split_percentage: Optional[int] = field(\n default=5,\n metadata={\n \"help\": \"The percentage of the train set used as validation set in case there's no validation split\"\n },\n )\n preprocessing_num_workers: Optional[int] = field(\n default=None,\n metadata={\"help\": \"The number of processes to use for the preprocessing.\"},\n )\n group_texts_batch_size: Optional[int] = field(\n default=1000,\n metadata={\n \"help\": (\n \"Number of samples that will be grouped together to go though\"\n \" `group_texts` operation. See `--disable_group_texts` for\"\n \" detailed explanation of this operation.\"\n )\n }\n )\n disable_group_texts: Optional[bool] = field(\n default=False,\n metadata={\n \"help\": (\n \"Whether we group original samples together to generate sample\"\n \" sequences of length `block_size`. By default, we group every\"\n \" 1000 tokenized sequences together, divide them into \"\n \" [{total_num_tokens} / {block_size}] sequences, each with\"\n \" `block_size` tokens (the remaining tokens are ommited.\"\n \" If this flag is set to True, we only group 1 tokenized\"\n \" sequence, i.e. cutting long sequence into chunks.\"\n )\n },\n )\n keep_linebreaks: Optional[bool] = field(\n default=True,\n metadata={\"help\": \"Whether to keep line breaks when using TXT files or not.\"}\n )\n test_file: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Evaluation File Path\"},\n )\n\n def __post_init__(self):\n if self.streaming:\n require_version(\"datasets>=2.0.0\", \"The streaming feature requires `datasets>=2.0.0`\")\n\n if self.dataset_name is None and self.train_file is None and self.validation_file is None:\n raise ValueError(\"Need either a dataset name or a training/validation file.\")\n else:\n if self.train_file is not None:\n extension = self.train_file.split(\".\")[-1]\n assert extension in [\"csv\", \"json\", \"txt\"], \"`train_file` should be a csv, a json or a txt file.\"\n if self.validation_file is not None:\n extension = self.validation_file.split(\".\")[-1]\n assert extension in [\"csv\", \"json\", \"txt\"], \"`validation_file` should be a csv, a json or a txt file.\"\n
"},{"location":"reference/#src.aeiva.config.ExplicitEnum","title":"ExplicitEnum
","text":" Bases: str
, Enum
Enum with more explicit error message for missing values.
Source code in src/aeiva/config/general_configs.py
class ExplicitEnum(str, Enum):\n \"\"\"\n Enum with more explicit error message for missing values.\n \"\"\"\n @classmethod\n def _missing_(cls, value):\n raise ValueError(\n f\"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}\"\n )\n
"},{"location":"reference/#src.aeiva.config.ModelConfig","title":"ModelConfig
dataclass
","text":" Bases: BaseConfig
Model configuration class.
Source code in src/aeiva/config/general_configs.py
@dataclass\nclass ModelConfig(BaseConfig):\n \"\"\"Model configuration class.\"\"\"\n model_name_or_path: Optional[str] = field(\n default=None,\n metadata={\n \"help\": (\n \"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch.\"\n )\n },\n )\n lora_model_path: Optional[str] = field(\n default=None,\n metadata={\n \"help\": (\n \"The incremental model diff introduced by LoRA finetuning.\"\n \" Along with the original non-finetuned model forms the whole\"\n \" finetuned model.\"\n )\n }\n )\n model_type: Optional[str] = field(\n default=None,\n metadata={\"help\": \"If training from scratch, pass a model type from the list: \" + \", \".join(MODEL_TYPES)},\n )\n arch_type: Optional[str] = field(\n default=\"decoder_only\",\n metadata={\"help\": \"The architecture type of the model. Currently supported decoder_only or encoder_decoder\"}\n )\n config_overrides: Optional[str] = field(\n default=None,\n metadata={\n \"help\": (\n \"Override some existing default config settings when a model is trained from scratch. Example: \"\n \"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index\"\n )\n },\n )\n arch_type: Optional[str] = field(\n default=\"decoder_only\",\n metadata={\n \"help\": (\n \"Model architecture type, e.g. \\\"decoder_only\\\",\"\n \" \\\"encoder_decoder\\\"\"\n ),\n \"choices\": [\"decoder_only\", \"encoder_decoder\", \"text_regression\", \"vision_encoder_decoder\"],\n },\n )\n config_name: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Pretrained config name or path if not the same as model_name\"}\n )\n tokenizer_name: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Pretrained tokenizer name or path if not the same as model_name\"}\n )\n cache_dir: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Where do you want to store the pretrained models downloaded from huggingface.co\"},\n )\n use_fast_tokenizer: Optional[bool] = field(\n default=True,\n metadata={\"help\": \"Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.\"},\n )\n model_revision: Optional[str] = field(\n default=\"main\",\n metadata={\"help\": \"The specific model version to use (can be a branch name, tag name or commit id).\"},\n )\n use_auth_token: Optional[bool] = field(\n default=False,\n metadata={\n \"help\": (\n \"Will use the token generated when running `huggingface-cli login` (necessary to use this script \"\n \"with private models).\"\n )\n },\n )\n torch_dtype: Optional[str] = field(\n default=None,\n metadata={\n \"help\": (\n \"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the \"\n \"dtype will be automatically derived from the model's weights.\"\n ),\n \"choices\": [\"auto\", \"bfloat16\", \"float16\", \"float32\"],\n },\n )\n use_lora: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"Whether to lora.\"},\n )\n lora_r: Optional[int] = field(\n default=8,\n metadata={\"help\": \"the rank of the lora parameters. The smaller lora_r is , the fewer parameters lora has.\"},\n )\n lora_alpha: Optional[int] = field(\n default=32,\n metadata={\"help\": \"Merging ratio between the fine-tuned model and the original. This is controlled by a parameter called alpha in the paper.\"},\n )\n lora_target_modules: Optional[list[str]] = field(\n default=None,\n metadata={\"help\": \"Pretrained config name or path if not the same as model_name\",\n }\n )\n lora_dropout: Optional[float] = field(\n default=0.1,\n metadata={\"help\": \"The dropout rate in lora.linear.\"},\n )\n save_aggregated_lora: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"Whether to save aggregated lora.\"},\n )\n use_ram_optimized_load: Optional[bool] = field(\n default=True,\n metadata={\"help\": \"Whether use disk mapping when memory is not enough.\"}\n )\n use_flash_attention: Optional[bool] = field(\n default=False,\n metadata={\n \"help\": (\n \"whether use flash attention layer to reduce GPU memory with\"\n \" higher time cost.\"\n )\n }\n )\n use_int8: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"whether to load int8 quantization for inference\"}\n )\n custom_model: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"flag for the model from huggingface or not\"}\n )\n # below is added for macaw model\n n_frames: Optional[int] = field(\n default=6,\n metadata={\n \"help\": \"The number of frames for encoding a video.\"\n },\n )\n attention_heads: Optional[int] = field(\n default=220,\n metadata={\n \"help\": \"The number of attention heads used in multi-head-attention.\"\n },\n )\n image_conv_kernel: Optional[int] = field(\n default=48,\n metadata={\n \"help\": \"The size of the convolutional kernel for the image stream.\"\n },\n )\n image_conv_stride: Optional[int] = field(\n default=36,\n metadata={\n \"help\": \"The stride of the convolutional kernel for the image stream.\"\n },\n )\n video_conv_kernel: Optional[int] = field(\n default=36,\n metadata={\n \"help\": \"The size of the convolutional kernel for the video stream.\"\n },\n )\n video_conv_stride: Optional[int] = field(\n default=30,\n metadata={\n \"help\": \"The stride of the convolutional kernel for the video stream.\"\n },\n )\n audio_conv_kernel: Optional[int] = field(\n default=240,\n metadata={\n \"help\": \"The size of the convolutional kernel for the audio stream.\"\n },\n )\n audio_conv_stride: Optional[int] = field(\n default=220,\n metadata={\n \"help\": \"The stride of the convolutional kernel for the audio stream.\"\n },\n )\n freeze_multi_modal_encoder: bool = field(\n default=False,\n metadata={\n \"help\": (\n \"Whether to freeze the parameters of multi-modal encoders during training.).\"\n )\n },\n )\n\n def __post_init__(self):\n if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):\n raise ValueError(\n \"--config_overrides can't be used in combination with --config_name or --model_name_or_path\"\n )\n
"},{"location":"reference/#src.aeiva.config.OptimizerNames","title":"OptimizerNames
","text":" Bases: ExplicitEnum
Stores the acceptable string identifiers for optimizers.
Source code in src/aeiva/config/general_configs.py
class OptimizerNames(ExplicitEnum):\n \"\"\"\n Stores the acceptable string identifiers for optimizers.\n \"\"\"\n ADAMW_HF = \"adamw_hf\"\n ADAMW_TORCH = \"adamw_torch\"\n ADAMW_TORCH_FUSED = \"adamw_torch_fused\"\n ADAMW_TORCH_XLA = \"adamw_torch_xla\"\n ADAMW_APEX_FUSED = \"adamw_apex_fused\"\n ADAFACTOR = \"adafactor\"\n ADAMW_ANYPRECISION = \"adamw_anyprecision\"\n SGD = \"sgd\"\n ADAGRAD = \"adagrad\"\n ADAMW_BNB = \"adamw_bnb_8bit\"\n ADAMW_8BIT = \"adamw_8bit\" # just an alias for adamw_bnb_8bit\n LION_8BIT = \"lion_8bit\"\n LION = \"lion_32bit\"\n PAGED_ADAMW = \"paged_adamw_32bit\"\n PAGED_ADAMW_8BIT = \"paged_adamw_8bit\"\n PAGED_LION = \"paged_lion_32bit\"\n PAGED_LION_8BIT = \"paged_lion_8bit\"\n
"},{"location":"reference/#src.aeiva.config.base_config","title":"base_config
","text":"This module contains the base config classes.
We can define separate config classes for different modules, e.g., data, model, trainer, llm, etc. They will be automatically registered in the BaseConfig class.
Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig","title":"BaseConfig
dataclass
","text":"Base class for all configuration classes.
Source code in src/aeiva/config/base_config.py
@dataclass\nclass BaseConfig:\n \"\"\"\n Base class for all configuration classes.\n \"\"\"\n subclasses = {} # Dictionary to store subclasses\n\n def __init_subclass__(cls, **kwargs):\n \"\"\"\n This method is called when a subclass is created.\n \"\"\"\n super().__init_subclass__(**kwargs)\n BaseConfig.subclasses[cls.__name__] = cls\n\n def __post_init__(self):\n \"\"\"\n Empty post-init to allow subclasses to call super().__post_init__().\n \"\"\"\n pass\n\n @classmethod\n def from_dict(cls, data: dict):\n \"\"\"\n Create a new instance of the class from a dictionary.\n \"\"\"\n try:\n return cls(**data)\n except TypeError as e:\n invalid_keys = [key.strip(\"'\") for key in re.findall(r\"'(\\w+)'\", str(e))]\n raise ValueError(f\"Invalid config keys provided: {invalid_keys}. Details: {e}\")\n\n def to_dict(self):\n \"\"\"\n Convert the instance to a dictionary.\n \"\"\"\n return {k: v for k, v in self.__dict__.items() if not k.startswith('_')}\n\n @classmethod\n def from_json(cls, json_path: str):\n \"\"\"\n Create a new instance of the class from a JSON file.\n \"\"\"\n with open(json_path, \"r\") as json_file:\n data = json.load(json_file)\n return cls.from_dict(data)\n\n def to_json(self, filepath: str):\n \"\"\"\n Convert the instance to a JSON file.\n \"\"\"\n with open(filepath, 'w') as json_file:\n json.dump(self.to_dict(), json_file, indent=4)\n\n @classmethod\n def from_yaml(cls, yaml_path: str):\n \"\"\"\n Create a new instance of the class from a YAML file.\n \"\"\"\n with open(yaml_path, \"r\") as yaml_file:\n data = yaml.safe_load(yaml_file)\n return cls.from_dict(data)\n\n def to_yaml(self, filepath: str):\n \"\"\"\n Convert the instance to a YAML file.\n \"\"\"\n with open(filepath, 'w') as yaml_file:\n yaml.dump(self.to_dict(), yaml_file)\n\n @classmethod\n def from_json_or_yaml(cls, file_path: str):\n \"\"\"\n Create a new instance of the class from a JSON or YAML file.\n \"\"\"\n _, file_extension = os.path.splitext(file_path)\n if file_extension == \".json\":\n return cls.from_json(file_path)\n elif file_extension == \".yaml\" or file_extension == \".yml\":\n return cls.from_yaml(file_path)\n else:\n raise ValueError(f\"Unsupported file extension: {file_extension}. Please use .json or .yaml\")\n\n def __str__(self):\n \"\"\"\n Return a string representation of the instance.\n \"\"\"\n return pprint.pformat(self.to_dict(), indent=4)\n
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig.__init_subclass__","title":"__init_subclass__(**kwargs)
","text":"This method is called when a subclass is created.
Source code in src/aeiva/config/base_config.py
def __init_subclass__(cls, **kwargs):\n \"\"\"\n This method is called when a subclass is created.\n \"\"\"\n super().__init_subclass__(**kwargs)\n BaseConfig.subclasses[cls.__name__] = cls\n
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig.__post_init__","title":"__post_init__()
","text":"Empty post-init to allow subclasses to call super().post_init().
Source code in src/aeiva/config/base_config.py
def __post_init__(self):\n \"\"\"\n Empty post-init to allow subclasses to call super().__post_init__().\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig.__str__","title":"__str__()
","text":"Return a string representation of the instance.
Source code in src/aeiva/config/base_config.py
def __str__(self):\n \"\"\"\n Return a string representation of the instance.\n \"\"\"\n return pprint.pformat(self.to_dict(), indent=4)\n
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig.from_dict","title":"from_dict(data)
classmethod
","text":"Create a new instance of the class from a dictionary.
Source code in src/aeiva/config/base_config.py
@classmethod\ndef from_dict(cls, data: dict):\n \"\"\"\n Create a new instance of the class from a dictionary.\n \"\"\"\n try:\n return cls(**data)\n except TypeError as e:\n invalid_keys = [key.strip(\"'\") for key in re.findall(r\"'(\\w+)'\", str(e))]\n raise ValueError(f\"Invalid config keys provided: {invalid_keys}. Details: {e}\")\n
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig.from_json","title":"from_json(json_path)
classmethod
","text":"Create a new instance of the class from a JSON file.
Source code in src/aeiva/config/base_config.py
@classmethod\ndef from_json(cls, json_path: str):\n \"\"\"\n Create a new instance of the class from a JSON file.\n \"\"\"\n with open(json_path, \"r\") as json_file:\n data = json.load(json_file)\n return cls.from_dict(data)\n
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig.from_json_or_yaml","title":"from_json_or_yaml(file_path)
classmethod
","text":"Create a new instance of the class from a JSON or YAML file.
Source code in src/aeiva/config/base_config.py
@classmethod\ndef from_json_or_yaml(cls, file_path: str):\n \"\"\"\n Create a new instance of the class from a JSON or YAML file.\n \"\"\"\n _, file_extension = os.path.splitext(file_path)\n if file_extension == \".json\":\n return cls.from_json(file_path)\n elif file_extension == \".yaml\" or file_extension == \".yml\":\n return cls.from_yaml(file_path)\n else:\n raise ValueError(f\"Unsupported file extension: {file_extension}. Please use .json or .yaml\")\n
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig.from_yaml","title":"from_yaml(yaml_path)
classmethod
","text":"Create a new instance of the class from a YAML file.
Source code in src/aeiva/config/base_config.py
@classmethod\ndef from_yaml(cls, yaml_path: str):\n \"\"\"\n Create a new instance of the class from a YAML file.\n \"\"\"\n with open(yaml_path, \"r\") as yaml_file:\n data = yaml.safe_load(yaml_file)\n return cls.from_dict(data)\n
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig.to_dict","title":"to_dict()
","text":"Convert the instance to a dictionary.
Source code in src/aeiva/config/base_config.py
def to_dict(self):\n \"\"\"\n Convert the instance to a dictionary.\n \"\"\"\n return {k: v for k, v in self.__dict__.items() if not k.startswith('_')}\n
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig.to_json","title":"to_json(filepath)
","text":"Convert the instance to a JSON file.
Source code in src/aeiva/config/base_config.py
def to_json(self, filepath: str):\n \"\"\"\n Convert the instance to a JSON file.\n \"\"\"\n with open(filepath, 'w') as json_file:\n json.dump(self.to_dict(), json_file, indent=4)\n
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig.to_yaml","title":"to_yaml(filepath)
","text":"Convert the instance to a YAML file.
Source code in src/aeiva/config/base_config.py
def to_yaml(self, filepath: str):\n \"\"\"\n Convert the instance to a YAML file.\n \"\"\"\n with open(filepath, 'w') as yaml_file:\n yaml.dump(self.to_dict(), yaml_file)\n
"},{"location":"reference/#src.aeiva.config.custom_configs","title":"custom_configs
","text":""},{"location":"reference/#src.aeiva.config.custom_configs.macaw_config","title":"macaw_config
","text":"This module contains the config for macaw model.
We can define separate config classes for different specific models/datasets/tasks.
Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.
"},{"location":"reference/#src.aeiva.config.custom_configs.macaw_config.MacawConfig","title":"MacawConfig
dataclass
","text":" Bases: BaseConfig
Define user-customized config here.
Source code in src/aeiva/config/custom_configs/macaw_config.py
@dataclass\nclass MacawConfig(BaseConfig):\n \"\"\"\n Define user-customized config here.\n \"\"\"\n image_dir: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The directory of image data\"}\n )\n video_dir: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The directory of video data\"}\n )\n frame_dir: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The directory to save video frames\"}\n )\n audio_dir: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The directory to save video audios\"}\n )\n num_frames_to_sample: Optional[int] = field(\n default=120,\n metadata={\"help\": \"The number of frames to sample from a video\"}\n )\n num_frames_to_load: Optional[int] = field(\n default=6,\n metadata={\"help\": \"The number of frames to load as a part of model inputs\"}\n )\n num_samples_per_dataset: Optional[int] = field(\n default=100,\n metadata={\"help\": \"The number of samples to load from each dataset\"}\n )\n num_samples_per_merged_dataset: Optional[int] = field(\n default=20,\n metadata={\"help\": \"The number of samples to save after merging datasets\"}\n )\n batch_size: Optional[int] = field(\n default=1,\n metadata={\"help\": \"The batch size of model inputs\"}\n )\n max_seq_len_for_preprocess: Optional[int] = field(\n default=256,\n metadata={\"help\": \"The maximum sequence length for preprocess\"}\n )\n run_time_cache_dir: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The directory to save running time data, such as video frames, audios, and so on.\"}\n )\n tokenizer_name_or_path: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The name or path of tokenizer\"}\n )\n clip_model_name_or_path: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The name or path of clip model\"}\n )\n whisper_model_name_or_path: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The name or path of whisper model\"}\n )\n llama7b_model_name_or_path: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The name or path of llama7b model\"}\n )\n macaw_model_name_or_path: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The name or path of macaw model\"}\n )\n mode: Optional[str] = field(\n default=\"train\",\n metadata={\"help\": \"The mode of train, eval, or inference\"}\n )\n model_name: Optional[str] = field(\n default=\"macaw\",\n metadata={\"help\": \"The name of model\"}\n )\n resource_ready: Optional[bool] = field(\n default=True,\n metadata={\"help\": \"Whether the pre-requisite resource is ready, e.g., download pretrained models and datasets\"}\n )\n
"},{"location":"reference/#src.aeiva.config.general_configs","title":"general_configs
","text":"This module contains some general config classes that can be used in deep learning projects.
E.g., data config, model config, trainer config, etc.
Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.
"},{"location":"reference/#src.aeiva.config.general_configs.DataConfig","title":"DataConfig
dataclass
","text":" Bases: BaseConfig
This class contains the data configuration.
Source code in src/aeiva/config/general_configs.py
@dataclass\nclass DataConfig(BaseConfig):\n \"\"\"This class contains the data configuration.\"\"\"\n dataset_path: Optional[str] = field(\n default=None, metadata={\"help\": \"The path of the dataset to use.\"}\n )\n dataset_name: Optional[str] = field(\n default=\"customized\", metadata={\"help\": \"Should be \\\"customized\\\"\"}\n )\n is_custom_dataset: Optional[bool] = field(\n default=False, metadata={\"help\": \"whether to use custom data\"}\n )\n customized_cache_dir: Optional[str] = field(\n default=\".cache/llm-ft/datasets\",\n metadata={\"help\": \"Where do you want to store the customized dataset caches\"},\n )\n dataset_config_name: Optional[str] = field(\n default=None, metadata={\"help\": \"The configuration name of the dataset to use (via the datasets library).\"}\n )\n train_file: Optional[str] = field(default=None, metadata={\"help\": \"The input training data file (a text file).\"})\n validation_file: Optional[str] = field(\n default=None,\n metadata={\"help\": \"An optional input evaluation data file to evaluate the perplexity on (a text file).\"},\n )\n max_train_samples: Optional[int] = field(\n default=None,\n metadata={\n \"help\": (\n \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n \"value if set.\"\n )\n },\n )\n max_eval_samples: Optional[int] = field(\n default=1e10,\n metadata={\n \"help\": (\n \"For debugging purposes or quicker training, truncate the number of evaluation examples to this \"\n \"value if set.\"\n )\n },\n )\n streaming: Optional[bool] = field(default=False, metadata={\"help\": \"Enable streaming mode\"})\n block_size: Optional[int] = field(\n default=512,\n metadata={\n \"help\": (\n \"Optional input sequence length after tokenization. \"\n \"The training dataset will be truncated in block of this size for training. \"\n \"Default to the model max input length for single sentence inputs (take into account special tokens).\"\n )\n },\n )\n overwrite_cache: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"Overwrite the cached training and evaluation sets\"}\n )\n validation_split_percentage: Optional[int] = field(\n default=5,\n metadata={\n \"help\": \"The percentage of the train set used as validation set in case there's no validation split\"\n },\n )\n preprocessing_num_workers: Optional[int] = field(\n default=None,\n metadata={\"help\": \"The number of processes to use for the preprocessing.\"},\n )\n group_texts_batch_size: Optional[int] = field(\n default=1000,\n metadata={\n \"help\": (\n \"Number of samples that will be grouped together to go though\"\n \" `group_texts` operation. See `--disable_group_texts` for\"\n \" detailed explanation of this operation.\"\n )\n }\n )\n disable_group_texts: Optional[bool] = field(\n default=False,\n metadata={\n \"help\": (\n \"Whether we group original samples together to generate sample\"\n \" sequences of length `block_size`. By default, we group every\"\n \" 1000 tokenized sequences together, divide them into \"\n \" [{total_num_tokens} / {block_size}] sequences, each with\"\n \" `block_size` tokens (the remaining tokens are ommited.\"\n \" If this flag is set to True, we only group 1 tokenized\"\n \" sequence, i.e. cutting long sequence into chunks.\"\n )\n },\n )\n keep_linebreaks: Optional[bool] = field(\n default=True,\n metadata={\"help\": \"Whether to keep line breaks when using TXT files or not.\"}\n )\n test_file: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Evaluation File Path\"},\n )\n\n def __post_init__(self):\n if self.streaming:\n require_version(\"datasets>=2.0.0\", \"The streaming feature requires `datasets>=2.0.0`\")\n\n if self.dataset_name is None and self.train_file is None and self.validation_file is None:\n raise ValueError(\"Need either a dataset name or a training/validation file.\")\n else:\n if self.train_file is not None:\n extension = self.train_file.split(\".\")[-1]\n assert extension in [\"csv\", \"json\", \"txt\"], \"`train_file` should be a csv, a json or a txt file.\"\n if self.validation_file is not None:\n extension = self.validation_file.split(\".\")[-1]\n assert extension in [\"csv\", \"json\", \"txt\"], \"`validation_file` should be a csv, a json or a txt file.\"\n
"},{"location":"reference/#src.aeiva.config.general_configs.ExplicitEnum","title":"ExplicitEnum
","text":" Bases: str
, Enum
Enum with more explicit error message for missing values.
Source code in src/aeiva/config/general_configs.py
class ExplicitEnum(str, Enum):\n \"\"\"\n Enum with more explicit error message for missing values.\n \"\"\"\n @classmethod\n def _missing_(cls, value):\n raise ValueError(\n f\"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}\"\n )\n
"},{"location":"reference/#src.aeiva.config.general_configs.ModelConfig","title":"ModelConfig
dataclass
","text":" Bases: BaseConfig
Model configuration class.
Source code in src/aeiva/config/general_configs.py
@dataclass\nclass ModelConfig(BaseConfig):\n \"\"\"Model configuration class.\"\"\"\n model_name_or_path: Optional[str] = field(\n default=None,\n metadata={\n \"help\": (\n \"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch.\"\n )\n },\n )\n lora_model_path: Optional[str] = field(\n default=None,\n metadata={\n \"help\": (\n \"The incremental model diff introduced by LoRA finetuning.\"\n \" Along with the original non-finetuned model forms the whole\"\n \" finetuned model.\"\n )\n }\n )\n model_type: Optional[str] = field(\n default=None,\n metadata={\"help\": \"If training from scratch, pass a model type from the list: \" + \", \".join(MODEL_TYPES)},\n )\n arch_type: Optional[str] = field(\n default=\"decoder_only\",\n metadata={\"help\": \"The architecture type of the model. Currently supported decoder_only or encoder_decoder\"}\n )\n config_overrides: Optional[str] = field(\n default=None,\n metadata={\n \"help\": (\n \"Override some existing default config settings when a model is trained from scratch. Example: \"\n \"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index\"\n )\n },\n )\n arch_type: Optional[str] = field(\n default=\"decoder_only\",\n metadata={\n \"help\": (\n \"Model architecture type, e.g. \\\"decoder_only\\\",\"\n \" \\\"encoder_decoder\\\"\"\n ),\n \"choices\": [\"decoder_only\", \"encoder_decoder\", \"text_regression\", \"vision_encoder_decoder\"],\n },\n )\n config_name: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Pretrained config name or path if not the same as model_name\"}\n )\n tokenizer_name: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Pretrained tokenizer name or path if not the same as model_name\"}\n )\n cache_dir: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Where do you want to store the pretrained models downloaded from huggingface.co\"},\n )\n use_fast_tokenizer: Optional[bool] = field(\n default=True,\n metadata={\"help\": \"Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.\"},\n )\n model_revision: Optional[str] = field(\n default=\"main\",\n metadata={\"help\": \"The specific model version to use (can be a branch name, tag name or commit id).\"},\n )\n use_auth_token: Optional[bool] = field(\n default=False,\n metadata={\n \"help\": (\n \"Will use the token generated when running `huggingface-cli login` (necessary to use this script \"\n \"with private models).\"\n )\n },\n )\n torch_dtype: Optional[str] = field(\n default=None,\n metadata={\n \"help\": (\n \"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the \"\n \"dtype will be automatically derived from the model's weights.\"\n ),\n \"choices\": [\"auto\", \"bfloat16\", \"float16\", \"float32\"],\n },\n )\n use_lora: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"Whether to lora.\"},\n )\n lora_r: Optional[int] = field(\n default=8,\n metadata={\"help\": \"the rank of the lora parameters. The smaller lora_r is , the fewer parameters lora has.\"},\n )\n lora_alpha: Optional[int] = field(\n default=32,\n metadata={\"help\": \"Merging ratio between the fine-tuned model and the original. This is controlled by a parameter called alpha in the paper.\"},\n )\n lora_target_modules: Optional[list[str]] = field(\n default=None,\n metadata={\"help\": \"Pretrained config name or path if not the same as model_name\",\n }\n )\n lora_dropout: Optional[float] = field(\n default=0.1,\n metadata={\"help\": \"The dropout rate in lora.linear.\"},\n )\n save_aggregated_lora: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"Whether to save aggregated lora.\"},\n )\n use_ram_optimized_load: Optional[bool] = field(\n default=True,\n metadata={\"help\": \"Whether use disk mapping when memory is not enough.\"}\n )\n use_flash_attention: Optional[bool] = field(\n default=False,\n metadata={\n \"help\": (\n \"whether use flash attention layer to reduce GPU memory with\"\n \" higher time cost.\"\n )\n }\n )\n use_int8: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"whether to load int8 quantization for inference\"}\n )\n custom_model: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"flag for the model from huggingface or not\"}\n )\n # below is added for macaw model\n n_frames: Optional[int] = field(\n default=6,\n metadata={\n \"help\": \"The number of frames for encoding a video.\"\n },\n )\n attention_heads: Optional[int] = field(\n default=220,\n metadata={\n \"help\": \"The number of attention heads used in multi-head-attention.\"\n },\n )\n image_conv_kernel: Optional[int] = field(\n default=48,\n metadata={\n \"help\": \"The size of the convolutional kernel for the image stream.\"\n },\n )\n image_conv_stride: Optional[int] = field(\n default=36,\n metadata={\n \"help\": \"The stride of the convolutional kernel for the image stream.\"\n },\n )\n video_conv_kernel: Optional[int] = field(\n default=36,\n metadata={\n \"help\": \"The size of the convolutional kernel for the video stream.\"\n },\n )\n video_conv_stride: Optional[int] = field(\n default=30,\n metadata={\n \"help\": \"The stride of the convolutional kernel for the video stream.\"\n },\n )\n audio_conv_kernel: Optional[int] = field(\n default=240,\n metadata={\n \"help\": \"The size of the convolutional kernel for the audio stream.\"\n },\n )\n audio_conv_stride: Optional[int] = field(\n default=220,\n metadata={\n \"help\": \"The stride of the convolutional kernel for the audio stream.\"\n },\n )\n freeze_multi_modal_encoder: bool = field(\n default=False,\n metadata={\n \"help\": (\n \"Whether to freeze the parameters of multi-modal encoders during training.).\"\n )\n },\n )\n\n def __post_init__(self):\n if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):\n raise ValueError(\n \"--config_overrides can't be used in combination with --config_name or --model_name_or_path\"\n )\n
"},{"location":"reference/#src.aeiva.config.general_configs.OptimizerNames","title":"OptimizerNames
","text":" Bases: ExplicitEnum
Stores the acceptable string identifiers for optimizers.
Source code in src/aeiva/config/general_configs.py
class OptimizerNames(ExplicitEnum):\n \"\"\"\n Stores the acceptable string identifiers for optimizers.\n \"\"\"\n ADAMW_HF = \"adamw_hf\"\n ADAMW_TORCH = \"adamw_torch\"\n ADAMW_TORCH_FUSED = \"adamw_torch_fused\"\n ADAMW_TORCH_XLA = \"adamw_torch_xla\"\n ADAMW_APEX_FUSED = \"adamw_apex_fused\"\n ADAFACTOR = \"adafactor\"\n ADAMW_ANYPRECISION = \"adamw_anyprecision\"\n SGD = \"sgd\"\n ADAGRAD = \"adagrad\"\n ADAMW_BNB = \"adamw_bnb_8bit\"\n ADAMW_8BIT = \"adamw_8bit\" # just an alias for adamw_bnb_8bit\n LION_8BIT = \"lion_8bit\"\n LION = \"lion_32bit\"\n PAGED_ADAMW = \"paged_adamw_32bit\"\n PAGED_ADAMW_8BIT = \"paged_adamw_8bit\"\n PAGED_LION = \"paged_lion_32bit\"\n PAGED_LION_8BIT = \"paged_lion_8bit\"\n
"},{"location":"reference/#src.aeiva.config.omni_config","title":"omni_config
","text":"This module contains the OmniConfig classes.
We can define separate config classes for different modules, e.g., data, model, trainer, etc. The OmniConfig class is the combination of all config classes. It can also accept command line arguments to update the config values.
Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.
"},{"location":"reference/#src.aeiva.config.omni_config.OmniConfig","title":"OmniConfig
dataclass
","text":" Bases: BaseConfig
Source code in src/aeiva/config/omni_config.py
@dataclass\nclass OmniConfig(BaseConfig):\n @staticmethod\n def create_omni_config():\n \"\"\"\n Initializes OmniConfig by aggregating all configuration classes.\n \"\"\"\n # Aggregating default values from all config classes\n defaults = {}\n for config_class_name, config_class in BaseConfig.subclasses.items():\n if config_class_name == \"OmniConfig\":\n continue\n for field_name, field_obj in config_class.__dataclass_fields__.items():\n if field_name in defaults:\n raise ValueError(f\"Overlapping config argument: '{field_name}' found in {config_class.__name__}\")\n default_value = getattr(config_class(), field_name, None)\n defaults[field_name] = default_value\n\n def __init__(self, **kwargs):\n for key, default_value in defaults.items():\n setattr(self, key, kwargs.get(key, default_value))\n\n OmniConfig.__init__ = __init__\n return OmniConfig\n\n def update_from_args(self, namespace_args: argparse.Namespace):\n \"\"\"\n Updates the configuration based on parsed command-line arguments.\n \"\"\"\n for key, value in vars(namespace_args).items():\n if hasattr(self, key) and value is not None:\n setattr(self, key, value)\n\n def get_argparse_parser(self):\n \"\"\"\n Creates an argument parser that can handle complex types.\n \"\"\"\n parser = argparse.ArgumentParser()\n for config_class_name, config_class in BaseConfig.subclasses.items():\n if config_class_name == \"OmniConfig\":\n continue\n for field_name, field_obj in config_class.__dataclass_fields__.items():\n field_type = field_obj.type\n\n # Handle Optional types\n if get_origin(field_type) is Union and type(None) in get_args(field_type):\n field_type = next(arg for arg in get_args(field_type) if arg is not type(None))\n\n arg_name = '--' + field_name\n help_msg = field_obj.metadata.get(\"help\", f\"{field_name} ({field_type})\")\n\n origin = get_origin(field_type)\n args = get_args(field_type)\n\n # Handle Enums\n if isinstance(field_type, type) and issubclass(field_type, enum.Enum):\n choices = [item.value for item in field_type]\n parser.add_argument(arg_name, type=str, choices=choices, help=help_msg)\n continue\n\n # Handle list types\n if origin is list:\n item_type = args[0]\n if item_type is str:\n parser.add_argument(arg_name, nargs='+', type=str, help=help_msg)\n elif item_type is int:\n parser.add_argument(arg_name, nargs='+', type=int, help=help_msg)\n else:\n # Default to strings if item type is not specifically handled\n parser.add_argument(arg_name, nargs='+', type=str, help=help_msg)\n continue\n\n # Handle tuple types\n if origin is tuple:\n # Accept comma-separated values and convert to tuple\n def tuple_type(s):\n try:\n return tuple(map(int, s.split(',')))\n except ValueError:\n raise argparse.ArgumentTypeError(\"Tuples must be comma-separated integers.\")\n\n parser.add_argument(arg_name, type=tuple_type, help=help_msg)\n continue\n\n # Handle dict types\n if origin is dict:\n # Expect JSON string\n def dict_type(s):\n try:\n return json.loads(s)\n except json.JSONDecodeError:\n raise argparse.ArgumentTypeError(\"Dictionaries must be valid JSON strings.\")\n\n parser.add_argument(arg_name, type=dict_type, help=help_msg)\n continue\n\n # Handle basic types\n if field_type is int:\n parser.add_argument(arg_name, type=int, help=help_msg)\n elif field_type is float:\n parser.add_argument(arg_name, type=float, help=help_msg)\n elif field_type is str:\n parser.add_argument(arg_name, type=str, help=help_msg)\n elif field_type is bool:\n parser.add_argument(arg_name, action='store_true', help=help_msg)\n else:\n print(f\"Warning: unsupported type {field_type} for field '{field_name}'\")\n return parser\n
"},{"location":"reference/#src.aeiva.config.omni_config.OmniConfig.create_omni_config","title":"create_omni_config()
staticmethod
","text":"Initializes OmniConfig by aggregating all configuration classes.
Source code in src/aeiva/config/omni_config.py
@staticmethod\ndef create_omni_config():\n \"\"\"\n Initializes OmniConfig by aggregating all configuration classes.\n \"\"\"\n # Aggregating default values from all config classes\n defaults = {}\n for config_class_name, config_class in BaseConfig.subclasses.items():\n if config_class_name == \"OmniConfig\":\n continue\n for field_name, field_obj in config_class.__dataclass_fields__.items():\n if field_name in defaults:\n raise ValueError(f\"Overlapping config argument: '{field_name}' found in {config_class.__name__}\")\n default_value = getattr(config_class(), field_name, None)\n defaults[field_name] = default_value\n\n def __init__(self, **kwargs):\n for key, default_value in defaults.items():\n setattr(self, key, kwargs.get(key, default_value))\n\n OmniConfig.__init__ = __init__\n return OmniConfig\n
"},{"location":"reference/#src.aeiva.config.omni_config.OmniConfig.get_argparse_parser","title":"get_argparse_parser()
","text":"Creates an argument parser that can handle complex types.
Source code in src/aeiva/config/omni_config.py
def get_argparse_parser(self):\n \"\"\"\n Creates an argument parser that can handle complex types.\n \"\"\"\n parser = argparse.ArgumentParser()\n for config_class_name, config_class in BaseConfig.subclasses.items():\n if config_class_name == \"OmniConfig\":\n continue\n for field_name, field_obj in config_class.__dataclass_fields__.items():\n field_type = field_obj.type\n\n # Handle Optional types\n if get_origin(field_type) is Union and type(None) in get_args(field_type):\n field_type = next(arg for arg in get_args(field_type) if arg is not type(None))\n\n arg_name = '--' + field_name\n help_msg = field_obj.metadata.get(\"help\", f\"{field_name} ({field_type})\")\n\n origin = get_origin(field_type)\n args = get_args(field_type)\n\n # Handle Enums\n if isinstance(field_type, type) and issubclass(field_type, enum.Enum):\n choices = [item.value for item in field_type]\n parser.add_argument(arg_name, type=str, choices=choices, help=help_msg)\n continue\n\n # Handle list types\n if origin is list:\n item_type = args[0]\n if item_type is str:\n parser.add_argument(arg_name, nargs='+', type=str, help=help_msg)\n elif item_type is int:\n parser.add_argument(arg_name, nargs='+', type=int, help=help_msg)\n else:\n # Default to strings if item type is not specifically handled\n parser.add_argument(arg_name, nargs='+', type=str, help=help_msg)\n continue\n\n # Handle tuple types\n if origin is tuple:\n # Accept comma-separated values and convert to tuple\n def tuple_type(s):\n try:\n return tuple(map(int, s.split(',')))\n except ValueError:\n raise argparse.ArgumentTypeError(\"Tuples must be comma-separated integers.\")\n\n parser.add_argument(arg_name, type=tuple_type, help=help_msg)\n continue\n\n # Handle dict types\n if origin is dict:\n # Expect JSON string\n def dict_type(s):\n try:\n return json.loads(s)\n except json.JSONDecodeError:\n raise argparse.ArgumentTypeError(\"Dictionaries must be valid JSON strings.\")\n\n parser.add_argument(arg_name, type=dict_type, help=help_msg)\n continue\n\n # Handle basic types\n if field_type is int:\n parser.add_argument(arg_name, type=int, help=help_msg)\n elif field_type is float:\n parser.add_argument(arg_name, type=float, help=help_msg)\n elif field_type is str:\n parser.add_argument(arg_name, type=str, help=help_msg)\n elif field_type is bool:\n parser.add_argument(arg_name, action='store_true', help=help_msg)\n else:\n print(f\"Warning: unsupported type {field_type} for field '{field_name}'\")\n return parser\n
"},{"location":"reference/#src.aeiva.config.omni_config.OmniConfig.update_from_args","title":"update_from_args(namespace_args)
","text":"Updates the configuration based on parsed command-line arguments.
Source code in src/aeiva/config/omni_config.py
def update_from_args(self, namespace_args: argparse.Namespace):\n \"\"\"\n Updates the configuration based on parsed command-line arguments.\n \"\"\"\n for key, value in vars(namespace_args).items():\n if hasattr(self, key) and value is not None:\n setattr(self, key, value)\n
"},{"location":"reference/#src.aeiva.data","title":"data
","text":""},{"location":"reference/#src.aeiva.data.processor","title":"processor
","text":"This module contains the data processor.
@Author: Bang Liu (chatsci.ai@gmail.com) @Date: 2023-07-11
Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.
"},{"location":"reference/#src.aeiva.data.processor.process_dataset","title":"process_dataset(formatted_dataset, pipeline, output_dir, dataset_name='')
","text":"Process a dataset with a pipeline of functions.
Parameters:
Name Type Description Default formatted_dataset
DataSet
the dataset to be processed.
required pipeline
list[Callable]
a list of functions to be applied to the dataset.
required output_dir
Optional[str]
the output directory to save the processed dataset.
required dataset_name
Optional[str]
the name of the dataset. Defaults to \"\".
''
Returns:
Name Type Description DataSet
DataSet
the processed dataset.
Source code in src/aeiva/data/processor.py
def process_dataset(formatted_dataset: DataSet,\n pipeline: list[Callable],\n output_dir: Optional[str],\n dataset_name: Optional[str] = \"\") -> DataSet:\n \"\"\"\n Process a dataset with a pipeline of functions.\n\n Args:\n formatted_dataset (DataSet): the dataset to be processed.\n pipeline (list[Callable]): a list of functions to be applied to the dataset.\n output_dir (Optional[str]): the output directory to save the processed dataset.\n dataset_name (Optional[str], optional): the name of the dataset. Defaults to \"\".\n\n Returns:\n DataSet: the processed dataset.\n \"\"\"\n processed_data = []\n pipeline = Pipeline(pipeline)\n for item in formatted_dataset[\"data\"]:\n processed_data.append(pipeline(item.copy()))\n\n output = {\"data\": processed_data, \"metadata\": formatted_dataset[\"metadata\"]}\n if output_dir is not None:\n ensure_dir(output_dir)\n dump_json(output, f\"{output_dir}/{dataset_name}_dataset.processed.json\")\n return output\n
"},{"location":"reference/#src.aeiva.demo","title":"demo
","text":""},{"location":"reference/#src.aeiva.demo.chat_gradio","title":"chat_gradio
","text":""},{"location":"reference/#src.aeiva.demo.chat_gradio.bot","title":"bot(user_input, history)
async
","text":"Handles chatbot logic by emitting perception.gradio events to the Agent and retrieving responses.
Source code in src/aeiva/demo/chat_gradio.py
async def bot(user_input, history):\n \"\"\"\n Handles chatbot logic by emitting perception.gradio events to the Agent and retrieving responses.\n \"\"\"\n if agent is None:\n logger.error(\"Agent is not initialized.\")\n history.append({\"role\": \"assistant\", \"content\": \"Agent is not initialized.\"})\n yield history, ''\n return\n\n try:\n # Append user's message to history\n history.append({\"role\": \"user\", \"content\": user_input})\n # Append an empty assistant response\n history.append({\"role\": \"assistant\", \"content\": \"\"})\n yield history, '' # Display the user's message\n logger.info(f\"User input appended to history: {user_input}\")\n\n stream = config_dict[\"llm_gateway_config\"][\"llm_stream\"]\n use_async = config_dict[\"llm_gateway_config\"][\"llm_use_async\"]\n\n # Emit the 'perception.gradio' event with stream=True\n emit_future = asyncio.run_coroutine_threadsafe(\n agent.event_bus.emit('perception.gradio', payload=user_input), # TODO: maybe simplify payload, Agent can directly read stream and use_async from config.\n agent.event_bus.loop\n )\n emit_future.result() # Ensure the event is emitted\n logger.info(f\"Emitted 'perception.gradio' event with payload: {user_input} | Stream: {stream}\")\n\n assistant_message = ''\n if stream:\n while True:\n try:\n # Non-blocking response retrieval from the thread-safe queue with timeout\n response = await asyncio.wait_for(\n asyncio.to_thread(response_queue.get, True, 30),\n timeout=30\n )\n logger.info(f\"Retrieved response from queue: {response}\")\n if response == \"<END_OF_RESPONSE>\":\n logger.info(\"Received end of response signal.\")\n break\n assistant_message += response\n # Create a new history list to ensure Gradio detects the update\n new_history = history.copy()\n new_history[-1][\"content\"] = assistant_message\n logger.info(f\"Yielding updated history: {new_history}\")\n yield new_history, ''\n except asyncio.TimeoutError:\n logger.warning(\"Timeout: No response received from Agent.\")\n # Create a new history list to ensure Gradio detects the update\n new_history = history.copy()\n new_history[-1][\"content\"] = \"I'm sorry, I didn't receive a response in time.\"\n yield new_history, ''\n break\n else:\n try:\n # Non-blocking response retrieval from the thread-safe queue with timeout\n response = await asyncio.wait_for(\n asyncio.to_thread(response_queue.get, True, 30),\n timeout=30\n )\n logger.info(f\"Retrieved response from queue: {response}\")\n assistant_message += response\n # Create a new history list to ensure Gradio detects the update\n new_history = history.copy()\n new_history[-1][\"content\"] = assistant_message\n logger.info(f\"Yielding updated history: {new_history}\")\n yield new_history, ''\n except asyncio.TimeoutError:\n logger.warning(\"Timeout: No response received from Agent.\")\n # Create a new history list to ensure Gradio detects the update\n new_history = history.copy()\n new_history[-1][\"content\"] = \"I'm sorry, I didn't receive a response in time.\"\n yield new_history, ''\n\n except Exception as e:\n logger.error(f\"Unexpected Error in bot function: {e}\")\n # Create a new history list to ensure Gradio detects the update\n new_history = history.copy()\n new_history[-1][\"content\"] = \"An unexpected error occurred.\"\n yield new_history, ''\n
"},{"location":"reference/#src.aeiva.demo.chat_gradio.clear_media","title":"clear_media()
","text":"Clears the uploaded media paths.
Source code in src/aeiva/demo/chat_gradio.py
def clear_media():\n \"\"\"\n Clears the uploaded media paths.\n \"\"\"\n # Implement any necessary logic to clear media paths or data\n logger.info(\"Cleared uploaded media paths.\")\n return \"\"\n
"},{"location":"reference/#src.aeiva.demo.chat_gradio.handle_upload","title":"handle_upload(file)
","text":"Handles file uploads and delegates to specific handlers based on file type.
Parameters:
Name Type Description Default file
Uploaded file object.
required Returns:
Name Type Description str
Message indicating the upload status.
Source code in src/aeiva/demo/chat_gradio.py
def handle_upload(file):\n \"\"\"\n Handles file uploads and delegates to specific handlers based on file type.\n\n Args:\n file: Uploaded file object.\n\n Returns:\n str: Message indicating the upload status.\n \"\"\"\n if file is None:\n return \"\"\n if file.type.startswith(\"image\"):\n return handle_image_upload(file)\n elif file.type.startswith(\"video\"):\n return handle_video_upload(file)\n elif file.type.startswith(\"audio\"):\n return handle_audio_upload(file)\n else:\n logger.warning(f\"Unsupported file type uploaded: {file.type}\")\n return \"Unsupported file type uploaded.\"\n
"},{"location":"reference/#src.aeiva.demo.mm_chatbot","title":"mm_chatbot
","text":"This module defines a multimodal chatbot demo with gradio.
@Author: Bang Liu (chatsci.ai@gmail.com) @Date: 2023-07-13
Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.
"},{"location":"reference/#src.aeiva.environment","title":"environment
","text":""},{"location":"reference/#src.aeiva.environment.environment","title":"environment
","text":""},{"location":"reference/#src.aeiva.environment.environment.Environment","title":"Environment
","text":" Bases: ABC
Abstract base class for an environment in which an intelligent agent operates.
Each environment provides context, defines interactions, and manages its own state. Subclasses should implement specific methods for different types of environments.
Attributes:
Name Type Description config
EnvironmentConfig
Configuration settings for the environment.
state
Any
Current state of the environment, initialized from the config.
entities
List[Any]
Entities present within the environment.
constraints
Dict[str, Any]
Rules or limitations for interactions in the environment.
time
Optional[int]
Time progression within the environment, if enabled.
Source code in src/aeiva/environment/environment.py
class Environment(ABC):\n \"\"\"\n Abstract base class for an environment in which an intelligent agent operates.\n\n Each environment provides context, defines interactions, and manages its own state.\n Subclasses should implement specific methods for different types of environments.\n\n Attributes:\n config (EnvironmentConfig): Configuration settings for the environment.\n state (Any): Current state of the environment, initialized from the config.\n entities (List[Any]): Entities present within the environment.\n constraints (Dict[str, Any]): Rules or limitations for interactions in the environment.\n time (Optional[int]): Time progression within the environment, if enabled.\n \"\"\"\n\n def __init__(self, config: EnvironmentConfig):\n \"\"\"\n Initialize the environment with a given configuration.\n\n Args:\n config (EnvironmentConfig): Configuration settings for the environment.\n \"\"\"\n self.config = config\n self.state = config.initial_state\n self.entities = config.entities\n self.constraints = config.constraints\n self.time = 0 if config.time_enabled else None\n self.setup()\n\n @abstractmethod\n def setup(self):\n \"\"\"\n Set up the environment based on its configuration.\n Subclasses should define any initialization logic here.\n \"\"\"\n pass\n\n @abstractmethod\n def reset(self):\n \"\"\"\n Reset the environment to its initial state as defined by the configuration.\n \"\"\"\n self.state = self.config.initial_state\n self.time = 0 if self.config.time_enabled else None\n\n @abstractmethod\n def step(self, actions: Dict[Any, Any]):\n \"\"\"\n Advance the environment by one step based on actions taken by agents.\n\n Args:\n actions (Dict[Any, Any]): A dictionary of actions performed by agents.\n \"\"\"\n pass\n\n @abstractmethod\n def observe(self, agent: Any) -> Any:\n \"\"\"\n Provide observations to an agent based on the current state.\n\n Args:\n agent (Any): The agent requesting observation.\n\n Returns:\n Any: Observation data formatted according to the agent's perception capabilities.\n \"\"\"\n pass\n\n @abstractmethod\n def act(self, action: Any, target: Optional[Any] = None):\n \"\"\"\n Execute an action in the environment, potentially modifying its state.\n\n Args:\n action (Any): The action to be executed.\n target (Optional[Any]): Target entity for the action, if applicable.\n \"\"\"\n pass\n\n def render(self):\n \"\"\"\n Visualize or output the environment's current state. Optional for subclasses.\n \"\"\"\n print(f\"Environment State: {self.state}\")\n\n def get_context(self) -> Any:\n \"\"\"\n Retrieve relevant context information from the environment, useful for agent processing.\n\n Returns:\n Any: Contextual data or state relevant to the agent's tasks.\n \"\"\"\n return self.state\n\n def close(self):\n \"\"\"\n Clean up any resources tied to the environment when it's no longer needed.\n \"\"\"\n print(\"Closing environment and releasing resources.\")\n\n def __repr__(self) -> str:\n return (f\"Environment(type={self.config.environment_type}, \"\n f\"state={self.state}, \"\n f\"entities={self.entities}, \"\n f\"time={self.time}, \"\n f\"constraints={self.constraints})\")\n
"},{"location":"reference/#src.aeiva.environment.environment.Environment.__init__","title":"__init__(config)
","text":"Initialize the environment with a given configuration.
Parameters:
Name Type Description Default config
EnvironmentConfig
Configuration settings for the environment.
required Source code in src/aeiva/environment/environment.py
def __init__(self, config: EnvironmentConfig):\n \"\"\"\n Initialize the environment with a given configuration.\n\n Args:\n config (EnvironmentConfig): Configuration settings for the environment.\n \"\"\"\n self.config = config\n self.state = config.initial_state\n self.entities = config.entities\n self.constraints = config.constraints\n self.time = 0 if config.time_enabled else None\n self.setup()\n
"},{"location":"reference/#src.aeiva.environment.environment.Environment.act","title":"act(action, target=None)
abstractmethod
","text":"Execute an action in the environment, potentially modifying its state.
Parameters:
Name Type Description Default action
Any
The action to be executed.
required target
Optional[Any]
Target entity for the action, if applicable.
None
Source code in src/aeiva/environment/environment.py
@abstractmethod\ndef act(self, action: Any, target: Optional[Any] = None):\n \"\"\"\n Execute an action in the environment, potentially modifying its state.\n\n Args:\n action (Any): The action to be executed.\n target (Optional[Any]): Target entity for the action, if applicable.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.environment.environment.Environment.close","title":"close()
","text":"Clean up any resources tied to the environment when it's no longer needed.
Source code in src/aeiva/environment/environment.py
def close(self):\n \"\"\"\n Clean up any resources tied to the environment when it's no longer needed.\n \"\"\"\n print(\"Closing environment and releasing resources.\")\n
"},{"location":"reference/#src.aeiva.environment.environment.Environment.get_context","title":"get_context()
","text":"Retrieve relevant context information from the environment, useful for agent processing.
Returns:
Name Type Description Any
Any
Contextual data or state relevant to the agent's tasks.
Source code in src/aeiva/environment/environment.py
def get_context(self) -> Any:\n \"\"\"\n Retrieve relevant context information from the environment, useful for agent processing.\n\n Returns:\n Any: Contextual data or state relevant to the agent's tasks.\n \"\"\"\n return self.state\n
"},{"location":"reference/#src.aeiva.environment.environment.Environment.observe","title":"observe(agent)
abstractmethod
","text":"Provide observations to an agent based on the current state.
Parameters:
Name Type Description Default agent
Any
The agent requesting observation.
required Returns:
Name Type Description Any
Any
Observation data formatted according to the agent's perception capabilities.
Source code in src/aeiva/environment/environment.py
@abstractmethod\ndef observe(self, agent: Any) -> Any:\n \"\"\"\n Provide observations to an agent based on the current state.\n\n Args:\n agent (Any): The agent requesting observation.\n\n Returns:\n Any: Observation data formatted according to the agent's perception capabilities.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.environment.environment.Environment.render","title":"render()
","text":"Visualize or output the environment's current state. Optional for subclasses.
Source code in src/aeiva/environment/environment.py
def render(self):\n \"\"\"\n Visualize or output the environment's current state. Optional for subclasses.\n \"\"\"\n print(f\"Environment State: {self.state}\")\n
"},{"location":"reference/#src.aeiva.environment.environment.Environment.reset","title":"reset()
abstractmethod
","text":"Reset the environment to its initial state as defined by the configuration.
Source code in src/aeiva/environment/environment.py
@abstractmethod\ndef reset(self):\n \"\"\"\n Reset the environment to its initial state as defined by the configuration.\n \"\"\"\n self.state = self.config.initial_state\n self.time = 0 if self.config.time_enabled else None\n
"},{"location":"reference/#src.aeiva.environment.environment.Environment.setup","title":"setup()
abstractmethod
","text":"Set up the environment based on its configuration. Subclasses should define any initialization logic here.
Source code in src/aeiva/environment/environment.py
@abstractmethod\ndef setup(self):\n \"\"\"\n Set up the environment based on its configuration.\n Subclasses should define any initialization logic here.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.environment.environment.Environment.step","title":"step(actions)
abstractmethod
","text":"Advance the environment by one step based on actions taken by agents.
Parameters:
Name Type Description Default actions
Dict[Any, Any]
A dictionary of actions performed by agents.
required Source code in src/aeiva/environment/environment.py
@abstractmethod\ndef step(self, actions: Dict[Any, Any]):\n \"\"\"\n Advance the environment by one step based on actions taken by agents.\n\n Args:\n actions (Dict[Any, Any]): A dictionary of actions performed by agents.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.environment.environment_config","title":"environment_config
","text":""},{"location":"reference/#src.aeiva.environment.environment_config.EnvironmentConfig","title":"EnvironmentConfig
dataclass
","text":" Bases: BaseConfig
Configuration class for initializing an environment.
Attributes:
Name Type Description environment_type
str
Type of the environment, see EnvironmentType class.
initial_state
Optional[Any]
Optional initial state of the environment.
constraints
Dict[str, Any]
Rules or constraints governing the environment.
entities
List[Any]
Entities present within the environment.
time_enabled
bool
Whether the environment tracks time progression.
Source code in src/aeiva/environment/environment_config.py
@dataclass\nclass EnvironmentConfig(BaseConfig):\n \"\"\"\n Configuration class for initializing an environment.\n\n Attributes:\n environment_type (str): Type of the environment, see EnvironmentType class.\n initial_state (Optional[Any]): Optional initial state of the environment.\n constraints (Dict[str, Any]): Rules or constraints governing the environment.\n entities (List[Any]): Entities present within the environment.\n time_enabled (bool): Whether the environment tracks time progression.\n \"\"\"\n\n environment_type: str = field(\n metadata={\"help\": \"Type of the environment (e.g., 'user', 'document', 'game').\"}\n )\n initial_state: Optional[Any] = field(\n default=None,\n metadata={\"help\": \"Optional initial state of the environment.\"}\n )\n constraints: Dict[str, Any] = field(\n default_factory=dict,\n metadata={\"help\": \"Rules or constraints for the environment.\"}\n )\n entities: List[Any] = field(\n default_factory=list,\n metadata={\"help\": \"Entities within the environment.\"}\n )\n time_enabled: bool = field(\n default=False,\n metadata={\"help\": \"Flag to enable time progression.\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n # Perform any necessary validation\n if not self.environment_type:\n raise ValueError(\"Environment type must be provided.\")\n
"},{"location":"reference/#src.aeiva.environment.environment_type","title":"environment_type
","text":""},{"location":"reference/#src.aeiva.environment.environment_type.EnvironmentType","title":"EnvironmentType
","text":"A class to hold constants for various environment types, organized by broad categories to maximize generality while supporting diverse use cases.
Categories - Interaction-Based: Environments with user or agent interaction.
- Digital: Environments involving digital interfaces, applications, or software systems.
- Data-Based: Static or dynamic data collections or document repositories.
- Virtual/Simulated: Simulated, spatial, or immersive virtual environments.
- World-Level: Comprehensive real or virtual world environments.
Source code in src/aeiva/environment/environment_type.py
class EnvironmentType:\n \"\"\"\n A class to hold constants for various environment types, organized by broad categories\n to maximize generality while supporting diverse use cases.\n\n Categories:\n - Interaction-Based: Environments with user or agent interaction.\n - Digital: Environments involving digital interfaces, applications, or software systems.\n - Data-Based: Static or dynamic data collections or document repositories.\n - Virtual/Simulated: Simulated, spatial, or immersive virtual environments.\n - World-Level: Comprehensive real or virtual world environments.\n \"\"\"\n\n # Interaction-Based Environments\n INTERACTIVE = \"Interactive\" # Environments involving user or multi-agent interaction.\n\n # Digital Environments\n DIGITAL_ENVIRONMENT = \"Digital Environment\" # Digital workspaces, applications, OS, or software systems.\n\n # Data-Based Environments\n DATA_REPOSITORY = \"Data Repository\" # Static datasets, dynamic data streams, or document repositories (e.g., knowledge bases).\n\n # Virtual/Simulated Environments\n VIRTUAL_ENVIRONMENT = \"Virtual Environment\" # Simulated or immersive 3D spaces, including games and VR.\n\n # World-Level Environments\n FULL_WORLD = \"Full World\" # Comprehensive virtual or real-world environment.\n\n # Meta/Complex Environments\n HYBRID_ENVIRONMENT = \"Hybrid Environment\" # Combination of multiple types.\n\n # Custom environment type for unique or unspecified cases.\n CUSTOM = \"Custom\"\n
"},{"location":"reference/#src.aeiva.event","title":"event
","text":""},{"location":"reference/#src.aeiva.event.event","title":"event
","text":""},{"location":"reference/#src.aeiva.event.event.Event","title":"Event
dataclass
","text":"Represents an event in the event bus system.
Attributes:
Name Type Description name
str
The name of the event.
payload
Any
The data associated with the event.
timestamp
datetime
The time the event was created.
priority
int
The priority of the event.
Source code in src/aeiva/event/event.py
@dataclass\nclass Event:\n \"\"\"\n Represents an event in the event bus system.\n\n Attributes:\n name (str): The name of the event.\n payload (Any): The data associated with the event.\n timestamp (datetime): The time the event was created.\n priority (int): The priority of the event.\n \"\"\"\n name: str\n payload: Any = None\n timestamp: datetime = field(default_factory=datetime.utcnow)\n priority: int = 0\n
"},{"location":"reference/#src.aeiva.event.event_bus","title":"event_bus
","text":""},{"location":"reference/#src.aeiva.event.event_bus.EventBus","title":"EventBus
","text":"An asynchronous event bus for publishing and subscribing to events.
Features: - Subscribers can use wildcard patterns to subscribe to multiple events. - Subscribers can cancel event propagation. - Subscribers can be set to auto-unsubscribe after one call. - Event-level prioritization in the queue. - Customizable error handling. - Logging for key actions. - emit, emit_after, and emit_only methods for flexible event emission.
Source code in src/aeiva/event/event_bus.py
class EventBus:\n \"\"\"\n An asynchronous event bus for publishing and subscribing to events.\n\n Features:\n - Subscribers can use wildcard patterns to subscribe to multiple events.\n - Subscribers can cancel event propagation.\n - Subscribers can be set to auto-unsubscribe after one call.\n - Event-level prioritization in the queue.\n - Customizable error handling.\n - Logging for key actions.\n - emit, emit_after, and emit_only methods for flexible event emission.\n \"\"\"\n\n def __init__(self):\n \"\"\"\n Initializes the event bus.\n \"\"\"\n self._subscribers: List[Dict] = [] # List of subscriber dictionaries\n self._event_queue = asyncio.PriorityQueue()\n self._processing_task: Optional[asyncio.Task] = None\n self._event_counter = 0 # Counter to maintain order of events with same priority\n self.loop = None\n\n def subscribe(\n self,\n event_pattern: str,\n callback: Callable[[Event], Any],\n *,\n priority: int = 0,\n once: bool = False\n ):\n \"\"\"\n Subscribes a callback function to events matching a pattern.\n\n Args:\n event_pattern (str): The event name or pattern to subscribe to.\n callback (Callable[[Event], Any]): The callback function.\n priority (int, optional): Priority of the callback.\n once (bool, optional): If True, unsubscribe after one call.\n \"\"\"\n subscriber = {\n 'pattern': re.compile(event_pattern.replace('*', '.*')),\n 'callback': callback,\n 'priority': priority,\n 'once': once\n }\n self._subscribers.append(subscriber)\n logger.info(f\"Subscribed '{callback.__name__}' to pattern '{event_pattern}' with priority {priority}.\")\n\n def unsubscribe(self, callback: Callable[[Event], Any]):\n \"\"\"\n Unsubscribes a callback function from all events.\n\n Args:\n callback (Callable[[Event], Any]): The callback function to remove.\n \"\"\"\n self._subscribers = [\n sub for sub in self._subscribers\n if sub['callback'] != callback\n ]\n logger.info(f\"Unsubscribed '{callback.__name__}' from all events.\")\n\n async def publish(self, event: Event, only: Union[str, List[str]] = None):\n \"\"\"\n Publishes an event to the event bus.\n\n Args:\n event (Event): The event to publish.\n only (str or List[str], optional): Names of specific subscribers to notify.\n \"\"\"\n self._event_counter += 1\n # Use a tuple of (priority, counter) to ensure proper ordering\n await self._event_queue.put((event.priority * -1, self._event_counter, event, only))\n logger.info(f\"Published event '{event.name}' with priority {event.priority}.\")\n\n async def _process_events(self):\n \"\"\"\n Internal coroutine that processes events from the queue and dispatches them to subscribers.\n \"\"\"\n while True:\n try:\n _, _, event, only = await self._event_queue.get()\n logger.info(f\"Processing event '{event.name}'.\")\n await self._dispatch_event(event, only)\n self._event_queue.task_done()\n except asyncio.CancelledError:\n # Exit the loop gracefully\n break\n except Exception as e:\n logger.error(f\"Error processing event: {e}\")\n self._event_queue.task_done()\n\n async def _dispatch_event(self, event: Event, only: Union[str, List[str]] = None):\n \"\"\"\n Dispatches an event to the appropriate subscribers.\n\n Args:\n event (Event): The event to dispatch.\n only (str or List[str], optional): Names of specific subscribers to notify.\n \"\"\"\n subscribers = sorted(\n [\n sub for sub in self._subscribers\n if sub['pattern'].fullmatch(event.name)\n and (only is None or sub['callback'].__name__ in (only if isinstance(only, list) else [only]))\n ],\n key=lambda x: x['priority'],\n reverse=True\n )\n for subscriber in subscribers:\n callback = subscriber['callback']\n try:\n if asyncio.iscoroutinefunction(callback):\n await callback(event)\n else:\n await asyncio.get_event_loop().run_in_executor(None, callback, event)\n except EventCancelled:\n logger.info(f\"Event '{event.name}' cancelled by '{callback.__name__}'.\")\n break # Stop further propagation\n except Exception as e:\n logger.error(f\"Error in callback '{callback.__name__}' for event '{event.name}': {e}\")\n self._handle_callback_exception(e, callback, event)\n finally:\n if subscriber.get('once'):\n self.unsubscribe(callback)\n\n def _handle_callback_exception(self, exception, callback, event):\n \"\"\"\n Handle exceptions raised by subscriber callbacks.\n\n Args:\n exception (Exception): The exception raised.\n callback (Callable): The subscriber callback.\n event (Event): The event being processed.\n \"\"\"\n # Default behavior is to log the exception.\n pass # Can be customized as needed.\n\n def start(self):\n \"\"\"\n Starts the event bus processing loop.\n \"\"\"\n if self._processing_task is None:\n self.loop = asyncio.get_running_loop()\n self._processing_task = asyncio.create_task(self._process_events())\n logger.info(\"Event bus started.\")\n\n def stop(self):\n \"\"\"\n Stops the event bus processing loop.\n \"\"\"\n if self._processing_task:\n self._processing_task.cancel()\n logger.info(\"Event bus stopped.\")\n\n def on(self, event_pattern: str, priority: int = 0, once: bool = False):\n \"\"\"\n Decorator for subscribing a function to events matching a pattern.\n\n Usage:\n @event_bus.on('event.*', priority=10)\n async def handler(event):\n ...\n\n Args:\n event_pattern (str): The event name or pattern to subscribe to.\n priority (int, optional): Priority of the callback.\n once (bool, optional): If True, unsubscribe after one call.\n\n Returns:\n Callable: The decorator function.\n \"\"\"\n def decorator(callback: Callable[[Event], Any]):\n self.subscribe(event_pattern, callback, priority=priority, once=once)\n return callback\n return decorator\n\n def emit_after(self, event_name: str, priority: int = 0):\n \"\"\"\n Decorator that emits an event after the decorated function is called.\n\n Usage:\n @event_bus.emit_after('event_name')\n def some_function():\n ...\n\n Args:\n event_name (str): The name of the event to emit after function execution.\n priority (int, optional): The priority of the event.\n\n Returns:\n Callable: The decorator function.\n \"\"\"\n def decorator(func: Callable):\n if asyncio.iscoroutinefunction(func):\n @wraps(func)\n async def async_wrapper(*args, **kwargs):\n result = await func(*args, **kwargs)\n await self.emit(event_name, priority=priority)\n return result\n return async_wrapper\n else:\n @wraps(func)\n def sync_wrapper(*args, **kwargs):\n result = func(*args, **kwargs)\n asyncio.create_task(self.emit(event_name, priority=priority))\n return result\n return sync_wrapper\n return decorator\n\n async def emit(self, event_name: str, payload: Any = None, priority: int = 0):\n \"\"\"\n Emits an event to all matching subscribers.\n\n Args:\n event_name (str): The name of the event to emit.\n payload (Any, optional): The payload of the event.\n priority (int, optional): The priority of the event.\n \"\"\"\n await self.publish(Event(name=event_name, payload=payload, priority=priority))\n\n async def emit_only(self, event_name: str, subscriber_names: Union[str, List[str]], payload: Any = None, priority: int = 0):\n \"\"\"\n Emits an event only to specified subscribers.\n\n Args:\n event_name (str): The name of the event to emit.\n subscriber_names (str or List[str]): The name(s) of subscribers to notify.\n payload (Any, optional): The payload of the event.\n priority (int, optional): The priority of the event.\n \"\"\"\n await self.publish(Event(name=event_name, payload=payload, priority=priority), only=subscriber_names)\n\n async def wait_until_all_events_processed(self):\n \"\"\"\n Waits until all events in the queue have been processed.\n \"\"\"\n await self._event_queue.join()\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.__init__","title":"__init__()
","text":"Initializes the event bus.
Source code in src/aeiva/event/event_bus.py
def __init__(self):\n \"\"\"\n Initializes the event bus.\n \"\"\"\n self._subscribers: List[Dict] = [] # List of subscriber dictionaries\n self._event_queue = asyncio.PriorityQueue()\n self._processing_task: Optional[asyncio.Task] = None\n self._event_counter = 0 # Counter to maintain order of events with same priority\n self.loop = None\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.emit","title":"emit(event_name, payload=None, priority=0)
async
","text":"Emits an event to all matching subscribers.
Parameters:
Name Type Description Default event_name
str
The name of the event to emit.
required payload
Any
The payload of the event.
None
priority
int
The priority of the event.
0
Source code in src/aeiva/event/event_bus.py
async def emit(self, event_name: str, payload: Any = None, priority: int = 0):\n \"\"\"\n Emits an event to all matching subscribers.\n\n Args:\n event_name (str): The name of the event to emit.\n payload (Any, optional): The payload of the event.\n priority (int, optional): The priority of the event.\n \"\"\"\n await self.publish(Event(name=event_name, payload=payload, priority=priority))\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.emit_after","title":"emit_after(event_name, priority=0)
","text":"Decorator that emits an event after the decorated function is called.
Usage @event_bus.emit_after('event_name') def some_function(): ...
Parameters:
Name Type Description Default event_name
str
The name of the event to emit after function execution.
required priority
int
The priority of the event.
0
Returns:
Name Type Description Callable
The decorator function.
Source code in src/aeiva/event/event_bus.py
def emit_after(self, event_name: str, priority: int = 0):\n \"\"\"\n Decorator that emits an event after the decorated function is called.\n\n Usage:\n @event_bus.emit_after('event_name')\n def some_function():\n ...\n\n Args:\n event_name (str): The name of the event to emit after function execution.\n priority (int, optional): The priority of the event.\n\n Returns:\n Callable: The decorator function.\n \"\"\"\n def decorator(func: Callable):\n if asyncio.iscoroutinefunction(func):\n @wraps(func)\n async def async_wrapper(*args, **kwargs):\n result = await func(*args, **kwargs)\n await self.emit(event_name, priority=priority)\n return result\n return async_wrapper\n else:\n @wraps(func)\n def sync_wrapper(*args, **kwargs):\n result = func(*args, **kwargs)\n asyncio.create_task(self.emit(event_name, priority=priority))\n return result\n return sync_wrapper\n return decorator\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.emit_only","title":"emit_only(event_name, subscriber_names, payload=None, priority=0)
async
","text":"Emits an event only to specified subscribers.
Parameters:
Name Type Description Default event_name
str
The name of the event to emit.
required subscriber_names
str or List[str]
The name(s) of subscribers to notify.
required payload
Any
The payload of the event.
None
priority
int
The priority of the event.
0
Source code in src/aeiva/event/event_bus.py
async def emit_only(self, event_name: str, subscriber_names: Union[str, List[str]], payload: Any = None, priority: int = 0):\n \"\"\"\n Emits an event only to specified subscribers.\n\n Args:\n event_name (str): The name of the event to emit.\n subscriber_names (str or List[str]): The name(s) of subscribers to notify.\n payload (Any, optional): The payload of the event.\n priority (int, optional): The priority of the event.\n \"\"\"\n await self.publish(Event(name=event_name, payload=payload, priority=priority), only=subscriber_names)\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.on","title":"on(event_pattern, priority=0, once=False)
","text":"Decorator for subscribing a function to events matching a pattern.
Usage @event_bus.on('event.*', priority=10) async def handler(event): ...
Parameters:
Name Type Description Default event_pattern
str
The event name or pattern to subscribe to.
required priority
int
Priority of the callback.
0
once
bool
If True, unsubscribe after one call.
False
Returns:
Name Type Description Callable
The decorator function.
Source code in src/aeiva/event/event_bus.py
def on(self, event_pattern: str, priority: int = 0, once: bool = False):\n \"\"\"\n Decorator for subscribing a function to events matching a pattern.\n\n Usage:\n @event_bus.on('event.*', priority=10)\n async def handler(event):\n ...\n\n Args:\n event_pattern (str): The event name or pattern to subscribe to.\n priority (int, optional): Priority of the callback.\n once (bool, optional): If True, unsubscribe after one call.\n\n Returns:\n Callable: The decorator function.\n \"\"\"\n def decorator(callback: Callable[[Event], Any]):\n self.subscribe(event_pattern, callback, priority=priority, once=once)\n return callback\n return decorator\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.publish","title":"publish(event, only=None)
async
","text":"Publishes an event to the event bus.
Parameters:
Name Type Description Default event
Event
The event to publish.
required only
str or List[str]
Names of specific subscribers to notify.
None
Source code in src/aeiva/event/event_bus.py
async def publish(self, event: Event, only: Union[str, List[str]] = None):\n \"\"\"\n Publishes an event to the event bus.\n\n Args:\n event (Event): The event to publish.\n only (str or List[str], optional): Names of specific subscribers to notify.\n \"\"\"\n self._event_counter += 1\n # Use a tuple of (priority, counter) to ensure proper ordering\n await self._event_queue.put((event.priority * -1, self._event_counter, event, only))\n logger.info(f\"Published event '{event.name}' with priority {event.priority}.\")\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.start","title":"start()
","text":"Starts the event bus processing loop.
Source code in src/aeiva/event/event_bus.py
def start(self):\n \"\"\"\n Starts the event bus processing loop.\n \"\"\"\n if self._processing_task is None:\n self.loop = asyncio.get_running_loop()\n self._processing_task = asyncio.create_task(self._process_events())\n logger.info(\"Event bus started.\")\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.stop","title":"stop()
","text":"Stops the event bus processing loop.
Source code in src/aeiva/event/event_bus.py
def stop(self):\n \"\"\"\n Stops the event bus processing loop.\n \"\"\"\n if self._processing_task:\n self._processing_task.cancel()\n logger.info(\"Event bus stopped.\")\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.subscribe","title":"subscribe(event_pattern, callback, *, priority=0, once=False)
","text":"Subscribes a callback function to events matching a pattern.
Parameters:
Name Type Description Default event_pattern
str
The event name or pattern to subscribe to.
required callback
Callable[[Event], Any]
The callback function.
required priority
int
Priority of the callback.
0
once
bool
If True, unsubscribe after one call.
False
Source code in src/aeiva/event/event_bus.py
def subscribe(\n self,\n event_pattern: str,\n callback: Callable[[Event], Any],\n *,\n priority: int = 0,\n once: bool = False\n):\n \"\"\"\n Subscribes a callback function to events matching a pattern.\n\n Args:\n event_pattern (str): The event name or pattern to subscribe to.\n callback (Callable[[Event], Any]): The callback function.\n priority (int, optional): Priority of the callback.\n once (bool, optional): If True, unsubscribe after one call.\n \"\"\"\n subscriber = {\n 'pattern': re.compile(event_pattern.replace('*', '.*')),\n 'callback': callback,\n 'priority': priority,\n 'once': once\n }\n self._subscribers.append(subscriber)\n logger.info(f\"Subscribed '{callback.__name__}' to pattern '{event_pattern}' with priority {priority}.\")\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.unsubscribe","title":"unsubscribe(callback)
","text":"Unsubscribes a callback function from all events.
Parameters:
Name Type Description Default callback
Callable[[Event], Any]
The callback function to remove.
required Source code in src/aeiva/event/event_bus.py
def unsubscribe(self, callback: Callable[[Event], Any]):\n \"\"\"\n Unsubscribes a callback function from all events.\n\n Args:\n callback (Callable[[Event], Any]): The callback function to remove.\n \"\"\"\n self._subscribers = [\n sub for sub in self._subscribers\n if sub['callback'] != callback\n ]\n logger.info(f\"Unsubscribed '{callback.__name__}' from all events.\")\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.wait_until_all_events_processed","title":"wait_until_all_events_processed()
async
","text":"Waits until all events in the queue have been processed.
Source code in src/aeiva/event/event_bus.py
async def wait_until_all_events_processed(self):\n \"\"\"\n Waits until all events in the queue have been processed.\n \"\"\"\n await self._event_queue.join()\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventCancelled","title":"EventCancelled
","text":" Bases: Exception
Exception to indicate that an event has been cancelled.
Source code in src/aeiva/event/event_bus.py
class EventCancelled(Exception):\n \"\"\"Exception to indicate that an event has been cancelled.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.hypergraph","title":"hypergraph
","text":""},{"location":"reference/#src.aeiva.hypergraph.exceptions","title":"exceptions
","text":""},{"location":"reference/#src.aeiva.hypergraph.exceptions.HypergraphError","title":"HypergraphError
","text":" Bases: Exception
Custom exception class for Hypergraph-related errors.
Source code in src/aeiva/hypergraph/exceptions.py
class HypergraphError(Exception):\n \"\"\"\n Custom exception class for Hypergraph-related errors.\n \"\"\"\n def __init__(self, message: str = \"An error occurred in the Hypergraph module.\"):\n super().__init__(message)\n
"},{"location":"reference/#src.aeiva.hypergraph.hyperedge","title":"hyperedge
","text":""},{"location":"reference/#src.aeiva.hypergraph.hyperedge.HyperEdge","title":"HyperEdge
","text":"Represents a hyperedge in the hypergraph, encapsulating its properties and connected nodes.
Source code in src/aeiva/hypergraph/hyperedge.py
class HyperEdge:\n \"\"\"\n Represents a hyperedge in the hypergraph, encapsulating its properties and connected nodes.\n \"\"\"\n\n def __init__(\n self,\n id: Any,\n nodes: Optional[Iterable[Any]] = None,\n properties: Optional[Dict[str, Any]] = None\n ) -> None:\n \"\"\"\n Initializes a HyperEdge.\n\n Parameters:\n id: Unique identifier for the hyperedge.\n nodes: (Optional) Iterable of node identifiers connected by the hyperedge.\n properties: (Optional) Dictionary of properties.\n \"\"\"\n self.id: Any = id\n self.nodes: Set[Any] = set(nodes) if nodes else set()\n self.properties: Dict[str, Any] = properties.copy() if properties else {}\n\n def add_node(self, node_id: Any) -> None:\n \"\"\"\n Adds a node to the hyperedge.\n\n Parameters:\n node_id: Identifier of the node to add.\n \"\"\"\n self.nodes.add(node_id)\n\n def remove_node(self, node_id: Any) -> None:\n \"\"\"\n Removes a node from the hyperedge.\n\n Parameters:\n node_id: Identifier of the node to remove.\n \"\"\"\n if node_id in self.nodes:\n self.nodes.remove(node_id)\n else:\n raise HypergraphError(f\"Node '{node_id}' not found in HyperEdge '{self.id}'.\")\n\n def add_property(self, key: str, value: Any) -> None:\n \"\"\"\n Adds or updates a property of the hyperedge.\n\n Parameters:\n key: Property name.\n value: Property value.\n \"\"\"\n self.properties[key] = value\n\n def get_property(self, key: str) -> Any:\n \"\"\"\n Retrieves a property of the hyperedge.\n\n Parameters:\n key: Property name.\n\n Returns:\n The value of the property.\n\n Raises:\n HypergraphError: If the property does not exist.\n \"\"\"\n if key in self.properties:\n return self.properties[key]\n else:\n raise HypergraphError(f\"Property '{key}' does not exist for HyperEdge '{self.id}'.\")\n\n def remove_property(self, key: str) -> None:\n \"\"\"\n Removes a property from the hyperedge.\n\n Parameters:\n key: Property name.\n\n Raises:\n HypergraphError: If the property does not exist.\n \"\"\"\n if key in self.properties:\n del self.properties[key]\n else:\n raise HypergraphError(f\"Property '{key}' does not exist for HyperEdge '{self.id}'.\")\n\n def to_dict(self):\n return {\n \"id\": self.id,\n \"nodes\": self.nodes,\n \"properties\": self.properties\n }\n
"},{"location":"reference/#src.aeiva.hypergraph.hyperedge.HyperEdge.__init__","title":"__init__(id, nodes=None, properties=None)
","text":"Initializes a HyperEdge.
Parameters:
Name Type Description Default id
Any
Unique identifier for the hyperedge.
required nodes
Optional[Iterable[Any]]
(Optional) Iterable of node identifiers connected by the hyperedge.
None
properties
Optional[Dict[str, Any]]
(Optional) Dictionary of properties.
None
Source code in src/aeiva/hypergraph/hyperedge.py
def __init__(\n self,\n id: Any,\n nodes: Optional[Iterable[Any]] = None,\n properties: Optional[Dict[str, Any]] = None\n) -> None:\n \"\"\"\n Initializes a HyperEdge.\n\n Parameters:\n id: Unique identifier for the hyperedge.\n nodes: (Optional) Iterable of node identifiers connected by the hyperedge.\n properties: (Optional) Dictionary of properties.\n \"\"\"\n self.id: Any = id\n self.nodes: Set[Any] = set(nodes) if nodes else set()\n self.properties: Dict[str, Any] = properties.copy() if properties else {}\n
"},{"location":"reference/#src.aeiva.hypergraph.hyperedge.HyperEdge.add_node","title":"add_node(node_id)
","text":"Adds a node to the hyperedge.
Parameters:
Name Type Description Default node_id
Any
Identifier of the node to add.
required Source code in src/aeiva/hypergraph/hyperedge.py
def add_node(self, node_id: Any) -> None:\n \"\"\"\n Adds a node to the hyperedge.\n\n Parameters:\n node_id: Identifier of the node to add.\n \"\"\"\n self.nodes.add(node_id)\n
"},{"location":"reference/#src.aeiva.hypergraph.hyperedge.HyperEdge.add_property","title":"add_property(key, value)
","text":"Adds or updates a property of the hyperedge.
Parameters:
Name Type Description Default key
str
Property name.
required value
Any
Property value.
required Source code in src/aeiva/hypergraph/hyperedge.py
def add_property(self, key: str, value: Any) -> None:\n \"\"\"\n Adds or updates a property of the hyperedge.\n\n Parameters:\n key: Property name.\n value: Property value.\n \"\"\"\n self.properties[key] = value\n
"},{"location":"reference/#src.aeiva.hypergraph.hyperedge.HyperEdge.get_property","title":"get_property(key)
","text":"Retrieves a property of the hyperedge.
Parameters:
Name Type Description Default key
str
Property name.
required Returns:
Type Description Any
The value of the property.
Raises:
Type Description HypergraphError
If the property does not exist.
Source code in src/aeiva/hypergraph/hyperedge.py
def get_property(self, key: str) -> Any:\n \"\"\"\n Retrieves a property of the hyperedge.\n\n Parameters:\n key: Property name.\n\n Returns:\n The value of the property.\n\n Raises:\n HypergraphError: If the property does not exist.\n \"\"\"\n if key in self.properties:\n return self.properties[key]\n else:\n raise HypergraphError(f\"Property '{key}' does not exist for HyperEdge '{self.id}'.\")\n
"},{"location":"reference/#src.aeiva.hypergraph.hyperedge.HyperEdge.remove_node","title":"remove_node(node_id)
","text":"Removes a node from the hyperedge.
Parameters:
Name Type Description Default node_id
Any
Identifier of the node to remove.
required Source code in src/aeiva/hypergraph/hyperedge.py
def remove_node(self, node_id: Any) -> None:\n \"\"\"\n Removes a node from the hyperedge.\n\n Parameters:\n node_id: Identifier of the node to remove.\n \"\"\"\n if node_id in self.nodes:\n self.nodes.remove(node_id)\n else:\n raise HypergraphError(f\"Node '{node_id}' not found in HyperEdge '{self.id}'.\")\n
"},{"location":"reference/#src.aeiva.hypergraph.hyperedge.HyperEdge.remove_property","title":"remove_property(key)
","text":"Removes a property from the hyperedge.
Parameters:
Name Type Description Default key
str
Property name.
required Raises:
Type Description HypergraphError
If the property does not exist.
Source code in src/aeiva/hypergraph/hyperedge.py
def remove_property(self, key: str) -> None:\n \"\"\"\n Removes a property from the hyperedge.\n\n Parameters:\n key: Property name.\n\n Raises:\n HypergraphError: If the property does not exist.\n \"\"\"\n if key in self.properties:\n del self.properties[key]\n else:\n raise HypergraphError(f\"Property '{key}' does not exist for HyperEdge '{self.id}'.\")\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph","title":"hypergraph
","text":""},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph","title":"Hypergraph
","text":"A simplified Hypergraph class using dictionaries and NetworkX for management.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph--parameters","title":"Parameters","text":"hyperedges : Dict[Any, Dict[str, Any]] A dictionary where keys are hyperedge identifiers and values are dictionaries containing: - 'nodes': Iterable of node identifiers connected by the hyperedge. - 'properties': (Optional) Dictionary of properties for the hyperedge.
Optional[Dict[Any, Dict[str, Any]]] = None A dictionary where keys are node identifiers and values are dictionaries of node properties.
Optional[Dict[Any, Dict[str, Any]]] = None A dictionary where keys are hyperedge identifiers and values are dictionaries of hyperedge properties.
Optional[str] = None Name assigned to the hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
class Hypergraph:\n \"\"\"\n A simplified Hypergraph class using dictionaries and NetworkX for management.\n\n Parameters\n ----------\n hyperedges : Dict[Any, Dict[str, Any]]\n A dictionary where keys are hyperedge identifiers and values are dictionaries containing:\n - 'nodes': Iterable of node identifiers connected by the hyperedge.\n - 'properties': (Optional) Dictionary of properties for the hyperedge.\n\n node_properties : Optional[Dict[Any, Dict[str, Any]]] = None\n A dictionary where keys are node identifiers and values are dictionaries of node properties.\n\n hyperedge_properties : Optional[Dict[Any, Dict[str, Any]]] = None\n A dictionary where keys are hyperedge identifiers and values are dictionaries of hyperedge properties.\n\n name : Optional[str] = None\n Name assigned to the hypergraph.\n \"\"\"\n\n def __init__(\n self,\n hyperedges: Dict[Any, Dict[str, Any]],\n node_properties: Optional[Dict[Any, Dict[str, Any]]] = None,\n hyperedge_properties: Optional[Dict[Any, Dict[str, Any]]] = None,\n name: Optional[str] = None\n ):\n self.name = name\n self.graph = nx.Graph()\n self.bipartite_nodes: Set[Any] = set()\n\n # Initialize node and hyperedge properties using deep copies to ensure full duplication\n self.node_properties: Dict[Any, Dict[str, Any]] = copy.deepcopy(node_properties) if node_properties else {}\n self.hyperedge_properties: Dict[Any, Dict[str, Any]] = copy.deepcopy(hyperedge_properties) if hyperedge_properties else {}\n\n # Add hyperedges and their connections to nodes\n self.hyperedges: Dict[Any, HyperEdge] = {}\n for he_id, he_data in hyperedges.items():\n nodes = he_data.get('nodes', [])\n properties = he_data.get('properties', {})\n hyperedge = HyperEdge(id=he_id, nodes=nodes, properties=properties)\n self.hyperedges[he_id] = hyperedge\n\n # Add hyperedge to bipartite graph with properties\n self.graph.add_node(he_id, bipartite='hyperedge', **self.hyperedge_properties.get(he_id, {}))\n self.bipartite_nodes.add(he_id)\n\n # Add edges between hyperedge and nodes with node properties\n for node in hyperedge.nodes:\n if node not in self.graph:\n self.graph.add_node(node, bipartite='node', **self.node_properties.get(node, {}))\n self.graph.add_edge(he_id, node)\n\n def dual(self, name: Optional[str] = None) -> \"Hypergraph\":\n \"\"\"\n Constructs the dual of the current hypergraph by reversing the roles of nodes and hyperedges.\n\n Parameters\n ----------\n name : Optional[str], default=None\n Name for the dual hypergraph. If None, defaults to the original hypergraph's name with '_dual' appended.\n\n Returns\n -------\n Hypergraph\n A new Hypergraph instance representing the dual of the current hypergraph.\n \"\"\"\n # Initialize dual hyperedges, which will correspond to original nodes\n dual_hyperedges = {}\n\n # Invert the node-hyperedge structure\n for he_id, hyperedge in self.hyperedges.items():\n for node in hyperedge.nodes:\n # Each original node becomes a hyperedge in the dual\n if node not in dual_hyperedges:\n dual_hyperedges[node] = {'nodes': [], 'properties': self.node_properties.get(node, {})}\n # The new hyperedge (original node) connects to the original hyperedge id as a \"node\"\n dual_hyperedges[node]['nodes'].append(he_id)\n\n # Define node properties in the dual as the original hyperedge properties\n dual_node_properties = {he_id: self.hyperedge_properties.get(he_id, {}) for he_id in self.hyperedges}\n\n # Create and return the dual Hypergraph\n return Hypergraph(\n hyperedges=dual_hyperedges,\n node_properties=dual_node_properties,\n hyperedge_properties=self.node_properties, # Properties of original nodes now apply to dual hyperedges\n name=name or (self.name + \"_dual\" if self.name else \"dual\")\n )\n\n def nodes(self) -> List[Any]:\n \"\"\"\n Returns a list of all unique node identifiers in the hypergraph.\n\n Returns\n -------\n List[Any]\n List of node IDs.\n \"\"\"\n return list(self.node_properties.keys())\n\n def node_memberships(self) -> Dict[Any, List[Any]]:\n \"\"\"\n Returns a dictionary where each key is a node ID and the value is a list of hyperedge IDs that include the node.\n\n Returns\n -------\n Dict[Any, List[Any]]\n Dictionary mapping node IDs to the hyperedge IDs they belong to.\n \"\"\"\n memberships = {}\n for he_id, hyperedge in self.hyperedges.items():\n for node in hyperedge.nodes:\n memberships.setdefault(node, []).append(he_id)\n return memberships\n\n def edges(self) -> List[Any]:\n \"\"\"\n Returns a list of all hyperedge identifiers in the hypergraph.\n\n Returns\n -------\n List[Any]\n List of hyperedge IDs.\n \"\"\"\n return list(self.hyperedges.keys())\n\n def edge_elements(self) -> Dict[Any, List[Any]]:\n \"\"\"\n Returns a dictionary where each key is a hyperedge ID and the value is a list of node IDs within that hyperedge.\n\n Returns\n -------\n Dict[Any, List[Any]]\n Dictionary mapping hyperedge IDs to lists of node IDs they contain.\n \"\"\"\n return {he_id: hyperedge.nodes for he_id, hyperedge in self.hyperedges.items()}\n\n def __str__(self) -> str:\n \"\"\"\n String representation of the hypergraph.\n\n Returns\n -------\n str\n A string describing the hypergraph with its name, number of nodes, and hyperedges.\n \"\"\"\n return f\"Hypergraph '{self.name}' with {len(self)} nodes and {len(self.hyperedges)} hyperedges.\"\n\n def __repr__(self) -> str:\n \"\"\"\n Official string representation of the hypergraph.\n\n Returns\n -------\n str\n A detailed string describing the hypergraph with its name, number of nodes, and hyperedges.\n \"\"\"\n return (\n f\"Hypergraph(name={self.name!r}, \"\n f\"nodes={len(self)}, hyperedges={len(self.hyperedges)})\"\n )\n\n def __len__(self) -> int:\n \"\"\"\n Returns the number of nodes in the hypergraph.\n\n Returns\n -------\n int\n Number of nodes.\n \"\"\"\n return len(self.node_properties)\n\n def __iter__(self) -> Iterator[Any]:\n \"\"\"\n Allows iteration over the nodes of the hypergraph.\n\n Yields\n ------\n Any\n Node identifiers.\n \"\"\"\n return iter(self.node_properties)\n\n def __contains__(self, item: Any) -> bool:\n \"\"\"\n Checks if a node is in the hypergraph.\n\n Parameters\n ----------\n item : Any\n The node identifier to check.\n\n Returns\n -------\n bool\n True if the node exists in the hypergraph, False otherwise.\n \"\"\"\n return item in self.node_properties\n\n def __getitem__(self, node: Any) -> Iterable[Any]:\n \"\"\"\n Retrieves the neighbors of a node in the hypergraph.\n\n Neighbors are nodes that share at least one hyperedge with the given node.\n\n Parameters\n ----------\n node : Any\n The node identifier.\n\n Returns\n -------\n Iterable[Any]\n An iterator over neighboring node identifiers.\n\n Raises\n ------\n HypergraphError\n If the node does not exist in the hypergraph.\n \"\"\"\n if node not in self.node_properties:\n raise HypergraphError(f\"Node '{node}' does not exist in the hypergraph.\")\n\n # Get all hyperedges that include the node\n hyperedges = set(self.graph.neighbors(node))\n\n # Get all nodes connected by these hyperedges\n neighbors = set()\n for he_id in hyperedges:\n neighbors.update(self.hyperedges[he_id].nodes)\n\n neighbors.discard(node) # Remove the node itself\n return neighbors\n\n def __eq__(self, other: Any) -> bool:\n \"\"\"\n Checks if two hypergraphs are equal based on their hyperedges and nodes.\n\n Parameters\n ----------\n other : Any\n The other object to compare.\n\n Returns\n -------\n bool\n True if both hypergraphs have identical nodes and hyperedges with the same properties, False otherwise.\n \"\"\"\n if not isinstance(other, Hypergraph):\n return False\n\n # Compare nodes and their properties\n if self.node_properties != other.node_properties:\n return False\n\n # Compare hyperedges and their properties\n if self.hyperedges.keys() != other.hyperedges.keys():\n return False\n\n for he_id in self.hyperedges:\n if self.hyperedges[he_id].nodes != other.hyperedges[he_id].nodes:\n return False\n if self.hyperedge_properties.get(he_id, {}) != other.hyperedge_properties.get(he_id, {}):\n return False\n\n return True\n\n def copy(self, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Creates a deep copy of the hypergraph instance.\n\n Parameters\n ----------\n name : Optional[str], default=None\n The name for the copied Hypergraph. If not provided, retains the original name.\n\n Returns\n -------\n Hypergraph\n A new Hypergraph instance that is a deep copy of the original.\n \"\"\"\n\n # Deep copy hyperedges\n hyperedges_dict = {}\n for he_id, he in self.hyperedges.items():\n hyperedges_dict[he_id] = {\n 'nodes': list(he.nodes),\n 'properties': copy.deepcopy(he.properties)\n }\n\n # Deep copy node_properties and hyperedge_properties\n node_properties_copy = copy.deepcopy(self.node_properties)\n hyperedge_properties_copy = copy.deepcopy(self.hyperedge_properties)\n\n # Create a new Hypergraph instance with the copied data\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=node_properties_copy,\n hyperedge_properties=hyperedge_properties_copy,\n name=name if name is not None else self.name\n )\n\n def deepcopy(self, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Creates a deep copy of the hypergraph.\n\n Parameters\n ----------\n name : Optional[str], default=None\n The name assigned to the cloned hypergraph. If None, defaults to the original hypergraph's name suffixed with '_clone'.\n\n Returns\n -------\n Hypergraph\n A deep copy of the hypergraph.\n \"\"\"\n\n # Deep copy hyperedges\n hyperedges_copy = {\n he_id: {\n 'nodes': hyperedge.nodes.copy(),\n 'properties': copy.deepcopy(hyperedge.properties)\n }\n for he_id, hyperedge in self.hyperedges.items()\n }\n\n # Deep copy node properties\n node_properties_copy = copy.deepcopy(self.node_properties)\n\n # Deep copy hyperedge properties\n hyperedge_properties_copy = copy.deepcopy(self.hyperedge_properties)\n\n # Set name\n cloned_name = f\"{self.name}_deepcopy\" if name is None else name\n\n # Initialize the cloned hypergraph\n cloned_H = Hypergraph(\n hyperedges=hyperedges_copy,\n node_properties=node_properties_copy,\n hyperedge_properties=hyperedge_properties_copy,\n name=cloned_name\n )\n\n return cloned_H\n\n # Adding and Removing Hyperedges and Nodes\n\n def add_hyperedge(\n self,\n he_id: Any,\n nodes: Iterable[Any],\n properties: Optional[Dict[str, Any]] = None\n ) -> None:\n \"\"\"\n Adds a hyperedge to the hypergraph.\n\n Parameters\n ----------\n he_id : Any\n Unique identifier for the hyperedge.\n nodes : Iterable[Any]\n Nodes connected by the hyperedge.\n properties : Optional[Dict[str, Any]] = None\n Properties of the hyperedge.\n\n Raises\n ------\n HypergraphError\n If the hyperedge ID already exists.\n \"\"\"\n if he_id in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' already exists.\")\n\n hyperedge = HyperEdge(id=he_id, nodes=nodes, properties=properties)\n self.hyperedges[he_id] = hyperedge\n self.hyperedge_properties[he_id] = copy.deepcopy(properties) if properties else {}\n\n # Add hyperedge to bipartite graph\n self.graph.add_node(he_id, bipartite='hyperedge', **self.hyperedge_properties[he_id])\n self.bipartite_nodes.add(he_id)\n\n # Add edges between hyperedge and nodes\n for node in hyperedge.nodes:\n if node not in self.graph:\n self.graph.add_node(node, bipartite='node', **self.node_properties.get(node, {}))\n self.graph.add_edge(he_id, node)\n\n def remove_hyperedge(self, he_id: Any) -> None:\n \"\"\"\n Removes a hyperedge from the hypergraph.\n\n Parameters\n ----------\n he_id : Any\n Identifier of the hyperedge to remove.\n\n Raises\n ------\n HypergraphError\n If the hyperedge does not exist.\n \"\"\"\n if he_id not in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' does not exist.\")\n\n # Remove hyperedge from the graph, which also removes all incidences\n self.graph.remove_node(he_id)\n self.bipartite_nodes.discard(he_id)\n\n # Remove from internal structures\n del self.hyperedges[he_id]\n self.hyperedge_properties.pop(he_id, None)\n\n def add_hyperedges_from(\n self,\n hyperedges: Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]],\n inplace: bool = True\n ) -> 'Hypergraph':\n \"\"\"\n Adds multiple hyperedges with attributes to the hypergraph.\n\n Parameters\n ----------\n hyperedges : Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]]\n An iterable of hyperedge identifiers or tuples of (he_id, attributes).\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added hyperedges.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any hyperedge ID already exists.\n ValueError\n If any tuple does not contain exactly two elements or if attributes are not dictionaries.\n \"\"\"\n new_hyperedges = []\n for item in hyperedges:\n if isinstance(item, tuple):\n if len(item) != 2 or not isinstance(item[1], dict):\n raise ValueError(f\"Each tuple must be of the form (he_id, attributes). Invalid tuple: {item}\")\n he_id, attrs = item\n else:\n he_id, attrs = item, {}\n\n if he_id in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' already exists.\")\n\n hyperedge = HyperEdge(id=he_id, nodes=[], properties=attrs.copy())\n new_hyperedges.append(hyperedge)\n\n if inplace:\n for hyperedge in new_hyperedges:\n self.hyperedges[hyperedge.id] = hyperedge\n self.hyperedge_properties[hyperedge.id] = copy.deepcopy(hyperedge.properties)\n self.graph.add_node(hyperedge.id, bipartite='hyperedge', **self.hyperedge_properties[hyperedge.id])\n self.bipartite_nodes.add(hyperedge.id)\n return self\n else:\n # Create a new Hypergraph instance with added hyperedges\n new_hyperedges_dict = copy.deepcopy(self.hyperedges)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for hyperedge in new_hyperedges:\n new_hyperedges_dict[hyperedge.id] = hyperedge\n new_hyperedge_properties[hyperedge.id] = copy.deepcopy(hyperedge.properties)\n new_graph.add_node(hyperedge.id, bipartite='hyperedge', **new_hyperedge_properties[hyperedge.id])\n new_bipartite_nodes.add(hyperedge.id)\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges_dict.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n\n def add_node(\n self,\n node_id: Any,\n properties: Optional[Dict[str, Any]] = None,\n inplace: bool = True\n ) -> 'Hypergraph':\n \"\"\"\n Adds a node to the hypergraph.\n\n Parameters\n ----------\n node_id : Any\n Identifier for the node.\n properties : Optional[Dict[str, Any]] = None\n Properties of the node.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added node.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If the node ID already exists.\n \"\"\"\n if node_id in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' already exists in the hypergraph.\")\n\n if inplace:\n self.node_properties[node_id] = copy.deepcopy(properties) if properties else {}\n self.graph.add_node(node_id, bipartite='node', **self.node_properties[node_id])\n return self\n else:\n # Create a new Hypergraph instance with the added node\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n new_node_properties[node_id] = copy.deepcopy(properties) if properties else {}\n new_graph.add_node(node_id, bipartite='node', **new_node_properties[node_id])\n\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n\n def remove_node(self, node_id: Any, inplace: bool = True) -> 'Hypergraph':\n \"\"\"\n Removes a node from the hypergraph.\n\n Parameters\n ----------\n node_id : Any\n Identifier of the node to remove.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the node removed.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If the node does not exist.\n \"\"\"\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n\n if inplace:\n # Remove node from node_properties\n del self.node_properties[node_id]\n # Remove node from all hyperedges\n for hyperedge in self.hyperedges.values():\n if node_id in hyperedge.nodes:\n hyperedge.remove_node(node_id)\n # Remove node from graph, which also removes all incidences\n self.graph.remove_node(node_id)\n return self\n else:\n # Create a new Hypergraph instance with the node removed\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n # Remove node from node_properties\n del new_node_properties[node_id]\n # Remove node from all hyperedges\n for hyperedge in new_hyperedges.values():\n if node_id in hyperedge.nodes:\n hyperedge.remove_node(node_id)\n # Remove node from graph, which also removes all incidences\n new_graph.remove_node(node_id)\n\n # Remove nodes not connected to any hyperedges\n retained_nodes = set()\n for hyperedge in new_hyperedges.values():\n retained_nodes.update(hyperedge.nodes)\n\n new_node_properties = {node: props for node, props in new_node_properties.items() if node in retained_nodes}\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n\n def add_nodes_from(\n self,\n nodes: Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]],\n inplace: bool = True\n ) -> 'Hypergraph':\n \"\"\"\n Adds multiple nodes with attributes to the hypergraph.\n\n Parameters\n ----------\n nodes : Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]]\n An iterable of node identifiers or tuples of (node_id, attributes).\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added nodes.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any node ID already exists.\n ValueError\n If any tuple does not contain exactly two elements or if attributes are not dictionaries.\n \"\"\"\n new_nodes = {}\n for item in nodes:\n if isinstance(item, tuple):\n if len(item) != 2 or not isinstance(item[1], dict):\n raise ValueError(f\"Each tuple must be of the form (node_id, attributes). Invalid tuple: {item}\")\n node_id, attrs = item\n else:\n node_id, attrs = item, {}\n\n if node_id in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' already exists in the hypergraph.\")\n\n new_nodes[node_id] = copy.deepcopy(attrs)\n\n if inplace:\n for node_id, attrs in new_nodes.items():\n self.node_properties[node_id] = attrs\n self.graph.add_node(node_id, bipartite='node', **self.node_properties[node_id])\n return self\n else:\n # Create a new Hypergraph instance with the added nodes\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for node_id, attrs in new_nodes.items():\n new_node_properties[node_id] = attrs\n new_graph.add_node(node_id, bipartite='node', **new_node_properties[node_id])\n\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n\n def remove_hyperedges(self, he_ids: Union[Any, Iterable[Any]], inplace: bool = True) -> 'Hypergraph':\n \"\"\"\n Removes the specified hyperedges from the hypergraph.\n\n Parameters\n ----------\n he_ids : Any | Iterable[Any]\n Hyperedge identifier(s) to remove.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the hyperedges removed.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any hyperedge ID does not exist.\n \"\"\"\n if isinstance(he_ids, (str, int)):\n he_ids = [he_ids]\n else:\n he_ids = list(he_ids)\n\n non_existing = set(he_ids) - set(self.hyperedges.keys())\n if non_existing:\n raise HypergraphError(f\"Hyperedges {non_existing} do not exist in the hypergraph.\")\n\n if inplace:\n for he_id in he_ids:\n self.remove_hyperedge(he_id)\n return self\n else:\n # Create a new Hypergraph instance with hyperedges removed\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for he_id in he_ids:\n del new_hyperedges[he_id]\n new_hyperedge_properties.pop(he_id, None)\n new_graph.remove_node(he_id)\n new_bipartite_nodes.discard(he_id)\n\n # Remove nodes not connected to any hyperedges\n retained_nodes = set()\n for hyperedge in new_hyperedges.values():\n retained_nodes.update(hyperedge.nodes)\n\n new_node_properties = {node: props for node, props in new_node_properties.items() if node in retained_nodes}\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n\n def remove_nodes_from(\n self,\n nodes: Union[Any, Iterable[Any]],\n inplace: bool = True\n ) -> 'Hypergraph':\n \"\"\"\n Removes the specified nodes from the hypergraph.\n\n Parameters\n ----------\n nodes : Any | Iterable[Any]\n Node identifier(s) to remove.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the nodes removed.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any node ID does not exist.\n \"\"\"\n if isinstance(nodes, (str, int)):\n nodes = [nodes]\n else:\n nodes = list(nodes)\n\n non_existing = set(nodes) - set(self.node_properties.keys())\n if non_existing:\n raise HypergraphError(f\"Nodes {non_existing} do not exist in the hypergraph.\")\n\n if inplace:\n for node_id in nodes:\n self.remove_node(node_id)\n return self\n else:\n # Create a new Hypergraph instance with nodes removed\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for node_id in nodes:\n del new_node_properties[node_id]\n # Remove node from all hyperedges\n for hyperedge in new_hyperedges.values():\n if node_id in hyperedge.nodes:\n hyperedge.remove_node(node_id)\n # Remove node from graph, which also removes all incidences\n new_graph.remove_node(node_id)\n\n # Remove nodes not connected to any hyperedges\n retained_nodes = set()\n for hyperedge in new_hyperedges.values():\n retained_nodes.update(hyperedge.nodes)\n\n new_node_properties = {node: props for node, props in new_node_properties.items() if node in retained_nodes}\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n\n def add_incidence(\n self,\n he_id: Any,\n node_id: Any,\n attributes: Optional[Dict[str, Any]] = None,\n inplace: bool = True\n ) -> 'Hypergraph':\n \"\"\"\n Adds a single incidence with attributes to the hypergraph.\n\n Parameters\n ----------\n he_id : Any\n Identifier of the hyperedge.\n node_id : Any\n Identifier of the node.\n attributes : Optional[Dict[str, Any]] = None\n Properties to add to the incidence as key-value pairs.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added incidence.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If the hyperedge or node does not exist, or if the incidence already exists.\n \"\"\"\n if he_id not in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' does not exist in the hypergraph.\")\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n if node_id in self.hyperedges[he_id].nodes:\n raise HypergraphError(f\"Incidence between hyperedge '{he_id}' and node '{node_id}' already exists.\")\n\n if inplace:\n # Add node to HyperEdge's nodes\n self.hyperedges[he_id].add_node(node_id)\n # Update hyperedge_properties if attributes provided\n if attributes:\n self.hyperedge_properties[he_id].update(attributes)\n # Add edge in graph with attributes\n self.graph.add_edge(he_id, node_id, **(attributes if attributes else {}))\n return self\n else:\n # Create a new Hypergraph instance with the incidence added\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n # Add node to HyperEdge's nodes\n new_hyperedges[he_id].add_node(node_id)\n # Update hyperedge_properties if attributes provided\n if attributes:\n new_hyperedge_properties[he_id].update(attributes)\n # Add edge in graph with attributes\n new_graph.add_edge(he_id, node_id, **(attributes if attributes else {}))\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n\n def remove_incidence(\n self,\n he_id: Any,\n node_id: Any,\n inplace: bool = True\n ) -> 'Hypergraph':\n \"\"\"\n Removes a single incidence from the hypergraph.\n\n Parameters\n ----------\n he_id : Any\n Identifier of the hyperedge.\n node_id : Any\n Identifier of the node.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the incidence removed.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If the hyperedge or node does not exist, or if the incidence does not exist.\n \"\"\"\n if he_id not in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' does not exist in the hypergraph.\")\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n if node_id not in self.hyperedges[he_id].nodes:\n raise HypergraphError(f\"Incidence between hyperedge '{he_id}' and node '{node_id}' does not exist.\")\n\n if inplace:\n # Remove node from HyperEdge's nodes\n self.hyperedges[he_id].remove_node(node_id)\n # Remove edge from graph\n self.graph.remove_edge(he_id, node_id)\n return self\n else:\n # Create a new Hypergraph instance with the incidence removed\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n # Remove node from HyperEdge's nodes\n new_hyperedges[he_id].remove_node(node_id)\n # Remove edge from graph\n new_graph.remove_edge(he_id, node_id)\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n\n # Managing Properties and Incidences\n\n def adjacency_matrix(self, s: int = 1, index: bool = False) -> Tuple[Optional[csr_matrix], Dict[int, Any]]:\n \"\"\"\n Generates the adjacency matrix for nodes based on s-node connectivity.\n \"\"\"\n from scipy.sparse import lil_matrix\n\n node_ids = list(self.node_properties.keys())\n node_index = {node_id: idx for idx, node_id in enumerate(node_ids)}\n size = len(node_ids)\n if size == 0:\n return None, {}\n\n A = lil_matrix((size, size), dtype=int)\n for he in self.hyperedges.values():\n nodes = list(he.nodes)\n for i in range(len(nodes)):\n for j in range(i + 1, len(nodes)):\n A[node_index[nodes[i]], node_index[nodes[j]]] += 1\n\n # Apply the threshold s and convert to binary\n A = (A >= s).astype(int)\n A = A.tocsr()\n\n if index:\n return A, node_index\n return A, {}\n\n def hyperedge_adjacency_matrix(self, s: int = 1, index: bool = False) -> Tuple[Optional[csr_matrix], Dict[int, Any]]:\n \"\"\"\n Generates the adjacency matrix for hyperedges based on s-hyperedge connectivity.\n\n Parameters\n ----------\n s : int, optional, default=1\n The number of shared nodes required for hyperedges to be considered adjacent.\n index : bool, optional, default=False\n If True, returns a mapping from matrix indices to hyperedge IDs.\n\n Returns\n -------\n Tuple[Optional[csr_matrix], Dict[int, Any]]\n - The adjacency matrix in CSR format.\n - A dictionary mapping matrix indices to hyperedge IDs.\n \"\"\"\n from scipy.sparse import lil_matrix\n\n hyperedge_ids = list(self.hyperedges.keys())\n he_index = {he_id: idx for idx, he_id in enumerate(hyperedge_ids)}\n size = len(hyperedge_ids)\n if size == 0:\n return None, {}\n\n A = lil_matrix((size, size), dtype=int)\n for i, he1 in enumerate(hyperedge_ids):\n nodes1 = self.hyperedges[he1].nodes\n for j in range(i + 1, size):\n he2 = hyperedge_ids[j]\n nodes2 = self.hyperedges[he2].nodes\n shared_nodes = nodes1 & nodes2\n if len(shared_nodes) >= s:\n A[i, j] = 1\n A[j, i] = 1\n\n A = A.tocsr()\n\n if index:\n return A, he_index\n return A, {}\n\n def get_hyperedges_of_node(self, node_id: Any) -> Set[Any]:\n \"\"\"\n Retrieves all hyperedges that a given node is part of.\n\n Parameters\n ----------\n node_id : Any\n The node identifier.\n\n Returns\n -------\n Set[Any]\n A set of hyperedge IDs that the node belongs to.\n\n Raises\n ------\n HypergraphError\n If the node does not exist in the hypergraph.\n \"\"\"\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n return {he.id for he in self.hyperedges.values() if node_id in he.nodes}\n\n def collapse_duplicate_hyperedges(\n self,\n name: Optional[str] = None,\n use_uids: Optional[List[Any]] = None,\n use_counts: bool = False,\n return_counts: bool = True,\n return_equivalence_classes: bool = False,\n aggregate_properties_by: Optional[Dict[str, str]] = None,\n ) -> Union['Hypergraph', Tuple['Hypergraph', Dict[Any, Set[Any]]]]:\n \"\"\"\n Collapses duplicate hyperedges (hyperedges with identical node memberships) into single hyperedges.\n\n Parameters\n ----------\n name : Optional[str], default=None\n The name assigned to the collapsed hypergraph. If None, defaults to the original name suffixed with '_collapsed_hyperedges'.\n\n use_uids : Optional[List[Any]] = None\n Specifies the hyperedge identifiers to use as representatives for each equivalence class.\n If two identifiers occur in the same equivalence class, the first one found in `use_uids` is used.\n If None, the first encountered hyperedge in each class is used as the representative.\n\n use_counts : bool, optional, default=False\n If True, renames the equivalence class representatives by appending the size of the class (e.g., 'HE1:3').\n\n return_counts : bool, optional, default=True\n If True, adds the size of each equivalence class to the properties of the representative hyperedge under the key 'equivalence_class_size'.\n\n return_equivalence_classes : bool, optional, default=False\n If True, returns a tuple containing the new collapsed hypergraph and a dictionary mapping representatives to their equivalence classes.\n\n aggregate_properties_by : Optional[Dict[str, str]] = None\n A dictionary specifying aggregation methods for hyperedge properties. Keys are property names, and values are aggregation functions (e.g., {'weight': 'sum'}).\n Properties not specified will use the 'first' aggregation.\n\n Returns\n -------\n Hypergraph or Tuple[Hypergraph, Dict[Any, Set[Any]]]\n - If `return_equivalence_classes=False`, returns the new collapsed hypergraph.\n - If `return_equivalence_classes=True`, returns a tuple containing the collapsed hypergraph and a dictionary of equivalence classes.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is empty or improperly structured.\n \"\"\"\n if not self.hyperedges:\n raise HypergraphError(\"Cannot collapse hyperedges in an empty hypergraph.\")\n\n # Identify equivalence classes based on identical node memberships\n membership_to_hyperedges: Dict[frozenset, Set[Any]] = {}\n for he_id, hyperedge in self.hyperedges.items():\n key = frozenset(hyperedge.nodes)\n membership_to_hyperedges.setdefault(key, set()).add(he_id)\n\n # Filter out classes with only one hyperedge (no duplicates)\n equivalence_classes = [hes for hes in membership_to_hyperedges.values() if len(hes) > 1]\n if not equivalence_classes:\n # No duplicates to collapse; return the original hypergraph\n return self if not return_equivalence_classes else (self, {})\n\n # Prepare aggregation methods\n aggregate_properties_by = aggregate_properties_by if aggregate_properties_by is not None else {\"weight\": \"sum\"}\n\n # Initialize mapping from old hyperedges to new hyperedges\n hyperedge_mapping: Dict[Any, Any] = {}\n equivalence_class_dict: Dict[Any, Set[Any]] = {}\n\n for eq_class in equivalence_classes:\n # Determine representative\n if use_uids:\n # Select the first UID from use_uids that is in the equivalence class\n representative = next((uid for uid in use_uids if uid in eq_class), None)\n if not representative:\n # Fallback to the first hyperedge in the equivalence class\n representative = next(iter(eq_class))\n else:\n # Use the first hyperedge in the equivalence class as representative\n representative = next(iter(eq_class))\n\n # Optionally rename with counts\n if use_counts:\n new_representative = f\"{representative}:{len(eq_class)}\"\n else:\n new_representative = representative\n\n # Map all hyperedges in the class to the representative\n for he in eq_class:\n hyperedge_mapping[he] = new_representative\n\n # Store the equivalence class\n equivalence_class_dict[new_representative] = eq_class\n\n # Replace hyperedge IDs in incidences based on mapping\n new_hyperedges = {}\n for he_id, hyperedge in self.hyperedges.items():\n new_he_id = hyperedge_mapping.get(he_id, he_id)\n if new_he_id not in new_hyperedges:\n new_hyperedges[new_he_id] = HyperEdge(id=new_he_id, nodes=hyperedge.nodes.copy(), properties=copy.deepcopy(hyperedge.properties))\n else:\n new_hyperedges[new_he_id].nodes.update(hyperedge.nodes)\n\n # Aggregate hyperedge properties\n for he_id, hyperedge in new_hyperedges.items():\n if he_id in equivalence_class_dict:\n aggregated_props = {}\n for prop, agg_func in aggregate_properties_by.items():\n values = [self.hyperedge_properties[old_he].get(prop, 0) for old_he in equivalence_class_dict[he_id]]\n if agg_func == 'sum':\n aggregated_props[prop] = sum(values)\n elif agg_func == 'mean':\n aggregated_props[prop] = sum(values) / len(values) if values else 0\n elif agg_func == 'max':\n aggregated_props[prop] = max(values) if values else None\n elif agg_func == 'min':\n aggregated_props[prop] = min(values) if values else None\n else:\n aggregated_props[prop] = values[0] if values else None # Default to first\n new_hyperedges[he_id].properties.update(aggregated_props)\n\n # Handle equivalence class size\n if use_counts:\n for he_id in equivalence_class_dict:\n new_hyperedges[he_id].properties['equivalence_class_size'] = len(equivalence_class_dict[he_id])\n elif return_counts:\n for he_id in new_hyperedges:\n if he_id in equivalence_class_dict:\n new_hyperedges[he_id].properties['equivalence_class_size'] = len(equivalence_class_dict[he_id])\n else:\n new_hyperedges[he_id].properties['equivalence_class_size'] = 1\n\n # Initialize the collapsed hypergraph\n collapsed_hypergraph = Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=copy.deepcopy(self.node_properties),\n hyperedge_properties={\n he_id: copy.deepcopy(he.properties) for he_id, he in new_hyperedges.items()\n },\n name=name if name else f\"{self.name}_collapsed_hyperedges\"\n )\n\n if return_equivalence_classes:\n return collapsed_hypergraph, equivalence_class_dict\n else:\n return collapsed_hypergraph\n\n def restrict_to_specific_hyperedges(\n self,\n hyperedges_to_retain: Iterable[Any],\n name: Optional[str] = None\n ) -> 'Hypergraph':\n \"\"\"\n Creates a new hypergraph by retaining only the specified hyperedges and removing all others.\n\n Parameters\n ----------\n hyperedges_to_retain : Iterable[Any]\n An iterable of hyperedge identifiers to retain in the new hypergraph.\n\n name : Optional[str], default=None\n The name assigned to the restricted hypergraph. If None, defaults to the original name suffixed with '_restricted_hyperedges'.\n\n Returns\n -------\n Hypergraph\n A new hypergraph containing only the specified hyperedges and their associated nodes.\n\n Raises\n ------\n HypergraphError\n If none of the specified hyperedges exist in the hypergraph.\n \"\"\"\n hyperedges_to_retain = set(hyperedges_to_retain)\n existing_hyperedges = set(self.hyperedges.keys())\n invalid_hyperedges = hyperedges_to_retain - existing_hyperedges\n if invalid_hyperedges:\n raise HypergraphError(f\"The following hyperedges do not exist and cannot be retained: {invalid_hyperedges}\")\n\n # Determine hyperedges to remove\n hyperedges_to_remove = existing_hyperedges - hyperedges_to_retain\n if not hyperedges_to_remove:\n # No hyperedges to remove; return the original hypergraph\n return self\n\n # Remove hyperedges using the existing remove_hyperedges method\n restricted_hypergraph = self.remove_hyperedges(hyperedges_to_remove, inplace=False)\n restricted_hypergraph.name = name if name else f\"{self.name}_restricted_hyperedges\"\n\n return restricted_hypergraph\n\n def restrict_to_specific_nodes(\n self,\n nodes_to_retain: Iterable[Any],\n name: Optional[str] = None\n ) -> 'Hypergraph':\n \"\"\"\n Creates a new hypergraph by retaining only the specified nodes and removing all others.\n\n Parameters\n ----------\n nodes_to_retain : Iterable[Any]\n An iterable of node identifiers to retain in the new hypergraph.\n\n name : Optional[str], default=None\n The name assigned to the restricted hypergraph. If None, defaults to the original name suffixed with '_restricted_nodes'.\n\n Returns\n -------\n Hypergraph\n A new hypergraph containing only the specified nodes and their associated hyperedges.\n\n Raises\n ------\n HypergraphError\n If none of the specified nodes exist in the hypergraph.\n \"\"\"\n nodes_to_retain = set(nodes_to_retain)\n existing_nodes = set(self.node_properties.keys())\n invalid_nodes = nodes_to_retain - existing_nodes\n if invalid_nodes:\n raise HypergraphError(f\"The following nodes do not exist and cannot be retained: {invalid_nodes}\")\n\n # Determine nodes to remove\n nodes_to_remove = existing_nodes - nodes_to_retain\n if not nodes_to_remove:\n # No nodes to remove; return the original hypergraph\n return self\n\n # Remove nodes using the existing remove_nodes_from method\n restricted_hypergraph = self.remove_nodes_from(nodes_to_remove, inplace=False)\n restricted_hypergraph.name = name if name else f\"{self.name}_restricted_nodes\"\n\n return restricted_hypergraph\n\n def add_incidences_from(\n self,\n incidences: Iterable[Union[Tuple[Any, Any], Tuple[Any, Any, Dict[str, Any]]]],\n inplace: bool = True\n ) -> 'Hypergraph':\n \"\"\"\n Adds a collection of incidences to the hypergraph.\n\n Parameters\n ----------\n incidences : Iterable[Union[Tuple[Any, Any], Tuple[Any, Any, Dict[str, Any]]]]\n Incidence tuples as:\n - (he_id, node_id)\n - (he_id, node_id, attributes)\n\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added incidences.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any hyperedge or node does not exist, or if any incidence already exists.\n ValueError\n If the structure of any incidence tuple is invalid.\n \"\"\"\n new_incidences = []\n for pr in incidences:\n if not isinstance(pr, tuple):\n raise ValueError(f\"Each incidence must be a tuple, got {type(pr)}\")\n if len(pr) == 2:\n he_id, node_id = pr\n attrs = {}\n elif len(pr) == 3:\n he_id, node_id, attrs = pr\n if not isinstance(attrs, dict):\n raise ValueError(f\"Attributes must be a dictionary, got {type(attrs)}\")\n else:\n raise ValueError(f\"Incidence tuples must be of length 2 or 3, got {len(pr)}\")\n\n if he_id not in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' does not exist in the hypergraph.\")\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n if node_id in self.hyperedges[he_id].nodes:\n raise HypergraphError(f\"Incidence between hyperedge '{he_id}' and node '{node_id}' already exists.\")\n\n new_incidences.append((he_id, node_id, attrs.copy()))\n\n if inplace:\n for he_id, node_id, attrs in new_incidences:\n # Add node to HyperEdge's nodes\n self.hyperedges[he_id].add_node(node_id)\n # Update hyperedge_properties if attributes provided\n if attrs:\n self.hyperedge_properties[he_id].update(attrs)\n # Add edge in graph with attributes\n self.graph.add_edge(he_id, node_id, **(attrs if attrs else {}))\n return self\n else:\n # Create a new Hypergraph instance with the incidences added\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for he_id, node_id, attrs in new_incidences:\n # Add node to HyperEdge's nodes\n new_hyperedges[he_id].add_node(node_id)\n # Update hyperedge_properties if attributes provided\n if attrs:\n new_hyperedge_properties[he_id].update(attrs)\n # Add edge in graph with attributes\n new_graph.add_edge(he_id, node_id, **(attrs if attrs else {}))\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n\n def remove_incidences(\n self,\n incidences: Iterable[Tuple[Any, Any]],\n inplace: bool = True\n ) -> 'Hypergraph':\n \"\"\"\n Removes the specified incidences from the hypergraph.\n\n Parameters\n ----------\n incidences : Iterable[Tuple[Any, Any]]\n Incidence identifiers as tuples of (he_id, node_id).\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the incidences removed.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any incidence does not exist.\n \"\"\"\n incidence_ids = list(incidences)\n\n # Check existence of incidences\n for he_id, node_id in incidence_ids:\n if he_id not in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' does not exist in the hypergraph.\")\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n if node_id not in self.hyperedges[he_id].nodes:\n raise HypergraphError(f\"Incidence between hyperedge '{he_id}' and node '{node_id}' does not exist.\")\n\n if inplace:\n for he_id, node_id in incidence_ids:\n # Remove node from HyperEdge's nodes\n self.hyperedges[he_id].remove_node(node_id)\n # Remove edge from graph\n self.graph.remove_edge(he_id, node_id)\n return self\n else:\n # Create a new Hypergraph instance with the incidences removed\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for he_id, node_id in incidence_ids:\n # Remove node from HyperEdge's nodes\n new_hyperedges[he_id].remove_node(node_id)\n # Remove edge from graph\n new_graph.remove_edge(he_id, node_id)\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n\n def collapse_duplicate_nodes(\n self,\n name: Optional[str] = None,\n use_uids: Optional[List[Any]] = None,\n use_counts: bool = False,\n return_counts: bool = True,\n return_equivalence_classes: bool = False,\n aggregate_properties_by: Optional[Dict[str, str]] = None,\n ) -> Union['Hypergraph', Tuple['Hypergraph', Dict[Any, Set[Any]]]]:\n \"\"\"\n Collapses duplicate nodes (nodes with identical hyperedge memberships) into single nodes.\n\n Parameters\n ----------\n name : Optional[str], default=None\n The name assigned to the collapsed hypergraph. If None, defaults to the original name suffixed with '_collapsed_nodes'.\n\n use_uids : Optional[List[Any]] = None\n Specifies the node identifiers to use as representatives for each equivalence class.\n If two identifiers occur in the same equivalence class, the first one found in `use_uids` is used.\n If None, the first encountered node in each class is used as the representative.\n\n use_counts : bool, optional, default=False\n If True, renames the equivalence class representatives by appending the size of the class (e.g., 'N1:3').\n\n return_counts : bool, optional, default=True\n If True, adds the size of each equivalence class to the properties of the representative node under the key 'equivalence_class_size'.\n\n return_equivalence_classes : bool, optional, default=False\n If True, returns a tuple containing the new collapsed hypergraph and a dictionary mapping representatives to their equivalence classes.\n\n aggregate_properties_by : Optional[Dict[str, str]] = None\n A dictionary specifying aggregation methods for node properties. Keys are property names, and values are aggregation functions (e.g., {'weight': 'sum'}).\n Properties not specified will use the 'first' aggregation.\n\n Returns\n -------\n Hypergraph or Tuple[Hypergraph, Dict[Any, Set[Any]]]\n - If `return_equivalence_classes=False`, returns the new collapsed hypergraph.\n - If `return_equivalence_classes=True`, returns a tuple containing the collapsed hypergraph and a dictionary of equivalence classes.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is empty or improperly structured.\n \"\"\"\n if not self.node_properties:\n raise HypergraphError(\"Cannot collapse nodes in an empty hypergraph.\")\n\n # Identify equivalence classes based on identical hyperedge memberships\n membership_to_nodes: Dict[frozenset, Set[Any]] = {}\n for node_id, node_props in self.node_properties.items():\n key = frozenset(self.get_hyperedges_of_node(node_id))\n membership_to_nodes.setdefault(key, set()).add(node_id)\n\n # Filter out classes with only one node (no duplicates)\n equivalence_classes = [nodes for nodes in membership_to_nodes.values() if len(nodes) > 1]\n if not equivalence_classes:\n # No duplicates to collapse; return the original hypergraph\n return self if not return_equivalence_classes else (self, {})\n\n # Prepare aggregation methods\n aggregate_properties_by = aggregate_properties_by if aggregate_properties_by is not None else {\"weight\": \"sum\"}\n\n # Initialize mapping from old nodes to new nodes\n node_mapping: Dict[Any, Any] = {}\n equivalence_class_dict: Dict[Any, Set[Any]] = {}\n\n for eq_class in equivalence_classes:\n # Determine representative\n if use_uids:\n # Select the first UID from use_uids that is in the equivalence class\n representative = next((uid for uid in use_uids if uid in eq_class), None)\n if not representative:\n # Fallback to the first node in the equivalence class\n representative = next(iter(eq_class))\n else:\n # Use the first node in the equivalence class as representative\n representative = next(iter(eq_class))\n\n # Optionally rename with counts\n if use_counts:\n new_representative = f\"{representative}:{len(eq_class)}\"\n else:\n new_representative = representative\n\n # Map all nodes in the class to the representative\n for node in eq_class:\n node_mapping[node] = new_representative\n\n # Store the equivalence class\n equivalence_class_dict[new_representative] = eq_class\n\n # Replace node IDs in hyperedges based on mapping\n new_hyperedges = {}\n for he_id, hyperedge in self.hyperedges.items():\n new_nodes = set()\n for node_id in hyperedge.nodes:\n new_node_id = node_mapping.get(node_id, node_id)\n new_nodes.add(new_node_id)\n new_hyperedges[he_id] = HyperEdge(id=he_id, nodes=new_nodes, properties=copy.deepcopy(hyperedge.properties))\n\n # Aggregate node properties\n new_node_properties = {}\n for node_id, node_props in self.node_properties.items():\n new_node_id = node_mapping.get(node_id, node_id)\n if new_node_id not in new_node_properties:\n new_node_properties[new_node_id] = copy.deepcopy(node_props)\n else:\n for prop, agg_func in aggregate_properties_by.items():\n if prop in node_props:\n if agg_func == 'sum':\n new_node_properties[new_node_id][prop] = new_node_properties[new_node_id].get(prop, 0) + node_props[prop]\n elif agg_func == 'mean':\n # To calculate mean, store sum and count\n if 'sum_' + prop not in new_node_properties[new_node_id]:\n new_node_properties[new_node_id]['sum_' + prop] = node_props[prop]\n new_node_properties[new_node_id]['count_' + prop] = 1\n else:\n new_node_properties[new_node_id]['sum_' + prop] += node_props[prop]\n new_node_properties[new_node_id]['count_' + prop] += 1\n # Calculate mean at the end\n elif agg_func == 'max':\n current_max = new_node_properties[new_node_id].get(prop, float('-inf'))\n new_node_properties[new_node_id][prop] = max(current_max, node_props[prop])\n elif agg_func == 'min':\n current_min = new_node_properties[new_node_id].get(prop, float('inf'))\n new_node_properties[new_node_id][prop] = min(current_min, node_props[prop])\n else:\n new_node_properties[new_node_id][prop] = node_props[prop] # Default to last\n # Finalize mean calculations\n for node_id, props in new_node_properties.items():\n for prop in list(props.keys()):\n if prop.startswith('sum_'):\n base_prop = prop[4:]\n sum_val = props[prop]\n count_val = props.get('count_' + base_prop, 1)\n new_node_properties[node_id][base_prop] = sum_val / count_val if count_val > 0 else 0\n del new_node_properties[node_id][prop]\n del new_node_properties[node_id]['count_' + base_prop]\n\n # Handle equivalence class size\n if use_counts:\n for node_id in equivalence_class_dict:\n new_node_properties[node_id]['equivalence_class_size'] = len(equivalence_class_dict[node_id])\n elif return_counts:\n for node_id in new_node_properties:\n if node_id in equivalence_class_dict:\n new_node_properties[node_id]['equivalence_class_size'] = len(equivalence_class_dict[node_id])\n else:\n new_node_properties[node_id]['equivalence_class_size'] = 1\n\n # Initialize the collapsed hypergraph\n collapsed_hypergraph = Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties={\n he_id: copy.deepcopy(he.properties) for he_id, he in new_hyperedges.items()\n },\n name=name if name else f\"{self.name}_collapsed_nodes\"\n )\n\n if return_equivalence_classes:\n return collapsed_hypergraph, equivalence_class_dict\n else:\n return collapsed_hypergraph\n\n # Analyzing and Querying the Hypergraph\n\n def get_toplexes(self, return_hypergraph: bool = False) -> Union[List[Any], 'Hypergraph']:\n \"\"\"\n Computes a maximal collection of toplexes for the hypergraph.\n A :term:`toplex` is a hyperedge that is not contained in any other hyperedge.\n\n Parameters\n ----------\n return_hypergraph : bool, optional, default=False\n If True, returns a new Hypergraph consisting only of the toplexes.\n\n Returns\n -------\n List[Any] or Hypergraph\n - A list of toplex hyperedge IDs.\n - If `return_hypergraph=True`, returns a Hypergraph containing only the toplexes.\n \"\"\"\n toplexes = []\n hyperedges = list(self.hyperedges.values())\n\n for he in hyperedges:\n if not any(he.nodes < other_he.nodes for other_he in hyperedges if he.id != other_he.id):\n toplexes.append(he.id)\n\n if return_hypergraph:\n return self.restrict_to_specific_hyperedges(toplexes, name=\"Toplexes\")\n return toplexes\n\n def is_node_connected(self, s: int = 1) -> bool:\n \"\"\"\n Determines if the hypergraph is s-node-connected.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n\n Returns\n -------\n bool\n True if the hypergraph is s-node-connected, False otherwise.\n \"\"\"\n return self._is_connected(s=s, hyperedges=False)\n\n def is_hyperedge_connected(self, s: int = 1) -> bool:\n \"\"\"\n Determines if the hypergraph is s-hyperedge-connected.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n\n Returns\n -------\n bool\n True if the hypergraph is s-hyperedge-connected, False otherwise.\n \"\"\"\n return self._is_connected(s=s, hyperedges=True)\n\n def _is_connected(self, s: int = 1, hyperedges: bool = False) -> bool:\n \"\"\"\n Internal method to determine connectivity based on nodes or hyperedges.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n hyperedges : bool, optional, default=False\n If True, checks for s-hyperedge-connectedness. Otherwise, checks for s-node-connectedness.\n\n Returns\n -------\n bool\n Connectivity status.\n \"\"\"\n if hyperedges:\n # Create hyperedge connectivity graph: hyperedges are nodes, connect if they share >= s nodes\n hyperedge_graph = nx.Graph()\n hyperedge_ids = list(self.hyperedges.keys())\n hyperedge_graph.add_nodes_from(hyperedge_ids)\n\n for i, he1 in enumerate(hyperedge_ids):\n nodes1 = self.hyperedges[he1].nodes\n for he2 in hyperedge_ids[i+1:]:\n nodes2 = self.hyperedges[he2].nodes\n shared_nodes = nodes1 & nodes2\n if len(shared_nodes) >= s:\n hyperedge_graph.add_edge(he1, he2)\n\n try:\n return nx.is_connected(hyperedge_graph)\n except nx.NetworkXPointlessConcept:\n return False\n else:\n # Create node connectivity graph: nodes are nodes, connect if they share >= s hyperedges\n node_graph = nx.Graph()\n node_ids = list(self.node_properties.keys())\n node_graph.add_nodes_from(node_ids)\n\n for i, node1 in enumerate(node_ids):\n hyperedges1 = {he.id for he in self.hyperedges.values() if node1 in he.nodes}\n for node2 in node_ids[i+1:]:\n hyperedges2 = {he.id for he in self.hyperedges.values() if node2 in he.nodes}\n shared_hyperedges = hyperedges1 & hyperedges2\n if len(shared_hyperedges) >= s:\n node_graph.add_edge(node1, node2)\n\n try:\n return nx.is_connected(node_graph)\n except nx.NetworkXPointlessConcept:\n return False\n\n def get_node_connected_components(\n self, s: int = 1, return_singletons: bool = False\n ) -> Iterator[Set[Any]]:\n \"\"\"\n Yields the s-node-connected components of the hypergraph.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n\n Yields\n ------\n Set[Any]\n Sets of node IDs representing each connected component.\n \"\"\"\n return self.s_connected_components(s=s, hyperedges=False, return_singletons=return_singletons)\n\n def get_hyperedge_connected_components(\n self, s: int = 1, return_singletons: bool = False\n ) -> Iterator[Set[Any]]:\n \"\"\"\n Yields the s-hyperedge-connected components of the hypergraph.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n\n Yields\n ------\n Set[Any]\n Sets of hyperedge IDs representing each connected component.\n \"\"\"\n return self.s_connected_components(s=s, hyperedges=True, return_singletons=return_singletons)\n\n def get_node_connected_subgraphs(\n self, s: int = 1, return_singletons: bool = False, name: Optional[str] = None\n ) -> Iterator['Hypergraph']:\n \"\"\"\n Yields subgraphs corresponding to each s-node-connected component.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n name : Optional[str], default=None\n Base name for the subgraphs. Each subgraph will have a unique name appended.\n\n Yields\n ------\n Hypergraph\n Subgraphs representing each connected component.\n \"\"\"\n return self.s_component_subgraphs(\n s=s,\n hyperedges=False,\n return_singletons=return_singletons,\n name=name\n )\n\n def get_hyperedge_connected_subgraphs(\n self, s: int = 1, return_singletons: bool = False, name: Optional[str] = None\n ) -> Iterator['Hypergraph']:\n \"\"\"\n Yields subgraphs corresponding to each s-hyperedge-connected component.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n name : Optional[str], default=None\n Base name for the subgraphs. Each subgraph will have a unique name appended.\n\n Yields\n ------\n Hypergraph\n Subgraphs representing each connected component.\n \"\"\"\n return self.s_component_subgraphs(\n s=s,\n hyperedges=True,\n return_singletons=return_singletons,\n name=name\n )\n\n def get_singleton_hyperedges(self) -> List[Any]:\n \"\"\"\n Returns a list of singleton hyperedges.\n A singleton hyperedge is a hyperedge of size 1 where its sole node has degree 1.\n\n Returns\n -------\n List[Any]\n A list of singleton hyperedge IDs.\n \"\"\"\n singletons = []\n for he in self.hyperedges.values():\n if len(he.nodes) == 1:\n node = next(iter(he.nodes))\n node_degree = sum(1 for hyperedge in self.hyperedges.values() if node in hyperedge.nodes)\n if node_degree == 1:\n singletons.append(he.id)\n return singletons\n\n def remove_singleton_hyperedges(self, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Constructs a clone of the hypergraph with singleton hyperedges removed.\n \"\"\"\n singletons = self.get_singleton_hyperedges()\n if not singletons:\n return self.copy(name=name)\n\n new_hypergraph = self.remove_hyperedges(singletons, inplace=False)\n new_hypergraph.name = name if name else f\"{self.name}_no_singleton_hyperedges\"\n return new_hypergraph\n\n def s_connected_components(\n self, \n s: int = 1, \n hyperedges: bool = True, \n return_singletons: bool = False\n ) -> Iterator[Set[Any]]:\n \"\"\"\n Yields the s-hyperedge-connected or s-node-connected components of the hypergraph.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n hyperedges : bool, optional, default=True\n If True, yields s-hyperedge-connected components. Otherwise, yields s-node-connected components.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n\n Yields\n ------\n Set[Any]\n Sets of hyperedge IDs or node IDs representing each connected component.\n \"\"\"\n if hyperedges:\n # s-hyperedge-connected: hyperedges are connected if they share at least s nodes\n hyperedge_graph = nx.Graph()\n hyperedge_ids = list(self.hyperedges.keys())\n hyperedge_graph.add_nodes_from(hyperedge_ids)\n\n for i, he1 in enumerate(hyperedge_ids):\n nodes1 = self.hyperedges[he1].nodes\n for he2 in hyperedge_ids[i + 1:]:\n nodes2 = self.hyperedges[he2].nodes\n shared_nodes = nodes1 & nodes2\n if len(shared_nodes) >= s:\n hyperedge_graph.add_edge(he1, he2)\n\n components = nx.connected_components(hyperedge_graph)\n for component in components:\n if not return_singletons and len(component) == 1:\n continue\n yield component\n else:\n # s-node-connected: nodes are connected if they share at least s hyperedges\n node_graph = nx.Graph()\n node_ids = list(self.node_properties.keys())\n node_graph.add_nodes_from(node_ids)\n\n for i, node1 in enumerate(node_ids):\n hyperedges1 = {he.id for he in self.hyperedges.values() if node1 in he.nodes}\n for node2 in node_ids[i + 1:]:\n hyperedges2 = {he.id for he in self.hyperedges.values() if node2 in he.nodes}\n shared_hyperedges = hyperedges1 & hyperedges2\n if len(shared_hyperedges) >= s:\n node_graph.add_edge(node1, node2)\n\n components = nx.connected_components(node_graph)\n for component in components:\n if not return_singletons and len(component) == 1:\n continue\n yield component\n\n def s_component_subgraphs(\n self,\n s: int = 1,\n hyperedges: bool = True,\n return_singletons: bool = False,\n name: Optional[str] = None\n ) -> Iterator['Hypergraph']:\n \"\"\"\n Yields subgraphs corresponding to each s-hyperedge-connected or s-node-connected component.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n hyperedges : bool, optional, default=True\n If True, yields subgraphs of s-hyperedge-connected components. Otherwise, yields subgraphs of s-node-connected components.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n name : Optional[str], default=None\n Base name for the subgraphs. Each subgraph will have a unique name appended.\n\n Yields\n ------\n Hypergraph\n Subgraphs representing each connected component.\n \"\"\"\n for idx, component in enumerate(\n self.s_connected_components(s=s, hyperedges=hyperedges, return_singletons=return_singletons)\n ):\n if hyperedges:\n yield self.restrict_to_specific_hyperedges(\n hyperedges_to_retain=component, \n name=f\"{name or self.name}_component_{idx}\"\n )\n else:\n yield self.restrict_to_specific_nodes(\n nodes_to_retain=component, \n name=f\"{name or self.name}_component_{idx}\"\n )\n\n def compute_node_diameters(self, s: int = 1) -> Tuple[int, List[int], List[Set[Any]]]:\n \"\"\"\n Returns the node diameters of the connected components in the hypergraph.\n\n Parameters\n ----------\n s : int, optional, default=1\n The number of shared hyperedges required for nodes to be considered adjacent.\n\n Returns\n -------\n Tuple[int, List[int], List[Set[Any]]]\n - Maximum diameter among all connected components.\n - List of diameters for each s-node connected component.\n - List of sets, each containing node IDs in an s-node connected component.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is not s-connected or has no nodes.\n \"\"\"\n A, node_id_map = self.adjacency_matrix(s=s, index=True)\n if A is None or A.shape[0] == 0:\n raise HypergraphError(\"The hypergraph has no nodes to compute diameters.\")\n\n graph = nx.from_scipy_sparse_array(A)\n\n if not nx.is_connected(graph):\n raise HypergraphError(f\"Hypergraph is not s-node-connected. s={s}\")\n\n diams = []\n comps = []\n for component in nx.connected_components(graph):\n subgraph = graph.subgraph(component)\n if len(subgraph) == 1:\n diamc = 0 # Diameter of a single node is 0\n else:\n try:\n diamc = nx.diameter(subgraph)\n except nx.NetworkXError:\n diamc = float('inf') # Infinite diameter if the subgraph is not connected\n diams.append(diamc)\n component_nodes = {node_id_map[node] for node in component}\n comps.append(component_nodes)\n\n if not diams:\n raise HypergraphError(\"No connected components found to compute diameters.\")\n\n max_diam = max(diams)\n return max_diam, diams, comps\n\n def compute_hyperedge_diameters(self, s: int = 1) -> Tuple[int, List[int], List[Set[Any]]]:\n \"\"\"\n Returns the hyperedge diameters of the s-hyperedge-connected component subgraphs in the hypergraph.\n\n Parameters\n ----------\n s : int, optional, default=1\n The number of shared nodes required for hyperedges to be considered adjacent.\n\n Returns\n -------\n Tuple[int, List[int], List[Set[Any]]]\n - Maximum diameter among all s-hyperedge-connected components.\n - List of diameters for each s-hyperedge connected component.\n - List of sets, each containing hyperedge IDs in an s-hyperedge connected component.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is not s-hyperedge-connected or has no hyperedges.\n \"\"\"\n A, he_id_map = self.hyperedge_adjacency_matrix(s=s, index=True)\n if A is None or A.shape[0] == 0:\n raise HypergraphError(\"The hypergraph has no hyperedges to compute diameters.\")\n\n graph = nx.from_scipy_sparse_array(A)\n\n if not nx.is_connected(graph):\n raise HypergraphError(f\"Hypergraph is not s-hyperedge-connected. s={s}\")\n\n diams = []\n comps = []\n for component in nx.connected_components(graph):\n subgraph = graph.subgraph(component)\n if len(subgraph) == 1:\n diamc = 0 # Diameter of a single hyperedge is 0\n else:\n try:\n diamc = nx.diameter(subgraph)\n except nx.NetworkXError:\n diamc = float('inf') # Infinite diameter if the subgraph is not connected\n diams.append(diamc)\n component_hyperedges = {he_id_map[he] for he in component}\n comps.append(component_hyperedges)\n\n if not diams:\n raise HypergraphError(\"No connected components found to compute hyperedge diameters.\")\n\n max_diam = max(diams)\n return max_diam, diams, comps\n\n def compute_node_diameter(self, s: int = 1) -> int:\n \"\"\"\n Returns the diameter of the hypergraph based on s-node connectivity.\n\n Parameters\n ----------\n s : int, optional, default=1\n The number of shared hyperedges required for nodes to be considered adjacent.\n\n Returns\n -------\n int\n The diameter of the hypergraph.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is not s-node-connected or has no nodes.\n \"\"\"\n A, _ = self.adjacency_matrix(s=s, index=True)\n if A is None or A.shape[0] == 0:\n raise HypergraphError(\"The hypergraph has no nodes to compute diameter.\")\n\n graph = nx.from_scipy_sparse_array(A)\n if not nx.is_connected(graph):\n raise HypergraphError(f\"Hypergraph is not s-node-connected. s={s}\")\n\n try:\n return nx.diameter(graph)\n except nx.NetworkXError as e:\n raise HypergraphError(f\"Could not compute diameter: {e}\")\n\n def compute_hyperedge_diameter(self, s: int = 1) -> int:\n \"\"\"\n Returns the diameter of the hypergraph based on s-hyperedge connectivity.\n\n Parameters\n ----------\n s : int, optional, default=1\n The number of shared nodes required for hyperedges to be considered adjacent.\n\n Returns\n -------\n int\n The diameter of the hypergraph based on hyperedge connectivity.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is not s-hyperedge-connected or has no hyperedges.\n \"\"\"\n A, _ = self.hyperedge_adjacency_matrix(s=s, index=True)\n if A is None or A.shape[0] == 0:\n raise HypergraphError(\"The hypergraph has no hyperedges to compute diameter.\")\n\n graph = nx.from_scipy_sparse_array(A)\n if not nx.is_connected(graph):\n raise HypergraphError(f\"Hypergraph is not s-hyperedge-connected. s={s}\")\n\n try:\n return nx.diameter(graph)\n except nx.NetworkXError as e:\n raise HypergraphError(f\"Could not compute hyperedge diameter: {e}\")\n\n def get_node_distance(self, source: Any, target: Any, s: int = 1) -> Union[int, float]:\n \"\"\"\n Returns the shortest s-walk distance between two nodes in the hypergraph.\n\n Parameters\n ----------\n source : Any\n A node identifier in the hypergraph.\n target : Any\n A node identifier in the hypergraph.\n s : int, optional, default=1\n The number of shared hyperedges required for nodes to be considered adjacent.\n\n Returns\n -------\n Union[int, float]\n The shortest s-walk distance between the source and target nodes.\n Returns `float('inf')` if no path exists.\n\n Raises\n ------\n HypergraphError\n If either the source or target node does not exist in the hypergraph.\n \"\"\"\n if source not in self.node_properties:\n raise HypergraphError(f\"Source node '{source}' does not exist in the hypergraph.\")\n if target not in self.node_properties:\n raise HypergraphError(f\"Target node '{target}' does not exist in the hypergraph.\")\n\n A, node_id_map = self.adjacency_matrix(s=s, index=True)\n if A is None:\n raise HypergraphError(\"Adjacency matrix could not be generated.\")\n\n graph = nx.from_scipy_sparse_array(A)\n\n try:\n distance = nx.shortest_path_length(graph, source=source, target=target)\n return distance\n except (nx.NetworkXNoPath, nx.NodeNotFound):\n warnings.warn(f\"No s-walk path between '{source}' and '{target}'. Returning infinity.\")\n return float('inf')\n\n def get_hyperedge_distance(self, source: Any, target: Any, s: int = 1) -> Union[int, float]:\n \"\"\"\n Returns the shortest s-walk distance between two hyperedges in the hypergraph.\n\n Parameters\n ----------\n source : Any\n A hyperedge identifier in the hypergraph.\n target : Any\n A hyperedge identifier in the hypergraph.\n s : int, optional, default=1\n The number of shared nodes required for hyperedges to be considered adjacent.\n\n Returns\n -------\n Union[int, float]\n The shortest s-walk distance between the source and target hyperedges.\n Returns `float('inf')` if no path exists.\n\n Raises\n ------\n HypergraphError\n If either the source or target hyperedge does not exist in the hypergraph.\n \"\"\"\n if source not in self.hyperedges:\n raise HypergraphError(f\"Source hyperedge '{source}' does not exist in the hypergraph.\")\n if target not in self.hyperedges:\n raise HypergraphError(f\"Target hyperedge '{target}' does not exist in the hypergraph.\")\n\n A, he_id_map = self.hyperedge_adjacency_matrix(s=s, index=True)\n if A is None:\n raise HypergraphError(\"Hyperedge adjacency matrix could not be generated.\")\n\n graph = nx.from_scipy_sparse_array(A)\n\n try:\n distance = nx.shortest_path_length(graph, source=source, target=target)\n return distance\n except (nx.NetworkXNoPath, nx.NodeNotFound):\n warnings.warn(f\"No s-walk path between hyperedges '{source}' and '{target}'. Returning infinity.\")\n return float('inf')\n\n # Advanced Operations and Transformations\n\n def union(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Returns the union of the current hypergraph with another hypergraph.\n The union combines all nodes and hyperedges from both hypergraphs.\n\n Parameters\n ----------\n other : Hypergraph\n The hypergraph to union with.\n inplace : bool, optional, default=False\n If True, modifies the current hypergraph. Otherwise, returns a new Hypergraph instance.\n name : Optional[str], default=None\n The name for the resulting hypergraph. If None, defaults to 'Union_of_{self.name}_{other.name}'.\n\n Returns\n -------\n Hypergraph\n The resulting union hypergraph.\n\n Raises\n ------\n TypeError\n If `other` is not an instance of Hypergraph.\n \"\"\"\n if not isinstance(other, Hypergraph):\n raise TypeError(\"The `other` parameter must be an instance of Hypergraph.\")\n\n if inplace:\n # Add nodes from other\n for node_id, props in other.node_properties.items():\n if node_id not in self.node_properties:\n self.add_node(node_id, properties=props, inplace=True)\n else:\n # Optionally, merge properties\n self.node_properties[node_id].update(props)\n self.graph.nodes[node_id].update(props)\n\n # Add hyperedges from other\n for he_id, hyperedge in other.hyperedges.items():\n if he_id not in self.hyperedges:\n self.add_hyperedge(he_id, hyperedge.nodes, properties=hyperedge.properties)\n else:\n # Optionally, merge properties and nodes\n self.hyperedges[he_id].nodes.update(hyperedge.nodes)\n self.hyperedge_properties[he_id].update(hyperedge.properties)\n for node in hyperedge.nodes:\n if node not in self.graph:\n self.add_node(node)\n self.graph.add_edge(he_id, node)\n if name:\n self.name = name\n return self\n else:\n # Create a new Hypergraph instance\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n new_name = name if name else f\"Union_of_{self.name}_{other.name}\"\n\n # Add nodes from other\n for node_id, props in other.node_properties.items():\n if node_id not in new_node_properties:\n new_node_properties[node_id] = copy.deepcopy(props)\n new_graph.add_node(node_id, bipartite='node', **props)\n\n # Add hyperedges from other\n for he_id, hyperedge in other.hyperedges.items():\n if he_id not in new_hyperedges:\n new_hyperedges[he_id] = copy.deepcopy(hyperedge)\n new_hyperedge_properties[he_id] = copy.deepcopy(other.hyperedge_properties[he_id])\n new_graph.add_node(he_id, bipartite='hyperedge', **new_hyperedge_properties[he_id])\n new_bipartite_nodes.add(he_id)\n for node in hyperedge.nodes:\n new_graph.add_edge(he_id, node)\n else:\n # Merge nodes and properties\n new_hyperedges[he_id].nodes.update(hyperedge.nodes)\n new_hyperedge_properties[he_id].update(other.hyperedge_properties[he_id])\n for node in hyperedge.nodes:\n new_graph.add_edge(he_id, node)\n\n # Construct the new Hypergraph\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=new_name\n )\n\n def intersection(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Returns the intersection of the current hypergraph with another hypergraph.\n The intersection includes only nodes and hyperedges present in both hypergraphs.\n\n Parameters\n ----------\n other : Hypergraph\n The hypergraph to intersect with.\n inplace : bool, optional, default=False\n If True, modifies the current hypergraph to keep only the intersecting elements.\n Otherwise, returns a new Hypergraph instance.\n name : Optional[str], default=None\n The name for the resulting hypergraph. If None, defaults to 'Intersection_of_{self.name}_{other.name}'.\n\n Returns\n -------\n Hypergraph\n The resulting intersection hypergraph.\n\n Raises\n ------\n TypeError\n If `other` is not an instance of Hypergraph.\n \"\"\"\n if not isinstance(other, Hypergraph):\n raise TypeError(\"The `other` parameter must be an instance of Hypergraph.\")\n\n intersect_nodes = set(self.node_properties.keys()) & set(other.node_properties.keys())\n intersect_hyperedges = set(self.hyperedges.keys()) & set(other.hyperedges.keys())\n\n if inplace:\n # Remove non-intersecting nodes and hyperedges\n nodes_to_remove = set(self.node_properties.keys()) - intersect_nodes\n hyperedges_to_remove = set(self.hyperedges.keys()) - intersect_hyperedges\n self.remove_nodes_from(nodes_to_remove, inplace=True)\n self.remove_hyperedges(hyperedges_to_remove, inplace=True)\n return self\n else:\n # Create a new Hypergraph instance\n new_hyperedges = {}\n new_node_properties = {node_id: copy.deepcopy(self.node_properties[node_id]) for node_id in intersect_nodes}\n new_hyperedge_properties = {}\n new_graph = nx.Graph()\n new_bipartite_nodes = set()\n\n for he_id in intersect_hyperedges:\n he_self = self.hyperedges[he_id]\n he_other = other.hyperedges[he_id]\n # Intersection hyperedges have the same nodes and merged properties\n new_nodes = set(he_self.nodes) & set(he_other.nodes)\n if not new_nodes:\n continue # Skip hyperedges with no common nodes\n new_hyperedges[he_id] = HyperEdge(id=he_id, nodes=new_nodes, properties={})\n # Merge properties (could define specific rules)\n new_hyperedge_properties[he_id] = {**self.hyperedge_properties.get(he_id, {}), \n **other.hyperedge_properties.get(he_id, {})}\n new_graph.add_node(he_id, bipartite='hyperedge', **new_hyperedge_properties[he_id])\n new_bipartite_nodes.add(he_id)\n for node in new_nodes:\n new_graph.add_edge(he_id, node)\n\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=name if name else f\"Intersection_of_{self.name}_{other.name}\"\n )\n\n def difference(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Returns the difference of the current hypergraph with another hypergraph.\n The difference includes nodes and hyperedges present in the current hypergraph but not in the other.\n\n Parameters\n ----------\n other : Hypergraph\n The hypergraph to subtract.\n inplace : bool, optional, default=False\n If True, modifies the current hypergraph by removing elements found in `other`.\n Otherwise, returns a new Hypergraph instance.\n name : Optional[str], default=None\n The name for the resulting hypergraph. If None, defaults to 'Difference_of_{self.name}_{other.name}'.\n\n Returns\n -------\n Hypergraph\n The resulting difference hypergraph.\n\n Raises\n ------\n TypeError\n If `other` is not an instance of Hypergraph.\n \"\"\"\n if not isinstance(other, Hypergraph):\n raise TypeError(\"The `other` parameter must be an instance of Hypergraph.\")\n\n if inplace:\n # Remove hyperedges present in other\n hyperedges_to_remove = set(self.hyperedges.keys()) & set(other.hyperedges.keys())\n self.remove_hyperedges(hyperedges_to_remove, inplace=True)\n # Remove nodes present in other\n nodes_to_remove = set(self.node_properties.keys()) & set(other.node_properties.keys())\n self.remove_nodes_from(nodes_to_remove, inplace=True)\n return self\n else:\n # Create a new Hypergraph instance\n new_hyperedges = {he_id: copy.deepcopy(he) for he_id, he in self.hyperedges.items() if he_id not in other.hyperedges}\n new_hyperedge_properties = {he_id: copy.deepcopy(props) for he_id, props in self.hyperedge_properties.items() if he_id not in other.hyperedges}\n new_node_properties = {node_id: copy.deepcopy(props) for node_id, props in self.node_properties.items() if node_id not in other.node_properties}\n\n # Reconstruct graph\n new_graph = nx.Graph()\n new_bipartite_nodes = set()\n for he_id, hyperedge in new_hyperedges.items():\n new_graph.add_node(he_id, bipartite='hyperedge', **new_hyperedge_properties[he_id])\n new_bipartite_nodes.add(he_id)\n for node in hyperedge.nodes:\n if node in new_node_properties:\n new_graph.add_edge(he_id, node)\n\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=name if name else f\"Difference_of_{self.name}_{other.name}\"\n )\n\n def symmetric_difference(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Returns the symmetric difference of the current hypergraph with another hypergraph.\n The symmetric difference includes elements present in either hypergraph but not in both.\n\n Parameters\n ----------\n other : Hypergraph\n The hypergraph to symmetric difference with.\n inplace : bool, optional, default=False\n If True, modifies the current hypergraph to keep only the symmetric difference elements.\n Otherwise, returns a new Hypergraph instance.\n name : Optional[str], default=None\n The name for the resulting hypergraph. If None, defaults to 'SymmetricDifference_of_{self.name}_{other.name}'.\n\n Returns\n -------\n Hypergraph\n The resulting symmetric difference hypergraph.\n\n Raises\n ------\n TypeError\n If `other` is not an instance of Hypergraph.\n \"\"\"\n if not isinstance(other, Hypergraph):\n raise TypeError(\"The `other` parameter must be an instance of Hypergraph.\")\n\n if inplace:\n # Hyperedges symmetric difference\n hyperedges_to_add = set(other.hyperedges.keys()) - set(self.hyperedges.keys())\n hyperedges_to_remove = set(self.hyperedges.keys()) & set(other.hyperedges.keys())\n self.remove_hyperedges(hyperedges_to_remove, inplace=True)\n for he_id in hyperedges_to_add:\n hyperedge = other.hyperedges[he_id]\n self.add_hyperedge(he_id, hyperedge.nodes, properties=hyperedge.properties)\n\n # Nodes symmetric difference\n nodes_to_add = set(other.node_properties.keys()) - set(self.node_properties.keys())\n nodes_to_remove = set(self.node_properties.keys()) & set(other.node_properties.keys())\n self.remove_nodes_from(nodes_to_remove, inplace=True)\n for node_id in nodes_to_add:\n props = other.node_properties[node_id]\n self.add_node(node_id, properties=props, inplace=True)\n\n if name:\n self.name = name\n return self\n else:\n # Create a new Hypergraph instance\n union_hg = self.union(other)\n intersection_hg = self.intersection(other)\n return union_hg.difference(intersection_hg, name=name if name else f\"SymmetricDifference_of_{self.name}_{other.name}\")\n\n def transpose(self, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Transposes the hypergraph by swapping the roles of nodes and hyperedges.\n The resulting hypergraph has hyperedges corresponding to the original nodes and vice versa.\n\n Parameters\n ----------\n name : Optional[str], default=None\n The name assigned to the transposed hypergraph. If None, defaults to the original name suffixed with '_transposed'.\n\n Returns\n -------\n Hypergraph\n The transposed hypergraph.\n \"\"\"\n transposed_hyperedges = {node_id: HyperEdge(id=node_id, nodes=set(), properties=copy.deepcopy(props))\n for node_id, props in self.node_properties.items()}\n transposed_node_properties = {he_id: copy.deepcopy(props) for he_id, props in self.hyperedge_properties.items()}\n\n for he_id, hyperedge in self.hyperedges.items():\n for node in hyperedge.nodes:\n if node in transposed_hyperedges:\n transposed_hyperedges[node].nodes.add(he_id)\n\n # Construct the transposed hypergraph\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in transposed_hyperedges.items()\n },\n node_properties=transposed_node_properties,\n hyperedge_properties={he_id: he.properties.copy() for he_id, he in transposed_hyperedges.items()},\n name=name if name else f\"{self.name}_transposed\"\n )\n\n def to_bipartite_graph(self, keep_data=False, directed=False) -> nx.Graph:\n \"\"\"\n Creates a bipartite NetworkX graph from the hypergraph.\n The nodes and hyperedges of the hypergraph become nodes in the bipartite graph.\n For every hyperedge in the hypergraph and each node it connects to, there\n is an edge in the bipartite graph.\n\n Parameters\n ----------\n keep_data : bool, optional, default = False\n If True, includes the node and hyperedge properties in the NetworkX graph.\n directed : bool, optional, default = False\n If True, the edges in the graph are directed with hyperedges as sources and nodes as targets.\n\n Returns\n -------\n networkx.Graph or networkx.DiGraph\n The bipartite graph representation of the hypergraph.\n \"\"\"\n # Choose graph type based on directed flag\n B = nx.DiGraph() if directed else nx.Graph()\n\n if not keep_data:\n # Add nodes with bipartite attributes, where 0 indicates hyperedges and 1 indicates regular nodes\n B.add_nodes_from(self.hyperedges.keys(), bipartite=0) # hyperedges\n B.add_nodes_from(self.node_properties.keys(), bipartite=1) # nodes\n\n # Add edges between hyperedges and nodes based on hyperedges data\n for he_id, hyperedge in self.hyperedges.items():\n for node in hyperedge.nodes:\n B.add_edge(he_id, node)\n else:\n # Add nodes with properties if keep_data is True\n for node_id, properties in self.node_properties.items():\n B.add_node(node_id, bipartite=1, **properties)\n\n for he_id, hyperedge in self.hyperedges.items():\n B.add_node(he_id, bipartite=0, **self.hyperedge_properties.get(he_id, {}))\n for node in hyperedge.nodes:\n # Add edges with optional properties if keep_data is True\n B.add_edge(he_id, node)\n\n return B\n\n @classmethod\n def from_bipartite_graph(cls, bipartite_graph: nx.Graph, hyperedge_prefix: str = \"HE\", node_prefix: str = \"N\", name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Constructs a Hypergraph instance from a bipartite graph.\n\n Parameters\n ----------\n bipartite_graph : nx.Graph\n A bipartite graph where one set of nodes represents hyperedges and the other represents regular nodes.\n hyperedge_prefix : str, optional, default=\"HE\"\n The prefix to identify hyperedge nodes in the bipartite graph.\n node_prefix : str, optional, default=\"N\"\n The prefix to identify regular nodes in the bipartite graph.\n name : Optional[str], default=None\n The name assigned to the new Hypergraph. If None, defaults to 'FromBipartiteGraph'.\n\n Returns\n -------\n Hypergraph\n The constructed Hypergraph instance.\n\n Raises\n ------\n ValueError\n If the bipartite graph does not contain two distinct sets of nodes identifiable by the provided prefixes.\n \"\"\"\n hyperedges = {}\n node_properties = {}\n hyperedge_properties = {}\n name = name if name else \"FromBipartiteGraph\"\n\n for node in bipartite_graph.nodes(data=True):\n node_id, attrs = node\n if node_id.startswith(hyperedge_prefix):\n # It's a hyperedge\n hyperedges[node_id] = HyperEdge(id=node_id, nodes=set(), properties=attrs)\n hyperedge_properties[node_id] = copy.deepcopy(attrs)\n elif node_id.startswith(node_prefix):\n # It's a regular node\n node_properties[node_id] = copy.deepcopy(attrs)\n else:\n raise ValueError(f\"Node '{node_id}' does not start with either hyperedge_prefix '{hyperedge_prefix}' or node_prefix '{node_prefix}'.\")\n\n # Assign nodes to hyperedges based on edges in bipartite graph\n for he_id in hyperedges:\n connected_nodes = set(bipartite_graph.neighbors(he_id))\n hyperedges[he_id].nodes = connected_nodes\n\n # Construct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in hyperedges.items()\n }\n\n return cls(\n hyperedges=hyperedges_dict,\n node_properties=node_properties,\n hyperedge_properties=hyperedge_properties,\n name=name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__contains__","title":"__contains__(item)
","text":"Checks if a node is in the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__contains__--parameters","title":"Parameters","text":"item : Any The node identifier to check.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__contains__--returns","title":"Returns","text":"bool True if the node exists in the hypergraph, False otherwise.
Source code in src/aeiva/hypergraph/hypergraph.py
def __contains__(self, item: Any) -> bool:\n \"\"\"\n Checks if a node is in the hypergraph.\n\n Parameters\n ----------\n item : Any\n The node identifier to check.\n\n Returns\n -------\n bool\n True if the node exists in the hypergraph, False otherwise.\n \"\"\"\n return item in self.node_properties\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__eq__","title":"__eq__(other)
","text":"Checks if two hypergraphs are equal based on their hyperedges and nodes.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__eq__--parameters","title":"Parameters","text":"other : Any The other object to compare.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__eq__--returns","title":"Returns","text":"bool True if both hypergraphs have identical nodes and hyperedges with the same properties, False otherwise.
Source code in src/aeiva/hypergraph/hypergraph.py
def __eq__(self, other: Any) -> bool:\n \"\"\"\n Checks if two hypergraphs are equal based on their hyperedges and nodes.\n\n Parameters\n ----------\n other : Any\n The other object to compare.\n\n Returns\n -------\n bool\n True if both hypergraphs have identical nodes and hyperedges with the same properties, False otherwise.\n \"\"\"\n if not isinstance(other, Hypergraph):\n return False\n\n # Compare nodes and their properties\n if self.node_properties != other.node_properties:\n return False\n\n # Compare hyperedges and their properties\n if self.hyperedges.keys() != other.hyperedges.keys():\n return False\n\n for he_id in self.hyperedges:\n if self.hyperedges[he_id].nodes != other.hyperedges[he_id].nodes:\n return False\n if self.hyperedge_properties.get(he_id, {}) != other.hyperedge_properties.get(he_id, {}):\n return False\n\n return True\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__getitem__","title":"__getitem__(node)
","text":"Retrieves the neighbors of a node in the hypergraph.
Neighbors are nodes that share at least one hyperedge with the given node.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__getitem__--parameters","title":"Parameters","text":"node : Any The node identifier.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__getitem__--returns","title":"Returns","text":"Iterable[Any] An iterator over neighboring node identifiers.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__getitem__--raises","title":"Raises","text":"HypergraphError If the node does not exist in the hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def __getitem__(self, node: Any) -> Iterable[Any]:\n \"\"\"\n Retrieves the neighbors of a node in the hypergraph.\n\n Neighbors are nodes that share at least one hyperedge with the given node.\n\n Parameters\n ----------\n node : Any\n The node identifier.\n\n Returns\n -------\n Iterable[Any]\n An iterator over neighboring node identifiers.\n\n Raises\n ------\n HypergraphError\n If the node does not exist in the hypergraph.\n \"\"\"\n if node not in self.node_properties:\n raise HypergraphError(f\"Node '{node}' does not exist in the hypergraph.\")\n\n # Get all hyperedges that include the node\n hyperedges = set(self.graph.neighbors(node))\n\n # Get all nodes connected by these hyperedges\n neighbors = set()\n for he_id in hyperedges:\n neighbors.update(self.hyperedges[he_id].nodes)\n\n neighbors.discard(node) # Remove the node itself\n return neighbors\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__iter__","title":"__iter__()
","text":"Allows iteration over the nodes of the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__iter__--yields","title":"Yields","text":"Any Node identifiers.
Source code in src/aeiva/hypergraph/hypergraph.py
def __iter__(self) -> Iterator[Any]:\n \"\"\"\n Allows iteration over the nodes of the hypergraph.\n\n Yields\n ------\n Any\n Node identifiers.\n \"\"\"\n return iter(self.node_properties)\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__len__","title":"__len__()
","text":"Returns the number of nodes in the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__len__--returns","title":"Returns","text":"int Number of nodes.
Source code in src/aeiva/hypergraph/hypergraph.py
def __len__(self) -> int:\n \"\"\"\n Returns the number of nodes in the hypergraph.\n\n Returns\n -------\n int\n Number of nodes.\n \"\"\"\n return len(self.node_properties)\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__repr__","title":"__repr__()
","text":"Official string representation of the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__repr__--returns","title":"Returns","text":"str A detailed string describing the hypergraph with its name, number of nodes, and hyperedges.
Source code in src/aeiva/hypergraph/hypergraph.py
def __repr__(self) -> str:\n \"\"\"\n Official string representation of the hypergraph.\n\n Returns\n -------\n str\n A detailed string describing the hypergraph with its name, number of nodes, and hyperedges.\n \"\"\"\n return (\n f\"Hypergraph(name={self.name!r}, \"\n f\"nodes={len(self)}, hyperedges={len(self.hyperedges)})\"\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__str__","title":"__str__()
","text":"String representation of the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__str__--returns","title":"Returns","text":"str A string describing the hypergraph with its name, number of nodes, and hyperedges.
Source code in src/aeiva/hypergraph/hypergraph.py
def __str__(self) -> str:\n \"\"\"\n String representation of the hypergraph.\n\n Returns\n -------\n str\n A string describing the hypergraph with its name, number of nodes, and hyperedges.\n \"\"\"\n return f\"Hypergraph '{self.name}' with {len(self)} nodes and {len(self.hyperedges)} hyperedges.\"\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_hyperedge","title":"add_hyperedge(he_id, nodes, properties=None)
","text":"Adds a hyperedge to the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_hyperedge--parameters","title":"Parameters","text":"he_id : Any Unique identifier for the hyperedge. nodes : Iterable[Any] Nodes connected by the hyperedge. properties : Optional[Dict[str, Any]] = None Properties of the hyperedge.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_hyperedge--raises","title":"Raises","text":"HypergraphError If the hyperedge ID already exists.
Source code in src/aeiva/hypergraph/hypergraph.py
def add_hyperedge(\n self,\n he_id: Any,\n nodes: Iterable[Any],\n properties: Optional[Dict[str, Any]] = None\n) -> None:\n \"\"\"\n Adds a hyperedge to the hypergraph.\n\n Parameters\n ----------\n he_id : Any\n Unique identifier for the hyperedge.\n nodes : Iterable[Any]\n Nodes connected by the hyperedge.\n properties : Optional[Dict[str, Any]] = None\n Properties of the hyperedge.\n\n Raises\n ------\n HypergraphError\n If the hyperedge ID already exists.\n \"\"\"\n if he_id in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' already exists.\")\n\n hyperedge = HyperEdge(id=he_id, nodes=nodes, properties=properties)\n self.hyperedges[he_id] = hyperedge\n self.hyperedge_properties[he_id] = copy.deepcopy(properties) if properties else {}\n\n # Add hyperedge to bipartite graph\n self.graph.add_node(he_id, bipartite='hyperedge', **self.hyperedge_properties[he_id])\n self.bipartite_nodes.add(he_id)\n\n # Add edges between hyperedge and nodes\n for node in hyperedge.nodes:\n if node not in self.graph:\n self.graph.add_node(node, bipartite='node', **self.node_properties.get(node, {}))\n self.graph.add_edge(he_id, node)\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_hyperedges_from","title":"add_hyperedges_from(hyperedges, inplace=True)
","text":"Adds multiple hyperedges with attributes to the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_hyperedges_from--parameters","title":"Parameters","text":"hyperedges : Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]] An iterable of hyperedge identifiers or tuples of (he_id, attributes). inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added hyperedges.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_hyperedges_from--returns","title":"Returns","text":"Hypergraph The updated or new Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_hyperedges_from--raises","title":"Raises","text":"HypergraphError If any hyperedge ID already exists. ValueError If any tuple does not contain exactly two elements or if attributes are not dictionaries.
Source code in src/aeiva/hypergraph/hypergraph.py
def add_hyperedges_from(\n self,\n hyperedges: Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]],\n inplace: bool = True\n) -> 'Hypergraph':\n \"\"\"\n Adds multiple hyperedges with attributes to the hypergraph.\n\n Parameters\n ----------\n hyperedges : Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]]\n An iterable of hyperedge identifiers or tuples of (he_id, attributes).\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added hyperedges.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any hyperedge ID already exists.\n ValueError\n If any tuple does not contain exactly two elements or if attributes are not dictionaries.\n \"\"\"\n new_hyperedges = []\n for item in hyperedges:\n if isinstance(item, tuple):\n if len(item) != 2 or not isinstance(item[1], dict):\n raise ValueError(f\"Each tuple must be of the form (he_id, attributes). Invalid tuple: {item}\")\n he_id, attrs = item\n else:\n he_id, attrs = item, {}\n\n if he_id in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' already exists.\")\n\n hyperedge = HyperEdge(id=he_id, nodes=[], properties=attrs.copy())\n new_hyperedges.append(hyperedge)\n\n if inplace:\n for hyperedge in new_hyperedges:\n self.hyperedges[hyperedge.id] = hyperedge\n self.hyperedge_properties[hyperedge.id] = copy.deepcopy(hyperedge.properties)\n self.graph.add_node(hyperedge.id, bipartite='hyperedge', **self.hyperedge_properties[hyperedge.id])\n self.bipartite_nodes.add(hyperedge.id)\n return self\n else:\n # Create a new Hypergraph instance with added hyperedges\n new_hyperedges_dict = copy.deepcopy(self.hyperedges)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for hyperedge in new_hyperedges:\n new_hyperedges_dict[hyperedge.id] = hyperedge\n new_hyperedge_properties[hyperedge.id] = copy.deepcopy(hyperedge.properties)\n new_graph.add_node(hyperedge.id, bipartite='hyperedge', **new_hyperedge_properties[hyperedge.id])\n new_bipartite_nodes.add(hyperedge.id)\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges_dict.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_incidence","title":"add_incidence(he_id, node_id, attributes=None, inplace=True)
","text":"Adds a single incidence with attributes to the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_incidence--parameters","title":"Parameters","text":"he_id : Any Identifier of the hyperedge. node_id : Any Identifier of the node. attributes : Optional[Dict[str, Any]] = None Properties to add to the incidence as key-value pairs. inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added incidence.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_incidence--returns","title":"Returns","text":"Hypergraph The updated or new Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_incidence--raises","title":"Raises","text":"HypergraphError If the hyperedge or node does not exist, or if the incidence already exists.
Source code in src/aeiva/hypergraph/hypergraph.py
def add_incidence(\n self,\n he_id: Any,\n node_id: Any,\n attributes: Optional[Dict[str, Any]] = None,\n inplace: bool = True\n) -> 'Hypergraph':\n \"\"\"\n Adds a single incidence with attributes to the hypergraph.\n\n Parameters\n ----------\n he_id : Any\n Identifier of the hyperedge.\n node_id : Any\n Identifier of the node.\n attributes : Optional[Dict[str, Any]] = None\n Properties to add to the incidence as key-value pairs.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added incidence.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If the hyperedge or node does not exist, or if the incidence already exists.\n \"\"\"\n if he_id not in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' does not exist in the hypergraph.\")\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n if node_id in self.hyperedges[he_id].nodes:\n raise HypergraphError(f\"Incidence between hyperedge '{he_id}' and node '{node_id}' already exists.\")\n\n if inplace:\n # Add node to HyperEdge's nodes\n self.hyperedges[he_id].add_node(node_id)\n # Update hyperedge_properties if attributes provided\n if attributes:\n self.hyperedge_properties[he_id].update(attributes)\n # Add edge in graph with attributes\n self.graph.add_edge(he_id, node_id, **(attributes if attributes else {}))\n return self\n else:\n # Create a new Hypergraph instance with the incidence added\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n # Add node to HyperEdge's nodes\n new_hyperedges[he_id].add_node(node_id)\n # Update hyperedge_properties if attributes provided\n if attributes:\n new_hyperedge_properties[he_id].update(attributes)\n # Add edge in graph with attributes\n new_graph.add_edge(he_id, node_id, **(attributes if attributes else {}))\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_incidences_from","title":"add_incidences_from(incidences, inplace=True)
","text":"Adds a collection of incidences to the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_incidences_from--parameters","title":"Parameters","text":"incidences : Iterable[Union[Tuple[Any, Any], Tuple[Any, Any, Dict[str, Any]]]] Incidence tuples as: - (he_id, node_id) - (he_id, node_id, attributes)
bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added incidences.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_incidences_from--returns","title":"Returns","text":"Hypergraph The updated or new Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_incidences_from--raises","title":"Raises","text":"HypergraphError If any hyperedge or node does not exist, or if any incidence already exists. ValueError If the structure of any incidence tuple is invalid.
Source code in src/aeiva/hypergraph/hypergraph.py
def add_incidences_from(\n self,\n incidences: Iterable[Union[Tuple[Any, Any], Tuple[Any, Any, Dict[str, Any]]]],\n inplace: bool = True\n) -> 'Hypergraph':\n \"\"\"\n Adds a collection of incidences to the hypergraph.\n\n Parameters\n ----------\n incidences : Iterable[Union[Tuple[Any, Any], Tuple[Any, Any, Dict[str, Any]]]]\n Incidence tuples as:\n - (he_id, node_id)\n - (he_id, node_id, attributes)\n\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added incidences.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any hyperedge or node does not exist, or if any incidence already exists.\n ValueError\n If the structure of any incidence tuple is invalid.\n \"\"\"\n new_incidences = []\n for pr in incidences:\n if not isinstance(pr, tuple):\n raise ValueError(f\"Each incidence must be a tuple, got {type(pr)}\")\n if len(pr) == 2:\n he_id, node_id = pr\n attrs = {}\n elif len(pr) == 3:\n he_id, node_id, attrs = pr\n if not isinstance(attrs, dict):\n raise ValueError(f\"Attributes must be a dictionary, got {type(attrs)}\")\n else:\n raise ValueError(f\"Incidence tuples must be of length 2 or 3, got {len(pr)}\")\n\n if he_id not in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' does not exist in the hypergraph.\")\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n if node_id in self.hyperedges[he_id].nodes:\n raise HypergraphError(f\"Incidence between hyperedge '{he_id}' and node '{node_id}' already exists.\")\n\n new_incidences.append((he_id, node_id, attrs.copy()))\n\n if inplace:\n for he_id, node_id, attrs in new_incidences:\n # Add node to HyperEdge's nodes\n self.hyperedges[he_id].add_node(node_id)\n # Update hyperedge_properties if attributes provided\n if attrs:\n self.hyperedge_properties[he_id].update(attrs)\n # Add edge in graph with attributes\n self.graph.add_edge(he_id, node_id, **(attrs if attrs else {}))\n return self\n else:\n # Create a new Hypergraph instance with the incidences added\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for he_id, node_id, attrs in new_incidences:\n # Add node to HyperEdge's nodes\n new_hyperedges[he_id].add_node(node_id)\n # Update hyperedge_properties if attributes provided\n if attrs:\n new_hyperedge_properties[he_id].update(attrs)\n # Add edge in graph with attributes\n new_graph.add_edge(he_id, node_id, **(attrs if attrs else {}))\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_node","title":"add_node(node_id, properties=None, inplace=True)
","text":"Adds a node to the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_node--parameters","title":"Parameters","text":"node_id : Any Identifier for the node. properties : Optional[Dict[str, Any]] = None Properties of the node. inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added node.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_node--returns","title":"Returns","text":"Hypergraph The updated or new Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_node--raises","title":"Raises","text":"HypergraphError If the node ID already exists.
Source code in src/aeiva/hypergraph/hypergraph.py
def add_node(\n self,\n node_id: Any,\n properties: Optional[Dict[str, Any]] = None,\n inplace: bool = True\n) -> 'Hypergraph':\n \"\"\"\n Adds a node to the hypergraph.\n\n Parameters\n ----------\n node_id : Any\n Identifier for the node.\n properties : Optional[Dict[str, Any]] = None\n Properties of the node.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added node.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If the node ID already exists.\n \"\"\"\n if node_id in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' already exists in the hypergraph.\")\n\n if inplace:\n self.node_properties[node_id] = copy.deepcopy(properties) if properties else {}\n self.graph.add_node(node_id, bipartite='node', **self.node_properties[node_id])\n return self\n else:\n # Create a new Hypergraph instance with the added node\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n new_node_properties[node_id] = copy.deepcopy(properties) if properties else {}\n new_graph.add_node(node_id, bipartite='node', **new_node_properties[node_id])\n\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_nodes_from","title":"add_nodes_from(nodes, inplace=True)
","text":"Adds multiple nodes with attributes to the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_nodes_from--parameters","title":"Parameters","text":"nodes : Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]] An iterable of node identifiers or tuples of (node_id, attributes). inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added nodes.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_nodes_from--returns","title":"Returns","text":"Hypergraph The updated or new Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_nodes_from--raises","title":"Raises","text":"HypergraphError If any node ID already exists. ValueError If any tuple does not contain exactly two elements or if attributes are not dictionaries.
Source code in src/aeiva/hypergraph/hypergraph.py
def add_nodes_from(\n self,\n nodes: Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]],\n inplace: bool = True\n) -> 'Hypergraph':\n \"\"\"\n Adds multiple nodes with attributes to the hypergraph.\n\n Parameters\n ----------\n nodes : Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]]\n An iterable of node identifiers or tuples of (node_id, attributes).\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added nodes.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any node ID already exists.\n ValueError\n If any tuple does not contain exactly two elements or if attributes are not dictionaries.\n \"\"\"\n new_nodes = {}\n for item in nodes:\n if isinstance(item, tuple):\n if len(item) != 2 or not isinstance(item[1], dict):\n raise ValueError(f\"Each tuple must be of the form (node_id, attributes). Invalid tuple: {item}\")\n node_id, attrs = item\n else:\n node_id, attrs = item, {}\n\n if node_id in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' already exists in the hypergraph.\")\n\n new_nodes[node_id] = copy.deepcopy(attrs)\n\n if inplace:\n for node_id, attrs in new_nodes.items():\n self.node_properties[node_id] = attrs\n self.graph.add_node(node_id, bipartite='node', **self.node_properties[node_id])\n return self\n else:\n # Create a new Hypergraph instance with the added nodes\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for node_id, attrs in new_nodes.items():\n new_node_properties[node_id] = attrs\n new_graph.add_node(node_id, bipartite='node', **new_node_properties[node_id])\n\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.adjacency_matrix","title":"adjacency_matrix(s=1, index=False)
","text":"Generates the adjacency matrix for nodes based on s-node connectivity.
Source code in src/aeiva/hypergraph/hypergraph.py
def adjacency_matrix(self, s: int = 1, index: bool = False) -> Tuple[Optional[csr_matrix], Dict[int, Any]]:\n \"\"\"\n Generates the adjacency matrix for nodes based on s-node connectivity.\n \"\"\"\n from scipy.sparse import lil_matrix\n\n node_ids = list(self.node_properties.keys())\n node_index = {node_id: idx for idx, node_id in enumerate(node_ids)}\n size = len(node_ids)\n if size == 0:\n return None, {}\n\n A = lil_matrix((size, size), dtype=int)\n for he in self.hyperedges.values():\n nodes = list(he.nodes)\n for i in range(len(nodes)):\n for j in range(i + 1, len(nodes)):\n A[node_index[nodes[i]], node_index[nodes[j]]] += 1\n\n # Apply the threshold s and convert to binary\n A = (A >= s).astype(int)\n A = A.tocsr()\n\n if index:\n return A, node_index\n return A, {}\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.collapse_duplicate_hyperedges","title":"collapse_duplicate_hyperedges(name=None, use_uids=None, use_counts=False, return_counts=True, return_equivalence_classes=False, aggregate_properties_by=None)
","text":"Collapses duplicate hyperedges (hyperedges with identical node memberships) into single hyperedges.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.collapse_duplicate_hyperedges--parameters","title":"Parameters","text":"name : Optional[str], default=None The name assigned to the collapsed hypergraph. If None, defaults to the original name suffixed with '_collapsed_hyperedges'.
Optional[List[Any]] = None Specifies the hyperedge identifiers to use as representatives for each equivalence class. If two identifiers occur in the same equivalence class, the first one found in use_uids
is used. If None, the first encountered hyperedge in each class is used as the representative.
bool, optional, default=False If True, renames the equivalence class representatives by appending the size of the class (e.g., 'HE1:3').
bool, optional, default=True If True, adds the size of each equivalence class to the properties of the representative hyperedge under the key 'equivalence_class_size'.
bool, optional, default=False If True, returns a tuple containing the new collapsed hypergraph and a dictionary mapping representatives to their equivalence classes.
Optional[Dict[str, str]] = None A dictionary specifying aggregation methods for hyperedge properties. Keys are property names, and values are aggregation functions (e.g., {'weight': 'sum'}). Properties not specified will use the 'first' aggregation.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.collapse_duplicate_hyperedges--returns","title":"Returns","text":"Hypergraph or Tuple[Hypergraph, Dict[Any, Set[Any]]] - If return_equivalence_classes=False
, returns the new collapsed hypergraph. - If return_equivalence_classes=True
, returns a tuple containing the collapsed hypergraph and a dictionary of equivalence classes.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.collapse_duplicate_hyperedges--raises","title":"Raises","text":"HypergraphError If the hypergraph is empty or improperly structured.
Source code in src/aeiva/hypergraph/hypergraph.py
def collapse_duplicate_hyperedges(\n self,\n name: Optional[str] = None,\n use_uids: Optional[List[Any]] = None,\n use_counts: bool = False,\n return_counts: bool = True,\n return_equivalence_classes: bool = False,\n aggregate_properties_by: Optional[Dict[str, str]] = None,\n) -> Union['Hypergraph', Tuple['Hypergraph', Dict[Any, Set[Any]]]]:\n \"\"\"\n Collapses duplicate hyperedges (hyperedges with identical node memberships) into single hyperedges.\n\n Parameters\n ----------\n name : Optional[str], default=None\n The name assigned to the collapsed hypergraph. If None, defaults to the original name suffixed with '_collapsed_hyperedges'.\n\n use_uids : Optional[List[Any]] = None\n Specifies the hyperedge identifiers to use as representatives for each equivalence class.\n If two identifiers occur in the same equivalence class, the first one found in `use_uids` is used.\n If None, the first encountered hyperedge in each class is used as the representative.\n\n use_counts : bool, optional, default=False\n If True, renames the equivalence class representatives by appending the size of the class (e.g., 'HE1:3').\n\n return_counts : bool, optional, default=True\n If True, adds the size of each equivalence class to the properties of the representative hyperedge under the key 'equivalence_class_size'.\n\n return_equivalence_classes : bool, optional, default=False\n If True, returns a tuple containing the new collapsed hypergraph and a dictionary mapping representatives to their equivalence classes.\n\n aggregate_properties_by : Optional[Dict[str, str]] = None\n A dictionary specifying aggregation methods for hyperedge properties. Keys are property names, and values are aggregation functions (e.g., {'weight': 'sum'}).\n Properties not specified will use the 'first' aggregation.\n\n Returns\n -------\n Hypergraph or Tuple[Hypergraph, Dict[Any, Set[Any]]]\n - If `return_equivalence_classes=False`, returns the new collapsed hypergraph.\n - If `return_equivalence_classes=True`, returns a tuple containing the collapsed hypergraph and a dictionary of equivalence classes.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is empty or improperly structured.\n \"\"\"\n if not self.hyperedges:\n raise HypergraphError(\"Cannot collapse hyperedges in an empty hypergraph.\")\n\n # Identify equivalence classes based on identical node memberships\n membership_to_hyperedges: Dict[frozenset, Set[Any]] = {}\n for he_id, hyperedge in self.hyperedges.items():\n key = frozenset(hyperedge.nodes)\n membership_to_hyperedges.setdefault(key, set()).add(he_id)\n\n # Filter out classes with only one hyperedge (no duplicates)\n equivalence_classes = [hes for hes in membership_to_hyperedges.values() if len(hes) > 1]\n if not equivalence_classes:\n # No duplicates to collapse; return the original hypergraph\n return self if not return_equivalence_classes else (self, {})\n\n # Prepare aggregation methods\n aggregate_properties_by = aggregate_properties_by if aggregate_properties_by is not None else {\"weight\": \"sum\"}\n\n # Initialize mapping from old hyperedges to new hyperedges\n hyperedge_mapping: Dict[Any, Any] = {}\n equivalence_class_dict: Dict[Any, Set[Any]] = {}\n\n for eq_class in equivalence_classes:\n # Determine representative\n if use_uids:\n # Select the first UID from use_uids that is in the equivalence class\n representative = next((uid for uid in use_uids if uid in eq_class), None)\n if not representative:\n # Fallback to the first hyperedge in the equivalence class\n representative = next(iter(eq_class))\n else:\n # Use the first hyperedge in the equivalence class as representative\n representative = next(iter(eq_class))\n\n # Optionally rename with counts\n if use_counts:\n new_representative = f\"{representative}:{len(eq_class)}\"\n else:\n new_representative = representative\n\n # Map all hyperedges in the class to the representative\n for he in eq_class:\n hyperedge_mapping[he] = new_representative\n\n # Store the equivalence class\n equivalence_class_dict[new_representative] = eq_class\n\n # Replace hyperedge IDs in incidences based on mapping\n new_hyperedges = {}\n for he_id, hyperedge in self.hyperedges.items():\n new_he_id = hyperedge_mapping.get(he_id, he_id)\n if new_he_id not in new_hyperedges:\n new_hyperedges[new_he_id] = HyperEdge(id=new_he_id, nodes=hyperedge.nodes.copy(), properties=copy.deepcopy(hyperedge.properties))\n else:\n new_hyperedges[new_he_id].nodes.update(hyperedge.nodes)\n\n # Aggregate hyperedge properties\n for he_id, hyperedge in new_hyperedges.items():\n if he_id in equivalence_class_dict:\n aggregated_props = {}\n for prop, agg_func in aggregate_properties_by.items():\n values = [self.hyperedge_properties[old_he].get(prop, 0) for old_he in equivalence_class_dict[he_id]]\n if agg_func == 'sum':\n aggregated_props[prop] = sum(values)\n elif agg_func == 'mean':\n aggregated_props[prop] = sum(values) / len(values) if values else 0\n elif agg_func == 'max':\n aggregated_props[prop] = max(values) if values else None\n elif agg_func == 'min':\n aggregated_props[prop] = min(values) if values else None\n else:\n aggregated_props[prop] = values[0] if values else None # Default to first\n new_hyperedges[he_id].properties.update(aggregated_props)\n\n # Handle equivalence class size\n if use_counts:\n for he_id in equivalence_class_dict:\n new_hyperedges[he_id].properties['equivalence_class_size'] = len(equivalence_class_dict[he_id])\n elif return_counts:\n for he_id in new_hyperedges:\n if he_id in equivalence_class_dict:\n new_hyperedges[he_id].properties['equivalence_class_size'] = len(equivalence_class_dict[he_id])\n else:\n new_hyperedges[he_id].properties['equivalence_class_size'] = 1\n\n # Initialize the collapsed hypergraph\n collapsed_hypergraph = Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=copy.deepcopy(self.node_properties),\n hyperedge_properties={\n he_id: copy.deepcopy(he.properties) for he_id, he in new_hyperedges.items()\n },\n name=name if name else f\"{self.name}_collapsed_hyperedges\"\n )\n\n if return_equivalence_classes:\n return collapsed_hypergraph, equivalence_class_dict\n else:\n return collapsed_hypergraph\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.collapse_duplicate_nodes","title":"collapse_duplicate_nodes(name=None, use_uids=None, use_counts=False, return_counts=True, return_equivalence_classes=False, aggregate_properties_by=None)
","text":"Collapses duplicate nodes (nodes with identical hyperedge memberships) into single nodes.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.collapse_duplicate_nodes--parameters","title":"Parameters","text":"name : Optional[str], default=None The name assigned to the collapsed hypergraph. If None, defaults to the original name suffixed with '_collapsed_nodes'.
Optional[List[Any]] = None Specifies the node identifiers to use as representatives for each equivalence class. If two identifiers occur in the same equivalence class, the first one found in use_uids
is used. If None, the first encountered node in each class is used as the representative.
bool, optional, default=False If True, renames the equivalence class representatives by appending the size of the class (e.g., 'N1:3').
bool, optional, default=True If True, adds the size of each equivalence class to the properties of the representative node under the key 'equivalence_class_size'.
bool, optional, default=False If True, returns a tuple containing the new collapsed hypergraph and a dictionary mapping representatives to their equivalence classes.
Optional[Dict[str, str]] = None A dictionary specifying aggregation methods for node properties. Keys are property names, and values are aggregation functions (e.g., {'weight': 'sum'}). Properties not specified will use the 'first' aggregation.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.collapse_duplicate_nodes--returns","title":"Returns","text":"Hypergraph or Tuple[Hypergraph, Dict[Any, Set[Any]]] - If return_equivalence_classes=False
, returns the new collapsed hypergraph. - If return_equivalence_classes=True
, returns a tuple containing the collapsed hypergraph and a dictionary of equivalence classes.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.collapse_duplicate_nodes--raises","title":"Raises","text":"HypergraphError If the hypergraph is empty or improperly structured.
Source code in src/aeiva/hypergraph/hypergraph.py
def collapse_duplicate_nodes(\n self,\n name: Optional[str] = None,\n use_uids: Optional[List[Any]] = None,\n use_counts: bool = False,\n return_counts: bool = True,\n return_equivalence_classes: bool = False,\n aggregate_properties_by: Optional[Dict[str, str]] = None,\n) -> Union['Hypergraph', Tuple['Hypergraph', Dict[Any, Set[Any]]]]:\n \"\"\"\n Collapses duplicate nodes (nodes with identical hyperedge memberships) into single nodes.\n\n Parameters\n ----------\n name : Optional[str], default=None\n The name assigned to the collapsed hypergraph. If None, defaults to the original name suffixed with '_collapsed_nodes'.\n\n use_uids : Optional[List[Any]] = None\n Specifies the node identifiers to use as representatives for each equivalence class.\n If two identifiers occur in the same equivalence class, the first one found in `use_uids` is used.\n If None, the first encountered node in each class is used as the representative.\n\n use_counts : bool, optional, default=False\n If True, renames the equivalence class representatives by appending the size of the class (e.g., 'N1:3').\n\n return_counts : bool, optional, default=True\n If True, adds the size of each equivalence class to the properties of the representative node under the key 'equivalence_class_size'.\n\n return_equivalence_classes : bool, optional, default=False\n If True, returns a tuple containing the new collapsed hypergraph and a dictionary mapping representatives to their equivalence classes.\n\n aggregate_properties_by : Optional[Dict[str, str]] = None\n A dictionary specifying aggregation methods for node properties. Keys are property names, and values are aggregation functions (e.g., {'weight': 'sum'}).\n Properties not specified will use the 'first' aggregation.\n\n Returns\n -------\n Hypergraph or Tuple[Hypergraph, Dict[Any, Set[Any]]]\n - If `return_equivalence_classes=False`, returns the new collapsed hypergraph.\n - If `return_equivalence_classes=True`, returns a tuple containing the collapsed hypergraph and a dictionary of equivalence classes.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is empty or improperly structured.\n \"\"\"\n if not self.node_properties:\n raise HypergraphError(\"Cannot collapse nodes in an empty hypergraph.\")\n\n # Identify equivalence classes based on identical hyperedge memberships\n membership_to_nodes: Dict[frozenset, Set[Any]] = {}\n for node_id, node_props in self.node_properties.items():\n key = frozenset(self.get_hyperedges_of_node(node_id))\n membership_to_nodes.setdefault(key, set()).add(node_id)\n\n # Filter out classes with only one node (no duplicates)\n equivalence_classes = [nodes for nodes in membership_to_nodes.values() if len(nodes) > 1]\n if not equivalence_classes:\n # No duplicates to collapse; return the original hypergraph\n return self if not return_equivalence_classes else (self, {})\n\n # Prepare aggregation methods\n aggregate_properties_by = aggregate_properties_by if aggregate_properties_by is not None else {\"weight\": \"sum\"}\n\n # Initialize mapping from old nodes to new nodes\n node_mapping: Dict[Any, Any] = {}\n equivalence_class_dict: Dict[Any, Set[Any]] = {}\n\n for eq_class in equivalence_classes:\n # Determine representative\n if use_uids:\n # Select the first UID from use_uids that is in the equivalence class\n representative = next((uid for uid in use_uids if uid in eq_class), None)\n if not representative:\n # Fallback to the first node in the equivalence class\n representative = next(iter(eq_class))\n else:\n # Use the first node in the equivalence class as representative\n representative = next(iter(eq_class))\n\n # Optionally rename with counts\n if use_counts:\n new_representative = f\"{representative}:{len(eq_class)}\"\n else:\n new_representative = representative\n\n # Map all nodes in the class to the representative\n for node in eq_class:\n node_mapping[node] = new_representative\n\n # Store the equivalence class\n equivalence_class_dict[new_representative] = eq_class\n\n # Replace node IDs in hyperedges based on mapping\n new_hyperedges = {}\n for he_id, hyperedge in self.hyperedges.items():\n new_nodes = set()\n for node_id in hyperedge.nodes:\n new_node_id = node_mapping.get(node_id, node_id)\n new_nodes.add(new_node_id)\n new_hyperedges[he_id] = HyperEdge(id=he_id, nodes=new_nodes, properties=copy.deepcopy(hyperedge.properties))\n\n # Aggregate node properties\n new_node_properties = {}\n for node_id, node_props in self.node_properties.items():\n new_node_id = node_mapping.get(node_id, node_id)\n if new_node_id not in new_node_properties:\n new_node_properties[new_node_id] = copy.deepcopy(node_props)\n else:\n for prop, agg_func in aggregate_properties_by.items():\n if prop in node_props:\n if agg_func == 'sum':\n new_node_properties[new_node_id][prop] = new_node_properties[new_node_id].get(prop, 0) + node_props[prop]\n elif agg_func == 'mean':\n # To calculate mean, store sum and count\n if 'sum_' + prop not in new_node_properties[new_node_id]:\n new_node_properties[new_node_id]['sum_' + prop] = node_props[prop]\n new_node_properties[new_node_id]['count_' + prop] = 1\n else:\n new_node_properties[new_node_id]['sum_' + prop] += node_props[prop]\n new_node_properties[new_node_id]['count_' + prop] += 1\n # Calculate mean at the end\n elif agg_func == 'max':\n current_max = new_node_properties[new_node_id].get(prop, float('-inf'))\n new_node_properties[new_node_id][prop] = max(current_max, node_props[prop])\n elif agg_func == 'min':\n current_min = new_node_properties[new_node_id].get(prop, float('inf'))\n new_node_properties[new_node_id][prop] = min(current_min, node_props[prop])\n else:\n new_node_properties[new_node_id][prop] = node_props[prop] # Default to last\n # Finalize mean calculations\n for node_id, props in new_node_properties.items():\n for prop in list(props.keys()):\n if prop.startswith('sum_'):\n base_prop = prop[4:]\n sum_val = props[prop]\n count_val = props.get('count_' + base_prop, 1)\n new_node_properties[node_id][base_prop] = sum_val / count_val if count_val > 0 else 0\n del new_node_properties[node_id][prop]\n del new_node_properties[node_id]['count_' + base_prop]\n\n # Handle equivalence class size\n if use_counts:\n for node_id in equivalence_class_dict:\n new_node_properties[node_id]['equivalence_class_size'] = len(equivalence_class_dict[node_id])\n elif return_counts:\n for node_id in new_node_properties:\n if node_id in equivalence_class_dict:\n new_node_properties[node_id]['equivalence_class_size'] = len(equivalence_class_dict[node_id])\n else:\n new_node_properties[node_id]['equivalence_class_size'] = 1\n\n # Initialize the collapsed hypergraph\n collapsed_hypergraph = Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties={\n he_id: copy.deepcopy(he.properties) for he_id, he in new_hyperedges.items()\n },\n name=name if name else f\"{self.name}_collapsed_nodes\"\n )\n\n if return_equivalence_classes:\n return collapsed_hypergraph, equivalence_class_dict\n else:\n return collapsed_hypergraph\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_hyperedge_diameter","title":"compute_hyperedge_diameter(s=1)
","text":"Returns the diameter of the hypergraph based on s-hyperedge connectivity.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_hyperedge_diameter--parameters","title":"Parameters","text":"s : int, optional, default=1 The number of shared nodes required for hyperedges to be considered adjacent.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_hyperedge_diameter--returns","title":"Returns","text":"int The diameter of the hypergraph based on hyperedge connectivity.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_hyperedge_diameter--raises","title":"Raises","text":"HypergraphError If the hypergraph is not s-hyperedge-connected or has no hyperedges.
Source code in src/aeiva/hypergraph/hypergraph.py
def compute_hyperedge_diameter(self, s: int = 1) -> int:\n \"\"\"\n Returns the diameter of the hypergraph based on s-hyperedge connectivity.\n\n Parameters\n ----------\n s : int, optional, default=1\n The number of shared nodes required for hyperedges to be considered adjacent.\n\n Returns\n -------\n int\n The diameter of the hypergraph based on hyperedge connectivity.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is not s-hyperedge-connected or has no hyperedges.\n \"\"\"\n A, _ = self.hyperedge_adjacency_matrix(s=s, index=True)\n if A is None or A.shape[0] == 0:\n raise HypergraphError(\"The hypergraph has no hyperedges to compute diameter.\")\n\n graph = nx.from_scipy_sparse_array(A)\n if not nx.is_connected(graph):\n raise HypergraphError(f\"Hypergraph is not s-hyperedge-connected. s={s}\")\n\n try:\n return nx.diameter(graph)\n except nx.NetworkXError as e:\n raise HypergraphError(f\"Could not compute hyperedge diameter: {e}\")\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_hyperedge_diameters","title":"compute_hyperedge_diameters(s=1)
","text":"Returns the hyperedge diameters of the s-hyperedge-connected component subgraphs in the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_hyperedge_diameters--parameters","title":"Parameters","text":"s : int, optional, default=1 The number of shared nodes required for hyperedges to be considered adjacent.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_hyperedge_diameters--returns","title":"Returns","text":"Tuple[int, List[int], List[Set[Any]]] - Maximum diameter among all s-hyperedge-connected components. - List of diameters for each s-hyperedge connected component. - List of sets, each containing hyperedge IDs in an s-hyperedge connected component.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_hyperedge_diameters--raises","title":"Raises","text":"HypergraphError If the hypergraph is not s-hyperedge-connected or has no hyperedges.
Source code in src/aeiva/hypergraph/hypergraph.py
def compute_hyperedge_diameters(self, s: int = 1) -> Tuple[int, List[int], List[Set[Any]]]:\n \"\"\"\n Returns the hyperedge diameters of the s-hyperedge-connected component subgraphs in the hypergraph.\n\n Parameters\n ----------\n s : int, optional, default=1\n The number of shared nodes required for hyperedges to be considered adjacent.\n\n Returns\n -------\n Tuple[int, List[int], List[Set[Any]]]\n - Maximum diameter among all s-hyperedge-connected components.\n - List of diameters for each s-hyperedge connected component.\n - List of sets, each containing hyperedge IDs in an s-hyperedge connected component.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is not s-hyperedge-connected or has no hyperedges.\n \"\"\"\n A, he_id_map = self.hyperedge_adjacency_matrix(s=s, index=True)\n if A is None or A.shape[0] == 0:\n raise HypergraphError(\"The hypergraph has no hyperedges to compute diameters.\")\n\n graph = nx.from_scipy_sparse_array(A)\n\n if not nx.is_connected(graph):\n raise HypergraphError(f\"Hypergraph is not s-hyperedge-connected. s={s}\")\n\n diams = []\n comps = []\n for component in nx.connected_components(graph):\n subgraph = graph.subgraph(component)\n if len(subgraph) == 1:\n diamc = 0 # Diameter of a single hyperedge is 0\n else:\n try:\n diamc = nx.diameter(subgraph)\n except nx.NetworkXError:\n diamc = float('inf') # Infinite diameter if the subgraph is not connected\n diams.append(diamc)\n component_hyperedges = {he_id_map[he] for he in component}\n comps.append(component_hyperedges)\n\n if not diams:\n raise HypergraphError(\"No connected components found to compute hyperedge diameters.\")\n\n max_diam = max(diams)\n return max_diam, diams, comps\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_node_diameter","title":"compute_node_diameter(s=1)
","text":"Returns the diameter of the hypergraph based on s-node connectivity.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_node_diameter--parameters","title":"Parameters","text":"s : int, optional, default=1 The number of shared hyperedges required for nodes to be considered adjacent.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_node_diameter--returns","title":"Returns","text":"int The diameter of the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_node_diameter--raises","title":"Raises","text":"HypergraphError If the hypergraph is not s-node-connected or has no nodes.
Source code in src/aeiva/hypergraph/hypergraph.py
def compute_node_diameter(self, s: int = 1) -> int:\n \"\"\"\n Returns the diameter of the hypergraph based on s-node connectivity.\n\n Parameters\n ----------\n s : int, optional, default=1\n The number of shared hyperedges required for nodes to be considered adjacent.\n\n Returns\n -------\n int\n The diameter of the hypergraph.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is not s-node-connected or has no nodes.\n \"\"\"\n A, _ = self.adjacency_matrix(s=s, index=True)\n if A is None or A.shape[0] == 0:\n raise HypergraphError(\"The hypergraph has no nodes to compute diameter.\")\n\n graph = nx.from_scipy_sparse_array(A)\n if not nx.is_connected(graph):\n raise HypergraphError(f\"Hypergraph is not s-node-connected. s={s}\")\n\n try:\n return nx.diameter(graph)\n except nx.NetworkXError as e:\n raise HypergraphError(f\"Could not compute diameter: {e}\")\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_node_diameters","title":"compute_node_diameters(s=1)
","text":"Returns the node diameters of the connected components in the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_node_diameters--parameters","title":"Parameters","text":"s : int, optional, default=1 The number of shared hyperedges required for nodes to be considered adjacent.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_node_diameters--returns","title":"Returns","text":"Tuple[int, List[int], List[Set[Any]]] - Maximum diameter among all connected components. - List of diameters for each s-node connected component. - List of sets, each containing node IDs in an s-node connected component.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_node_diameters--raises","title":"Raises","text":"HypergraphError If the hypergraph is not s-connected or has no nodes.
Source code in src/aeiva/hypergraph/hypergraph.py
def compute_node_diameters(self, s: int = 1) -> Tuple[int, List[int], List[Set[Any]]]:\n \"\"\"\n Returns the node diameters of the connected components in the hypergraph.\n\n Parameters\n ----------\n s : int, optional, default=1\n The number of shared hyperedges required for nodes to be considered adjacent.\n\n Returns\n -------\n Tuple[int, List[int], List[Set[Any]]]\n - Maximum diameter among all connected components.\n - List of diameters for each s-node connected component.\n - List of sets, each containing node IDs in an s-node connected component.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is not s-connected or has no nodes.\n \"\"\"\n A, node_id_map = self.adjacency_matrix(s=s, index=True)\n if A is None or A.shape[0] == 0:\n raise HypergraphError(\"The hypergraph has no nodes to compute diameters.\")\n\n graph = nx.from_scipy_sparse_array(A)\n\n if not nx.is_connected(graph):\n raise HypergraphError(f\"Hypergraph is not s-node-connected. s={s}\")\n\n diams = []\n comps = []\n for component in nx.connected_components(graph):\n subgraph = graph.subgraph(component)\n if len(subgraph) == 1:\n diamc = 0 # Diameter of a single node is 0\n else:\n try:\n diamc = nx.diameter(subgraph)\n except nx.NetworkXError:\n diamc = float('inf') # Infinite diameter if the subgraph is not connected\n diams.append(diamc)\n component_nodes = {node_id_map[node] for node in component}\n comps.append(component_nodes)\n\n if not diams:\n raise HypergraphError(\"No connected components found to compute diameters.\")\n\n max_diam = max(diams)\n return max_diam, diams, comps\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.copy","title":"copy(name=None)
","text":"Creates a deep copy of the hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.copy--parameters","title":"Parameters","text":"name : Optional[str], default=None The name for the copied Hypergraph. If not provided, retains the original name.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.copy--returns","title":"Returns","text":"Hypergraph A new Hypergraph instance that is a deep copy of the original.
Source code in src/aeiva/hypergraph/hypergraph.py
def copy(self, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Creates a deep copy of the hypergraph instance.\n\n Parameters\n ----------\n name : Optional[str], default=None\n The name for the copied Hypergraph. If not provided, retains the original name.\n\n Returns\n -------\n Hypergraph\n A new Hypergraph instance that is a deep copy of the original.\n \"\"\"\n\n # Deep copy hyperedges\n hyperedges_dict = {}\n for he_id, he in self.hyperedges.items():\n hyperedges_dict[he_id] = {\n 'nodes': list(he.nodes),\n 'properties': copy.deepcopy(he.properties)\n }\n\n # Deep copy node_properties and hyperedge_properties\n node_properties_copy = copy.deepcopy(self.node_properties)\n hyperedge_properties_copy = copy.deepcopy(self.hyperedge_properties)\n\n # Create a new Hypergraph instance with the copied data\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=node_properties_copy,\n hyperedge_properties=hyperedge_properties_copy,\n name=name if name is not None else self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.deepcopy","title":"deepcopy(name=None)
","text":"Creates a deep copy of the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.deepcopy--parameters","title":"Parameters","text":"name : Optional[str], default=None The name assigned to the cloned hypergraph. If None, defaults to the original hypergraph's name suffixed with '_clone'.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.deepcopy--returns","title":"Returns","text":"Hypergraph A deep copy of the hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def deepcopy(self, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Creates a deep copy of the hypergraph.\n\n Parameters\n ----------\n name : Optional[str], default=None\n The name assigned to the cloned hypergraph. If None, defaults to the original hypergraph's name suffixed with '_clone'.\n\n Returns\n -------\n Hypergraph\n A deep copy of the hypergraph.\n \"\"\"\n\n # Deep copy hyperedges\n hyperedges_copy = {\n he_id: {\n 'nodes': hyperedge.nodes.copy(),\n 'properties': copy.deepcopy(hyperedge.properties)\n }\n for he_id, hyperedge in self.hyperedges.items()\n }\n\n # Deep copy node properties\n node_properties_copy = copy.deepcopy(self.node_properties)\n\n # Deep copy hyperedge properties\n hyperedge_properties_copy = copy.deepcopy(self.hyperedge_properties)\n\n # Set name\n cloned_name = f\"{self.name}_deepcopy\" if name is None else name\n\n # Initialize the cloned hypergraph\n cloned_H = Hypergraph(\n hyperedges=hyperedges_copy,\n node_properties=node_properties_copy,\n hyperedge_properties=hyperedge_properties_copy,\n name=cloned_name\n )\n\n return cloned_H\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.difference","title":"difference(other, inplace=False, name=None)
","text":"Returns the difference of the current hypergraph with another hypergraph. The difference includes nodes and hyperedges present in the current hypergraph but not in the other.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.difference--parameters","title":"Parameters","text":"other : Hypergraph The hypergraph to subtract. inplace : bool, optional, default=False If True, modifies the current hypergraph by removing elements found in other
. Otherwise, returns a new Hypergraph instance. name : Optional[str], default=None The name for the resulting hypergraph. If None, defaults to 'Difference_of_{self.name}_{other.name}'.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.difference--returns","title":"Returns","text":"Hypergraph The resulting difference hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.difference--raises","title":"Raises","text":"TypeError If other
is not an instance of Hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def difference(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Returns the difference of the current hypergraph with another hypergraph.\n The difference includes nodes and hyperedges present in the current hypergraph but not in the other.\n\n Parameters\n ----------\n other : Hypergraph\n The hypergraph to subtract.\n inplace : bool, optional, default=False\n If True, modifies the current hypergraph by removing elements found in `other`.\n Otherwise, returns a new Hypergraph instance.\n name : Optional[str], default=None\n The name for the resulting hypergraph. If None, defaults to 'Difference_of_{self.name}_{other.name}'.\n\n Returns\n -------\n Hypergraph\n The resulting difference hypergraph.\n\n Raises\n ------\n TypeError\n If `other` is not an instance of Hypergraph.\n \"\"\"\n if not isinstance(other, Hypergraph):\n raise TypeError(\"The `other` parameter must be an instance of Hypergraph.\")\n\n if inplace:\n # Remove hyperedges present in other\n hyperedges_to_remove = set(self.hyperedges.keys()) & set(other.hyperedges.keys())\n self.remove_hyperedges(hyperedges_to_remove, inplace=True)\n # Remove nodes present in other\n nodes_to_remove = set(self.node_properties.keys()) & set(other.node_properties.keys())\n self.remove_nodes_from(nodes_to_remove, inplace=True)\n return self\n else:\n # Create a new Hypergraph instance\n new_hyperedges = {he_id: copy.deepcopy(he) for he_id, he in self.hyperedges.items() if he_id not in other.hyperedges}\n new_hyperedge_properties = {he_id: copy.deepcopy(props) for he_id, props in self.hyperedge_properties.items() if he_id not in other.hyperedges}\n new_node_properties = {node_id: copy.deepcopy(props) for node_id, props in self.node_properties.items() if node_id not in other.node_properties}\n\n # Reconstruct graph\n new_graph = nx.Graph()\n new_bipartite_nodes = set()\n for he_id, hyperedge in new_hyperedges.items():\n new_graph.add_node(he_id, bipartite='hyperedge', **new_hyperedge_properties[he_id])\n new_bipartite_nodes.add(he_id)\n for node in hyperedge.nodes:\n if node in new_node_properties:\n new_graph.add_edge(he_id, node)\n\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=name if name else f\"Difference_of_{self.name}_{other.name}\"\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.dual","title":"dual(name=None)
","text":"Constructs the dual of the current hypergraph by reversing the roles of nodes and hyperedges.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.dual--parameters","title":"Parameters","text":"name : Optional[str], default=None Name for the dual hypergraph. If None, defaults to the original hypergraph's name with '_dual' appended.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.dual--returns","title":"Returns","text":"Hypergraph A new Hypergraph instance representing the dual of the current hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def dual(self, name: Optional[str] = None) -> \"Hypergraph\":\n \"\"\"\n Constructs the dual of the current hypergraph by reversing the roles of nodes and hyperedges.\n\n Parameters\n ----------\n name : Optional[str], default=None\n Name for the dual hypergraph. If None, defaults to the original hypergraph's name with '_dual' appended.\n\n Returns\n -------\n Hypergraph\n A new Hypergraph instance representing the dual of the current hypergraph.\n \"\"\"\n # Initialize dual hyperedges, which will correspond to original nodes\n dual_hyperedges = {}\n\n # Invert the node-hyperedge structure\n for he_id, hyperedge in self.hyperedges.items():\n for node in hyperedge.nodes:\n # Each original node becomes a hyperedge in the dual\n if node not in dual_hyperedges:\n dual_hyperedges[node] = {'nodes': [], 'properties': self.node_properties.get(node, {})}\n # The new hyperedge (original node) connects to the original hyperedge id as a \"node\"\n dual_hyperedges[node]['nodes'].append(he_id)\n\n # Define node properties in the dual as the original hyperedge properties\n dual_node_properties = {he_id: self.hyperedge_properties.get(he_id, {}) for he_id in self.hyperedges}\n\n # Create and return the dual Hypergraph\n return Hypergraph(\n hyperedges=dual_hyperedges,\n node_properties=dual_node_properties,\n hyperedge_properties=self.node_properties, # Properties of original nodes now apply to dual hyperedges\n name=name or (self.name + \"_dual\" if self.name else \"dual\")\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.edge_elements","title":"edge_elements()
","text":"Returns a dictionary where each key is a hyperedge ID and the value is a list of node IDs within that hyperedge.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.edge_elements--returns","title":"Returns","text":"Dict[Any, List[Any]] Dictionary mapping hyperedge IDs to lists of node IDs they contain.
Source code in src/aeiva/hypergraph/hypergraph.py
def edge_elements(self) -> Dict[Any, List[Any]]:\n \"\"\"\n Returns a dictionary where each key is a hyperedge ID and the value is a list of node IDs within that hyperedge.\n\n Returns\n -------\n Dict[Any, List[Any]]\n Dictionary mapping hyperedge IDs to lists of node IDs they contain.\n \"\"\"\n return {he_id: hyperedge.nodes for he_id, hyperedge in self.hyperedges.items()}\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.edges","title":"edges()
","text":"Returns a list of all hyperedge identifiers in the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.edges--returns","title":"Returns","text":"List[Any] List of hyperedge IDs.
Source code in src/aeiva/hypergraph/hypergraph.py
def edges(self) -> List[Any]:\n \"\"\"\n Returns a list of all hyperedge identifiers in the hypergraph.\n\n Returns\n -------\n List[Any]\n List of hyperedge IDs.\n \"\"\"\n return list(self.hyperedges.keys())\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.from_bipartite_graph","title":"from_bipartite_graph(bipartite_graph, hyperedge_prefix='HE', node_prefix='N', name=None)
classmethod
","text":"Constructs a Hypergraph instance from a bipartite graph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.from_bipartite_graph--parameters","title":"Parameters","text":"bipartite_graph : nx.Graph A bipartite graph where one set of nodes represents hyperedges and the other represents regular nodes. hyperedge_prefix : str, optional, default=\"HE\" The prefix to identify hyperedge nodes in the bipartite graph. node_prefix : str, optional, default=\"N\" The prefix to identify regular nodes in the bipartite graph. name : Optional[str], default=None The name assigned to the new Hypergraph. If None, defaults to 'FromBipartiteGraph'.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.from_bipartite_graph--returns","title":"Returns","text":"Hypergraph The constructed Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.from_bipartite_graph--raises","title":"Raises","text":"ValueError If the bipartite graph does not contain two distinct sets of nodes identifiable by the provided prefixes.
Source code in src/aeiva/hypergraph/hypergraph.py
@classmethod\ndef from_bipartite_graph(cls, bipartite_graph: nx.Graph, hyperedge_prefix: str = \"HE\", node_prefix: str = \"N\", name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Constructs a Hypergraph instance from a bipartite graph.\n\n Parameters\n ----------\n bipartite_graph : nx.Graph\n A bipartite graph where one set of nodes represents hyperedges and the other represents regular nodes.\n hyperedge_prefix : str, optional, default=\"HE\"\n The prefix to identify hyperedge nodes in the bipartite graph.\n node_prefix : str, optional, default=\"N\"\n The prefix to identify regular nodes in the bipartite graph.\n name : Optional[str], default=None\n The name assigned to the new Hypergraph. If None, defaults to 'FromBipartiteGraph'.\n\n Returns\n -------\n Hypergraph\n The constructed Hypergraph instance.\n\n Raises\n ------\n ValueError\n If the bipartite graph does not contain two distinct sets of nodes identifiable by the provided prefixes.\n \"\"\"\n hyperedges = {}\n node_properties = {}\n hyperedge_properties = {}\n name = name if name else \"FromBipartiteGraph\"\n\n for node in bipartite_graph.nodes(data=True):\n node_id, attrs = node\n if node_id.startswith(hyperedge_prefix):\n # It's a hyperedge\n hyperedges[node_id] = HyperEdge(id=node_id, nodes=set(), properties=attrs)\n hyperedge_properties[node_id] = copy.deepcopy(attrs)\n elif node_id.startswith(node_prefix):\n # It's a regular node\n node_properties[node_id] = copy.deepcopy(attrs)\n else:\n raise ValueError(f\"Node '{node_id}' does not start with either hyperedge_prefix '{hyperedge_prefix}' or node_prefix '{node_prefix}'.\")\n\n # Assign nodes to hyperedges based on edges in bipartite graph\n for he_id in hyperedges:\n connected_nodes = set(bipartite_graph.neighbors(he_id))\n hyperedges[he_id].nodes = connected_nodes\n\n # Construct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in hyperedges.items()\n }\n\n return cls(\n hyperedges=hyperedges_dict,\n node_properties=node_properties,\n hyperedge_properties=hyperedge_properties,\n name=name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedge_connected_components","title":"get_hyperedge_connected_components(s=1, return_singletons=False)
","text":"Yields the s-hyperedge-connected components of the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedge_connected_components--parameters","title":"Parameters","text":"s : int, optional, default=1 The connectivity level to check. return_singletons : bool, optional, default=False If True, includes singleton components. Otherwise, excludes them.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedge_connected_components--yields","title":"Yields","text":"Set[Any] Sets of hyperedge IDs representing each connected component.
Source code in src/aeiva/hypergraph/hypergraph.py
def get_hyperedge_connected_components(\n self, s: int = 1, return_singletons: bool = False\n) -> Iterator[Set[Any]]:\n \"\"\"\n Yields the s-hyperedge-connected components of the hypergraph.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n\n Yields\n ------\n Set[Any]\n Sets of hyperedge IDs representing each connected component.\n \"\"\"\n return self.s_connected_components(s=s, hyperedges=True, return_singletons=return_singletons)\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedge_connected_subgraphs","title":"get_hyperedge_connected_subgraphs(s=1, return_singletons=False, name=None)
","text":"Yields subgraphs corresponding to each s-hyperedge-connected component.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedge_connected_subgraphs--parameters","title":"Parameters","text":"s : int, optional, default=1 The connectivity level to check. return_singletons : bool, optional, default=False If True, includes singleton components. Otherwise, excludes them. name : Optional[str], default=None Base name for the subgraphs. Each subgraph will have a unique name appended.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedge_connected_subgraphs--yields","title":"Yields","text":"Hypergraph Subgraphs representing each connected component.
Source code in src/aeiva/hypergraph/hypergraph.py
def get_hyperedge_connected_subgraphs(\n self, s: int = 1, return_singletons: bool = False, name: Optional[str] = None\n) -> Iterator['Hypergraph']:\n \"\"\"\n Yields subgraphs corresponding to each s-hyperedge-connected component.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n name : Optional[str], default=None\n Base name for the subgraphs. Each subgraph will have a unique name appended.\n\n Yields\n ------\n Hypergraph\n Subgraphs representing each connected component.\n \"\"\"\n return self.s_component_subgraphs(\n s=s,\n hyperedges=True,\n return_singletons=return_singletons,\n name=name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedge_distance","title":"get_hyperedge_distance(source, target, s=1)
","text":"Returns the shortest s-walk distance between two hyperedges in the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedge_distance--parameters","title":"Parameters","text":"source : Any A hyperedge identifier in the hypergraph. target : Any A hyperedge identifier in the hypergraph. s : int, optional, default=1 The number of shared nodes required for hyperedges to be considered adjacent.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedge_distance--returns","title":"Returns","text":"Union[int, float] The shortest s-walk distance between the source and target hyperedges. Returns float('inf')
if no path exists.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedge_distance--raises","title":"Raises","text":"HypergraphError If either the source or target hyperedge does not exist in the hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def get_hyperedge_distance(self, source: Any, target: Any, s: int = 1) -> Union[int, float]:\n \"\"\"\n Returns the shortest s-walk distance between two hyperedges in the hypergraph.\n\n Parameters\n ----------\n source : Any\n A hyperedge identifier in the hypergraph.\n target : Any\n A hyperedge identifier in the hypergraph.\n s : int, optional, default=1\n The number of shared nodes required for hyperedges to be considered adjacent.\n\n Returns\n -------\n Union[int, float]\n The shortest s-walk distance between the source and target hyperedges.\n Returns `float('inf')` if no path exists.\n\n Raises\n ------\n HypergraphError\n If either the source or target hyperedge does not exist in the hypergraph.\n \"\"\"\n if source not in self.hyperedges:\n raise HypergraphError(f\"Source hyperedge '{source}' does not exist in the hypergraph.\")\n if target not in self.hyperedges:\n raise HypergraphError(f\"Target hyperedge '{target}' does not exist in the hypergraph.\")\n\n A, he_id_map = self.hyperedge_adjacency_matrix(s=s, index=True)\n if A is None:\n raise HypergraphError(\"Hyperedge adjacency matrix could not be generated.\")\n\n graph = nx.from_scipy_sparse_array(A)\n\n try:\n distance = nx.shortest_path_length(graph, source=source, target=target)\n return distance\n except (nx.NetworkXNoPath, nx.NodeNotFound):\n warnings.warn(f\"No s-walk path between hyperedges '{source}' and '{target}'. Returning infinity.\")\n return float('inf')\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedges_of_node","title":"get_hyperedges_of_node(node_id)
","text":"Retrieves all hyperedges that a given node is part of.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedges_of_node--parameters","title":"Parameters","text":"node_id : Any The node identifier.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedges_of_node--returns","title":"Returns","text":"Set[Any] A set of hyperedge IDs that the node belongs to.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedges_of_node--raises","title":"Raises","text":"HypergraphError If the node does not exist in the hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def get_hyperedges_of_node(self, node_id: Any) -> Set[Any]:\n \"\"\"\n Retrieves all hyperedges that a given node is part of.\n\n Parameters\n ----------\n node_id : Any\n The node identifier.\n\n Returns\n -------\n Set[Any]\n A set of hyperedge IDs that the node belongs to.\n\n Raises\n ------\n HypergraphError\n If the node does not exist in the hypergraph.\n \"\"\"\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n return {he.id for he in self.hyperedges.values() if node_id in he.nodes}\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_node_connected_components","title":"get_node_connected_components(s=1, return_singletons=False)
","text":"Yields the s-node-connected components of the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_node_connected_components--parameters","title":"Parameters","text":"s : int, optional, default=1 The connectivity level to check. return_singletons : bool, optional, default=False If True, includes singleton components. Otherwise, excludes them.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_node_connected_components--yields","title":"Yields","text":"Set[Any] Sets of node IDs representing each connected component.
Source code in src/aeiva/hypergraph/hypergraph.py
def get_node_connected_components(\n self, s: int = 1, return_singletons: bool = False\n) -> Iterator[Set[Any]]:\n \"\"\"\n Yields the s-node-connected components of the hypergraph.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n\n Yields\n ------\n Set[Any]\n Sets of node IDs representing each connected component.\n \"\"\"\n return self.s_connected_components(s=s, hyperedges=False, return_singletons=return_singletons)\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_node_connected_subgraphs","title":"get_node_connected_subgraphs(s=1, return_singletons=False, name=None)
","text":"Yields subgraphs corresponding to each s-node-connected component.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_node_connected_subgraphs--parameters","title":"Parameters","text":"s : int, optional, default=1 The connectivity level to check. return_singletons : bool, optional, default=False If True, includes singleton components. Otherwise, excludes them. name : Optional[str], default=None Base name for the subgraphs. Each subgraph will have a unique name appended.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_node_connected_subgraphs--yields","title":"Yields","text":"Hypergraph Subgraphs representing each connected component.
Source code in src/aeiva/hypergraph/hypergraph.py
def get_node_connected_subgraphs(\n self, s: int = 1, return_singletons: bool = False, name: Optional[str] = None\n) -> Iterator['Hypergraph']:\n \"\"\"\n Yields subgraphs corresponding to each s-node-connected component.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n name : Optional[str], default=None\n Base name for the subgraphs. Each subgraph will have a unique name appended.\n\n Yields\n ------\n Hypergraph\n Subgraphs representing each connected component.\n \"\"\"\n return self.s_component_subgraphs(\n s=s,\n hyperedges=False,\n return_singletons=return_singletons,\n name=name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_node_distance","title":"get_node_distance(source, target, s=1)
","text":"Returns the shortest s-walk distance between two nodes in the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_node_distance--parameters","title":"Parameters","text":"source : Any A node identifier in the hypergraph. target : Any A node identifier in the hypergraph. s : int, optional, default=1 The number of shared hyperedges required for nodes to be considered adjacent.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_node_distance--returns","title":"Returns","text":"Union[int, float] The shortest s-walk distance between the source and target nodes. Returns float('inf')
if no path exists.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_node_distance--raises","title":"Raises","text":"HypergraphError If either the source or target node does not exist in the hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def get_node_distance(self, source: Any, target: Any, s: int = 1) -> Union[int, float]:\n \"\"\"\n Returns the shortest s-walk distance between two nodes in the hypergraph.\n\n Parameters\n ----------\n source : Any\n A node identifier in the hypergraph.\n target : Any\n A node identifier in the hypergraph.\n s : int, optional, default=1\n The number of shared hyperedges required for nodes to be considered adjacent.\n\n Returns\n -------\n Union[int, float]\n The shortest s-walk distance between the source and target nodes.\n Returns `float('inf')` if no path exists.\n\n Raises\n ------\n HypergraphError\n If either the source or target node does not exist in the hypergraph.\n \"\"\"\n if source not in self.node_properties:\n raise HypergraphError(f\"Source node '{source}' does not exist in the hypergraph.\")\n if target not in self.node_properties:\n raise HypergraphError(f\"Target node '{target}' does not exist in the hypergraph.\")\n\n A, node_id_map = self.adjacency_matrix(s=s, index=True)\n if A is None:\n raise HypergraphError(\"Adjacency matrix could not be generated.\")\n\n graph = nx.from_scipy_sparse_array(A)\n\n try:\n distance = nx.shortest_path_length(graph, source=source, target=target)\n return distance\n except (nx.NetworkXNoPath, nx.NodeNotFound):\n warnings.warn(f\"No s-walk path between '{source}' and '{target}'. Returning infinity.\")\n return float('inf')\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_singleton_hyperedges","title":"get_singleton_hyperedges()
","text":"Returns a list of singleton hyperedges. A singleton hyperedge is a hyperedge of size 1 where its sole node has degree 1.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_singleton_hyperedges--returns","title":"Returns","text":"List[Any] A list of singleton hyperedge IDs.
Source code in src/aeiva/hypergraph/hypergraph.py
def get_singleton_hyperedges(self) -> List[Any]:\n \"\"\"\n Returns a list of singleton hyperedges.\n A singleton hyperedge is a hyperedge of size 1 where its sole node has degree 1.\n\n Returns\n -------\n List[Any]\n A list of singleton hyperedge IDs.\n \"\"\"\n singletons = []\n for he in self.hyperedges.values():\n if len(he.nodes) == 1:\n node = next(iter(he.nodes))\n node_degree = sum(1 for hyperedge in self.hyperedges.values() if node in hyperedge.nodes)\n if node_degree == 1:\n singletons.append(he.id)\n return singletons\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_toplexes","title":"get_toplexes(return_hypergraph=False)
","text":"Computes a maximal collection of toplexes for the hypergraph. A :term:toplex
is a hyperedge that is not contained in any other hyperedge.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_toplexes--parameters","title":"Parameters","text":"return_hypergraph : bool, optional, default=False If True, returns a new Hypergraph consisting only of the toplexes.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_toplexes--returns","title":"Returns","text":"List[Any] or Hypergraph - A list of toplex hyperedge IDs. - If return_hypergraph=True
, returns a Hypergraph containing only the toplexes.
Source code in src/aeiva/hypergraph/hypergraph.py
def get_toplexes(self, return_hypergraph: bool = False) -> Union[List[Any], 'Hypergraph']:\n \"\"\"\n Computes a maximal collection of toplexes for the hypergraph.\n A :term:`toplex` is a hyperedge that is not contained in any other hyperedge.\n\n Parameters\n ----------\n return_hypergraph : bool, optional, default=False\n If True, returns a new Hypergraph consisting only of the toplexes.\n\n Returns\n -------\n List[Any] or Hypergraph\n - A list of toplex hyperedge IDs.\n - If `return_hypergraph=True`, returns a Hypergraph containing only the toplexes.\n \"\"\"\n toplexes = []\n hyperedges = list(self.hyperedges.values())\n\n for he in hyperedges:\n if not any(he.nodes < other_he.nodes for other_he in hyperedges if he.id != other_he.id):\n toplexes.append(he.id)\n\n if return_hypergraph:\n return self.restrict_to_specific_hyperedges(toplexes, name=\"Toplexes\")\n return toplexes\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.hyperedge_adjacency_matrix","title":"hyperedge_adjacency_matrix(s=1, index=False)
","text":"Generates the adjacency matrix for hyperedges based on s-hyperedge connectivity.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.hyperedge_adjacency_matrix--parameters","title":"Parameters","text":"s : int, optional, default=1 The number of shared nodes required for hyperedges to be considered adjacent. index : bool, optional, default=False If True, returns a mapping from matrix indices to hyperedge IDs.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.hyperedge_adjacency_matrix--returns","title":"Returns","text":"Tuple[Optional[csr_matrix], Dict[int, Any]] - The adjacency matrix in CSR format. - A dictionary mapping matrix indices to hyperedge IDs.
Source code in src/aeiva/hypergraph/hypergraph.py
def hyperedge_adjacency_matrix(self, s: int = 1, index: bool = False) -> Tuple[Optional[csr_matrix], Dict[int, Any]]:\n \"\"\"\n Generates the adjacency matrix for hyperedges based on s-hyperedge connectivity.\n\n Parameters\n ----------\n s : int, optional, default=1\n The number of shared nodes required for hyperedges to be considered adjacent.\n index : bool, optional, default=False\n If True, returns a mapping from matrix indices to hyperedge IDs.\n\n Returns\n -------\n Tuple[Optional[csr_matrix], Dict[int, Any]]\n - The adjacency matrix in CSR format.\n - A dictionary mapping matrix indices to hyperedge IDs.\n \"\"\"\n from scipy.sparse import lil_matrix\n\n hyperedge_ids = list(self.hyperedges.keys())\n he_index = {he_id: idx for idx, he_id in enumerate(hyperedge_ids)}\n size = len(hyperedge_ids)\n if size == 0:\n return None, {}\n\n A = lil_matrix((size, size), dtype=int)\n for i, he1 in enumerate(hyperedge_ids):\n nodes1 = self.hyperedges[he1].nodes\n for j in range(i + 1, size):\n he2 = hyperedge_ids[j]\n nodes2 = self.hyperedges[he2].nodes\n shared_nodes = nodes1 & nodes2\n if len(shared_nodes) >= s:\n A[i, j] = 1\n A[j, i] = 1\n\n A = A.tocsr()\n\n if index:\n return A, he_index\n return A, {}\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.intersection","title":"intersection(other, inplace=False, name=None)
","text":"Returns the intersection of the current hypergraph with another hypergraph. The intersection includes only nodes and hyperedges present in both hypergraphs.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.intersection--parameters","title":"Parameters","text":"other : Hypergraph The hypergraph to intersect with. inplace : bool, optional, default=False If True, modifies the current hypergraph to keep only the intersecting elements. Otherwise, returns a new Hypergraph instance. name : Optional[str], default=None The name for the resulting hypergraph. If None, defaults to 'Intersection_of_{self.name}_{other.name}'.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.intersection--returns","title":"Returns","text":"Hypergraph The resulting intersection hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.intersection--raises","title":"Raises","text":"TypeError If other
is not an instance of Hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def intersection(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Returns the intersection of the current hypergraph with another hypergraph.\n The intersection includes only nodes and hyperedges present in both hypergraphs.\n\n Parameters\n ----------\n other : Hypergraph\n The hypergraph to intersect with.\n inplace : bool, optional, default=False\n If True, modifies the current hypergraph to keep only the intersecting elements.\n Otherwise, returns a new Hypergraph instance.\n name : Optional[str], default=None\n The name for the resulting hypergraph. If None, defaults to 'Intersection_of_{self.name}_{other.name}'.\n\n Returns\n -------\n Hypergraph\n The resulting intersection hypergraph.\n\n Raises\n ------\n TypeError\n If `other` is not an instance of Hypergraph.\n \"\"\"\n if not isinstance(other, Hypergraph):\n raise TypeError(\"The `other` parameter must be an instance of Hypergraph.\")\n\n intersect_nodes = set(self.node_properties.keys()) & set(other.node_properties.keys())\n intersect_hyperedges = set(self.hyperedges.keys()) & set(other.hyperedges.keys())\n\n if inplace:\n # Remove non-intersecting nodes and hyperedges\n nodes_to_remove = set(self.node_properties.keys()) - intersect_nodes\n hyperedges_to_remove = set(self.hyperedges.keys()) - intersect_hyperedges\n self.remove_nodes_from(nodes_to_remove, inplace=True)\n self.remove_hyperedges(hyperedges_to_remove, inplace=True)\n return self\n else:\n # Create a new Hypergraph instance\n new_hyperedges = {}\n new_node_properties = {node_id: copy.deepcopy(self.node_properties[node_id]) for node_id in intersect_nodes}\n new_hyperedge_properties = {}\n new_graph = nx.Graph()\n new_bipartite_nodes = set()\n\n for he_id in intersect_hyperedges:\n he_self = self.hyperedges[he_id]\n he_other = other.hyperedges[he_id]\n # Intersection hyperedges have the same nodes and merged properties\n new_nodes = set(he_self.nodes) & set(he_other.nodes)\n if not new_nodes:\n continue # Skip hyperedges with no common nodes\n new_hyperedges[he_id] = HyperEdge(id=he_id, nodes=new_nodes, properties={})\n # Merge properties (could define specific rules)\n new_hyperedge_properties[he_id] = {**self.hyperedge_properties.get(he_id, {}), \n **other.hyperedge_properties.get(he_id, {})}\n new_graph.add_node(he_id, bipartite='hyperedge', **new_hyperedge_properties[he_id])\n new_bipartite_nodes.add(he_id)\n for node in new_nodes:\n new_graph.add_edge(he_id, node)\n\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=name if name else f\"Intersection_of_{self.name}_{other.name}\"\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.is_hyperedge_connected","title":"is_hyperedge_connected(s=1)
","text":"Determines if the hypergraph is s-hyperedge-connected.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.is_hyperedge_connected--parameters","title":"Parameters","text":"s : int, optional, default=1 The connectivity level to check.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.is_hyperedge_connected--returns","title":"Returns","text":"bool True if the hypergraph is s-hyperedge-connected, False otherwise.
Source code in src/aeiva/hypergraph/hypergraph.py
def is_hyperedge_connected(self, s: int = 1) -> bool:\n \"\"\"\n Determines if the hypergraph is s-hyperedge-connected.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n\n Returns\n -------\n bool\n True if the hypergraph is s-hyperedge-connected, False otherwise.\n \"\"\"\n return self._is_connected(s=s, hyperedges=True)\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.is_node_connected","title":"is_node_connected(s=1)
","text":"Determines if the hypergraph is s-node-connected.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.is_node_connected--parameters","title":"Parameters","text":"s : int, optional, default=1 The connectivity level to check.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.is_node_connected--returns","title":"Returns","text":"bool True if the hypergraph is s-node-connected, False otherwise.
Source code in src/aeiva/hypergraph/hypergraph.py
def is_node_connected(self, s: int = 1) -> bool:\n \"\"\"\n Determines if the hypergraph is s-node-connected.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n\n Returns\n -------\n bool\n True if the hypergraph is s-node-connected, False otherwise.\n \"\"\"\n return self._is_connected(s=s, hyperedges=False)\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.node_memberships","title":"node_memberships()
","text":"Returns a dictionary where each key is a node ID and the value is a list of hyperedge IDs that include the node.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.node_memberships--returns","title":"Returns","text":"Dict[Any, List[Any]] Dictionary mapping node IDs to the hyperedge IDs they belong to.
Source code in src/aeiva/hypergraph/hypergraph.py
def node_memberships(self) -> Dict[Any, List[Any]]:\n \"\"\"\n Returns a dictionary where each key is a node ID and the value is a list of hyperedge IDs that include the node.\n\n Returns\n -------\n Dict[Any, List[Any]]\n Dictionary mapping node IDs to the hyperedge IDs they belong to.\n \"\"\"\n memberships = {}\n for he_id, hyperedge in self.hyperedges.items():\n for node in hyperedge.nodes:\n memberships.setdefault(node, []).append(he_id)\n return memberships\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.nodes","title":"nodes()
","text":"Returns a list of all unique node identifiers in the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.nodes--returns","title":"Returns","text":"List[Any] List of node IDs.
Source code in src/aeiva/hypergraph/hypergraph.py
def nodes(self) -> List[Any]:\n \"\"\"\n Returns a list of all unique node identifiers in the hypergraph.\n\n Returns\n -------\n List[Any]\n List of node IDs.\n \"\"\"\n return list(self.node_properties.keys())\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_hyperedge","title":"remove_hyperedge(he_id)
","text":"Removes a hyperedge from the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_hyperedge--parameters","title":"Parameters","text":"he_id : Any Identifier of the hyperedge to remove.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_hyperedge--raises","title":"Raises","text":"HypergraphError If the hyperedge does not exist.
Source code in src/aeiva/hypergraph/hypergraph.py
def remove_hyperedge(self, he_id: Any) -> None:\n \"\"\"\n Removes a hyperedge from the hypergraph.\n\n Parameters\n ----------\n he_id : Any\n Identifier of the hyperedge to remove.\n\n Raises\n ------\n HypergraphError\n If the hyperedge does not exist.\n \"\"\"\n if he_id not in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' does not exist.\")\n\n # Remove hyperedge from the graph, which also removes all incidences\n self.graph.remove_node(he_id)\n self.bipartite_nodes.discard(he_id)\n\n # Remove from internal structures\n del self.hyperedges[he_id]\n self.hyperedge_properties.pop(he_id, None)\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_hyperedges","title":"remove_hyperedges(he_ids, inplace=True)
","text":"Removes the specified hyperedges from the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_hyperedges--parameters","title":"Parameters","text":"he_ids : Any | Iterable[Any] Hyperedge identifier(s) to remove. inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the hyperedges removed.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_hyperedges--returns","title":"Returns","text":"Hypergraph The updated or new Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_hyperedges--raises","title":"Raises","text":"HypergraphError If any hyperedge ID does not exist.
Source code in src/aeiva/hypergraph/hypergraph.py
def remove_hyperedges(self, he_ids: Union[Any, Iterable[Any]], inplace: bool = True) -> 'Hypergraph':\n \"\"\"\n Removes the specified hyperedges from the hypergraph.\n\n Parameters\n ----------\n he_ids : Any | Iterable[Any]\n Hyperedge identifier(s) to remove.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the hyperedges removed.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any hyperedge ID does not exist.\n \"\"\"\n if isinstance(he_ids, (str, int)):\n he_ids = [he_ids]\n else:\n he_ids = list(he_ids)\n\n non_existing = set(he_ids) - set(self.hyperedges.keys())\n if non_existing:\n raise HypergraphError(f\"Hyperedges {non_existing} do not exist in the hypergraph.\")\n\n if inplace:\n for he_id in he_ids:\n self.remove_hyperedge(he_id)\n return self\n else:\n # Create a new Hypergraph instance with hyperedges removed\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for he_id in he_ids:\n del new_hyperedges[he_id]\n new_hyperedge_properties.pop(he_id, None)\n new_graph.remove_node(he_id)\n new_bipartite_nodes.discard(he_id)\n\n # Remove nodes not connected to any hyperedges\n retained_nodes = set()\n for hyperedge in new_hyperedges.values():\n retained_nodes.update(hyperedge.nodes)\n\n new_node_properties = {node: props for node, props in new_node_properties.items() if node in retained_nodes}\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_incidence","title":"remove_incidence(he_id, node_id, inplace=True)
","text":"Removes a single incidence from the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_incidence--parameters","title":"Parameters","text":"he_id : Any Identifier of the hyperedge. node_id : Any Identifier of the node. inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the incidence removed.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_incidence--returns","title":"Returns","text":"Hypergraph The updated or new Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_incidence--raises","title":"Raises","text":"HypergraphError If the hyperedge or node does not exist, or if the incidence does not exist.
Source code in src/aeiva/hypergraph/hypergraph.py
def remove_incidence(\n self,\n he_id: Any,\n node_id: Any,\n inplace: bool = True\n) -> 'Hypergraph':\n \"\"\"\n Removes a single incidence from the hypergraph.\n\n Parameters\n ----------\n he_id : Any\n Identifier of the hyperedge.\n node_id : Any\n Identifier of the node.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the incidence removed.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If the hyperedge or node does not exist, or if the incidence does not exist.\n \"\"\"\n if he_id not in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' does not exist in the hypergraph.\")\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n if node_id not in self.hyperedges[he_id].nodes:\n raise HypergraphError(f\"Incidence between hyperedge '{he_id}' and node '{node_id}' does not exist.\")\n\n if inplace:\n # Remove node from HyperEdge's nodes\n self.hyperedges[he_id].remove_node(node_id)\n # Remove edge from graph\n self.graph.remove_edge(he_id, node_id)\n return self\n else:\n # Create a new Hypergraph instance with the incidence removed\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n # Remove node from HyperEdge's nodes\n new_hyperedges[he_id].remove_node(node_id)\n # Remove edge from graph\n new_graph.remove_edge(he_id, node_id)\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_incidences","title":"remove_incidences(incidences, inplace=True)
","text":"Removes the specified incidences from the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_incidences--parameters","title":"Parameters","text":"incidences : Iterable[Tuple[Any, Any]] Incidence identifiers as tuples of (he_id, node_id). inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the incidences removed.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_incidences--returns","title":"Returns","text":"Hypergraph The updated or new Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_incidences--raises","title":"Raises","text":"HypergraphError If any incidence does not exist.
Source code in src/aeiva/hypergraph/hypergraph.py
def remove_incidences(\n self,\n incidences: Iterable[Tuple[Any, Any]],\n inplace: bool = True\n) -> 'Hypergraph':\n \"\"\"\n Removes the specified incidences from the hypergraph.\n\n Parameters\n ----------\n incidences : Iterable[Tuple[Any, Any]]\n Incidence identifiers as tuples of (he_id, node_id).\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the incidences removed.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any incidence does not exist.\n \"\"\"\n incidence_ids = list(incidences)\n\n # Check existence of incidences\n for he_id, node_id in incidence_ids:\n if he_id not in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' does not exist in the hypergraph.\")\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n if node_id not in self.hyperedges[he_id].nodes:\n raise HypergraphError(f\"Incidence between hyperedge '{he_id}' and node '{node_id}' does not exist.\")\n\n if inplace:\n for he_id, node_id in incidence_ids:\n # Remove node from HyperEdge's nodes\n self.hyperedges[he_id].remove_node(node_id)\n # Remove edge from graph\n self.graph.remove_edge(he_id, node_id)\n return self\n else:\n # Create a new Hypergraph instance with the incidences removed\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for he_id, node_id in incidence_ids:\n # Remove node from HyperEdge's nodes\n new_hyperedges[he_id].remove_node(node_id)\n # Remove edge from graph\n new_graph.remove_edge(he_id, node_id)\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_node","title":"remove_node(node_id, inplace=True)
","text":"Removes a node from the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_node--parameters","title":"Parameters","text":"node_id : Any Identifier of the node to remove. inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the node removed.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_node--returns","title":"Returns","text":"Hypergraph The updated or new Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_node--raises","title":"Raises","text":"HypergraphError If the node does not exist.
Source code in src/aeiva/hypergraph/hypergraph.py
def remove_node(self, node_id: Any, inplace: bool = True) -> 'Hypergraph':\n \"\"\"\n Removes a node from the hypergraph.\n\n Parameters\n ----------\n node_id : Any\n Identifier of the node to remove.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the node removed.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If the node does not exist.\n \"\"\"\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n\n if inplace:\n # Remove node from node_properties\n del self.node_properties[node_id]\n # Remove node from all hyperedges\n for hyperedge in self.hyperedges.values():\n if node_id in hyperedge.nodes:\n hyperedge.remove_node(node_id)\n # Remove node from graph, which also removes all incidences\n self.graph.remove_node(node_id)\n return self\n else:\n # Create a new Hypergraph instance with the node removed\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n # Remove node from node_properties\n del new_node_properties[node_id]\n # Remove node from all hyperedges\n for hyperedge in new_hyperedges.values():\n if node_id in hyperedge.nodes:\n hyperedge.remove_node(node_id)\n # Remove node from graph, which also removes all incidences\n new_graph.remove_node(node_id)\n\n # Remove nodes not connected to any hyperedges\n retained_nodes = set()\n for hyperedge in new_hyperedges.values():\n retained_nodes.update(hyperedge.nodes)\n\n new_node_properties = {node: props for node, props in new_node_properties.items() if node in retained_nodes}\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_nodes_from","title":"remove_nodes_from(nodes, inplace=True)
","text":"Removes the specified nodes from the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_nodes_from--parameters","title":"Parameters","text":"nodes : Any | Iterable[Any] Node identifier(s) to remove. inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the nodes removed.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_nodes_from--returns","title":"Returns","text":"Hypergraph The updated or new Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_nodes_from--raises","title":"Raises","text":"HypergraphError If any node ID does not exist.
Source code in src/aeiva/hypergraph/hypergraph.py
def remove_nodes_from(\n self,\n nodes: Union[Any, Iterable[Any]],\n inplace: bool = True\n) -> 'Hypergraph':\n \"\"\"\n Removes the specified nodes from the hypergraph.\n\n Parameters\n ----------\n nodes : Any | Iterable[Any]\n Node identifier(s) to remove.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the nodes removed.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any node ID does not exist.\n \"\"\"\n if isinstance(nodes, (str, int)):\n nodes = [nodes]\n else:\n nodes = list(nodes)\n\n non_existing = set(nodes) - set(self.node_properties.keys())\n if non_existing:\n raise HypergraphError(f\"Nodes {non_existing} do not exist in the hypergraph.\")\n\n if inplace:\n for node_id in nodes:\n self.remove_node(node_id)\n return self\n else:\n # Create a new Hypergraph instance with nodes removed\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for node_id in nodes:\n del new_node_properties[node_id]\n # Remove node from all hyperedges\n for hyperedge in new_hyperedges.values():\n if node_id in hyperedge.nodes:\n hyperedge.remove_node(node_id)\n # Remove node from graph, which also removes all incidences\n new_graph.remove_node(node_id)\n\n # Remove nodes not connected to any hyperedges\n retained_nodes = set()\n for hyperedge in new_hyperedges.values():\n retained_nodes.update(hyperedge.nodes)\n\n new_node_properties = {node: props for node, props in new_node_properties.items() if node in retained_nodes}\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_singleton_hyperedges","title":"remove_singleton_hyperedges(name=None)
","text":"Constructs a clone of the hypergraph with singleton hyperedges removed.
Source code in src/aeiva/hypergraph/hypergraph.py
def remove_singleton_hyperedges(self, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Constructs a clone of the hypergraph with singleton hyperedges removed.\n \"\"\"\n singletons = self.get_singleton_hyperedges()\n if not singletons:\n return self.copy(name=name)\n\n new_hypergraph = self.remove_hyperedges(singletons, inplace=False)\n new_hypergraph.name = name if name else f\"{self.name}_no_singleton_hyperedges\"\n return new_hypergraph\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.restrict_to_specific_hyperedges","title":"restrict_to_specific_hyperedges(hyperedges_to_retain, name=None)
","text":"Creates a new hypergraph by retaining only the specified hyperedges and removing all others.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.restrict_to_specific_hyperedges--parameters","title":"Parameters","text":"hyperedges_to_retain : Iterable[Any] An iterable of hyperedge identifiers to retain in the new hypergraph.
Optional[str], default=None The name assigned to the restricted hypergraph. If None, defaults to the original name suffixed with '_restricted_hyperedges'.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.restrict_to_specific_hyperedges--returns","title":"Returns","text":"Hypergraph A new hypergraph containing only the specified hyperedges and their associated nodes.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.restrict_to_specific_hyperedges--raises","title":"Raises","text":"HypergraphError If none of the specified hyperedges exist in the hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def restrict_to_specific_hyperedges(\n self,\n hyperedges_to_retain: Iterable[Any],\n name: Optional[str] = None\n) -> 'Hypergraph':\n \"\"\"\n Creates a new hypergraph by retaining only the specified hyperedges and removing all others.\n\n Parameters\n ----------\n hyperedges_to_retain : Iterable[Any]\n An iterable of hyperedge identifiers to retain in the new hypergraph.\n\n name : Optional[str], default=None\n The name assigned to the restricted hypergraph. If None, defaults to the original name suffixed with '_restricted_hyperedges'.\n\n Returns\n -------\n Hypergraph\n A new hypergraph containing only the specified hyperedges and their associated nodes.\n\n Raises\n ------\n HypergraphError\n If none of the specified hyperedges exist in the hypergraph.\n \"\"\"\n hyperedges_to_retain = set(hyperedges_to_retain)\n existing_hyperedges = set(self.hyperedges.keys())\n invalid_hyperedges = hyperedges_to_retain - existing_hyperedges\n if invalid_hyperedges:\n raise HypergraphError(f\"The following hyperedges do not exist and cannot be retained: {invalid_hyperedges}\")\n\n # Determine hyperedges to remove\n hyperedges_to_remove = existing_hyperedges - hyperedges_to_retain\n if not hyperedges_to_remove:\n # No hyperedges to remove; return the original hypergraph\n return self\n\n # Remove hyperedges using the existing remove_hyperedges method\n restricted_hypergraph = self.remove_hyperedges(hyperedges_to_remove, inplace=False)\n restricted_hypergraph.name = name if name else f\"{self.name}_restricted_hyperedges\"\n\n return restricted_hypergraph\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.restrict_to_specific_nodes","title":"restrict_to_specific_nodes(nodes_to_retain, name=None)
","text":"Creates a new hypergraph by retaining only the specified nodes and removing all others.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.restrict_to_specific_nodes--parameters","title":"Parameters","text":"nodes_to_retain : Iterable[Any] An iterable of node identifiers to retain in the new hypergraph.
Optional[str], default=None The name assigned to the restricted hypergraph. If None, defaults to the original name suffixed with '_restricted_nodes'.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.restrict_to_specific_nodes--returns","title":"Returns","text":"Hypergraph A new hypergraph containing only the specified nodes and their associated hyperedges.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.restrict_to_specific_nodes--raises","title":"Raises","text":"HypergraphError If none of the specified nodes exist in the hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def restrict_to_specific_nodes(\n self,\n nodes_to_retain: Iterable[Any],\n name: Optional[str] = None\n) -> 'Hypergraph':\n \"\"\"\n Creates a new hypergraph by retaining only the specified nodes and removing all others.\n\n Parameters\n ----------\n nodes_to_retain : Iterable[Any]\n An iterable of node identifiers to retain in the new hypergraph.\n\n name : Optional[str], default=None\n The name assigned to the restricted hypergraph. If None, defaults to the original name suffixed with '_restricted_nodes'.\n\n Returns\n -------\n Hypergraph\n A new hypergraph containing only the specified nodes and their associated hyperedges.\n\n Raises\n ------\n HypergraphError\n If none of the specified nodes exist in the hypergraph.\n \"\"\"\n nodes_to_retain = set(nodes_to_retain)\n existing_nodes = set(self.node_properties.keys())\n invalid_nodes = nodes_to_retain - existing_nodes\n if invalid_nodes:\n raise HypergraphError(f\"The following nodes do not exist and cannot be retained: {invalid_nodes}\")\n\n # Determine nodes to remove\n nodes_to_remove = existing_nodes - nodes_to_retain\n if not nodes_to_remove:\n # No nodes to remove; return the original hypergraph\n return self\n\n # Remove nodes using the existing remove_nodes_from method\n restricted_hypergraph = self.remove_nodes_from(nodes_to_remove, inplace=False)\n restricted_hypergraph.name = name if name else f\"{self.name}_restricted_nodes\"\n\n return restricted_hypergraph\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.s_component_subgraphs","title":"s_component_subgraphs(s=1, hyperedges=True, return_singletons=False, name=None)
","text":"Yields subgraphs corresponding to each s-hyperedge-connected or s-node-connected component.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.s_component_subgraphs--parameters","title":"Parameters","text":"s : int, optional, default=1 The connectivity level to check. hyperedges : bool, optional, default=True If True, yields subgraphs of s-hyperedge-connected components. Otherwise, yields subgraphs of s-node-connected components. return_singletons : bool, optional, default=False If True, includes singleton components. Otherwise, excludes them. name : Optional[str], default=None Base name for the subgraphs. Each subgraph will have a unique name appended.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.s_component_subgraphs--yields","title":"Yields","text":"Hypergraph Subgraphs representing each connected component.
Source code in src/aeiva/hypergraph/hypergraph.py
def s_component_subgraphs(\n self,\n s: int = 1,\n hyperedges: bool = True,\n return_singletons: bool = False,\n name: Optional[str] = None\n) -> Iterator['Hypergraph']:\n \"\"\"\n Yields subgraphs corresponding to each s-hyperedge-connected or s-node-connected component.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n hyperedges : bool, optional, default=True\n If True, yields subgraphs of s-hyperedge-connected components. Otherwise, yields subgraphs of s-node-connected components.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n name : Optional[str], default=None\n Base name for the subgraphs. Each subgraph will have a unique name appended.\n\n Yields\n ------\n Hypergraph\n Subgraphs representing each connected component.\n \"\"\"\n for idx, component in enumerate(\n self.s_connected_components(s=s, hyperedges=hyperedges, return_singletons=return_singletons)\n ):\n if hyperedges:\n yield self.restrict_to_specific_hyperedges(\n hyperedges_to_retain=component, \n name=f\"{name or self.name}_component_{idx}\"\n )\n else:\n yield self.restrict_to_specific_nodes(\n nodes_to_retain=component, \n name=f\"{name or self.name}_component_{idx}\"\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.s_connected_components","title":"s_connected_components(s=1, hyperedges=True, return_singletons=False)
","text":"Yields the s-hyperedge-connected or s-node-connected components of the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.s_connected_components--parameters","title":"Parameters","text":"s : int, optional, default=1 The connectivity level to check. hyperedges : bool, optional, default=True If True, yields s-hyperedge-connected components. Otherwise, yields s-node-connected components. return_singletons : bool, optional, default=False If True, includes singleton components. Otherwise, excludes them.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.s_connected_components--yields","title":"Yields","text":"Set[Any] Sets of hyperedge IDs or node IDs representing each connected component.
Source code in src/aeiva/hypergraph/hypergraph.py
def s_connected_components(\n self, \n s: int = 1, \n hyperedges: bool = True, \n return_singletons: bool = False\n) -> Iterator[Set[Any]]:\n \"\"\"\n Yields the s-hyperedge-connected or s-node-connected components of the hypergraph.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n hyperedges : bool, optional, default=True\n If True, yields s-hyperedge-connected components. Otherwise, yields s-node-connected components.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n\n Yields\n ------\n Set[Any]\n Sets of hyperedge IDs or node IDs representing each connected component.\n \"\"\"\n if hyperedges:\n # s-hyperedge-connected: hyperedges are connected if they share at least s nodes\n hyperedge_graph = nx.Graph()\n hyperedge_ids = list(self.hyperedges.keys())\n hyperedge_graph.add_nodes_from(hyperedge_ids)\n\n for i, he1 in enumerate(hyperedge_ids):\n nodes1 = self.hyperedges[he1].nodes\n for he2 in hyperedge_ids[i + 1:]:\n nodes2 = self.hyperedges[he2].nodes\n shared_nodes = nodes1 & nodes2\n if len(shared_nodes) >= s:\n hyperedge_graph.add_edge(he1, he2)\n\n components = nx.connected_components(hyperedge_graph)\n for component in components:\n if not return_singletons and len(component) == 1:\n continue\n yield component\n else:\n # s-node-connected: nodes are connected if they share at least s hyperedges\n node_graph = nx.Graph()\n node_ids = list(self.node_properties.keys())\n node_graph.add_nodes_from(node_ids)\n\n for i, node1 in enumerate(node_ids):\n hyperedges1 = {he.id for he in self.hyperedges.values() if node1 in he.nodes}\n for node2 in node_ids[i + 1:]:\n hyperedges2 = {he.id for he in self.hyperedges.values() if node2 in he.nodes}\n shared_hyperedges = hyperedges1 & hyperedges2\n if len(shared_hyperedges) >= s:\n node_graph.add_edge(node1, node2)\n\n components = nx.connected_components(node_graph)\n for component in components:\n if not return_singletons and len(component) == 1:\n continue\n yield component\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.symmetric_difference","title":"symmetric_difference(other, inplace=False, name=None)
","text":"Returns the symmetric difference of the current hypergraph with another hypergraph. The symmetric difference includes elements present in either hypergraph but not in both.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.symmetric_difference--parameters","title":"Parameters","text":"other : Hypergraph The hypergraph to symmetric difference with. inplace : bool, optional, default=False If True, modifies the current hypergraph to keep only the symmetric difference elements. Otherwise, returns a new Hypergraph instance. name : Optional[str], default=None The name for the resulting hypergraph. If None, defaults to 'SymmetricDifference_of_{self.name}_{other.name}'.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.symmetric_difference--returns","title":"Returns","text":"Hypergraph The resulting symmetric difference hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.symmetric_difference--raises","title":"Raises","text":"TypeError If other
is not an instance of Hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def symmetric_difference(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Returns the symmetric difference of the current hypergraph with another hypergraph.\n The symmetric difference includes elements present in either hypergraph but not in both.\n\n Parameters\n ----------\n other : Hypergraph\n The hypergraph to symmetric difference with.\n inplace : bool, optional, default=False\n If True, modifies the current hypergraph to keep only the symmetric difference elements.\n Otherwise, returns a new Hypergraph instance.\n name : Optional[str], default=None\n The name for the resulting hypergraph. If None, defaults to 'SymmetricDifference_of_{self.name}_{other.name}'.\n\n Returns\n -------\n Hypergraph\n The resulting symmetric difference hypergraph.\n\n Raises\n ------\n TypeError\n If `other` is not an instance of Hypergraph.\n \"\"\"\n if not isinstance(other, Hypergraph):\n raise TypeError(\"The `other` parameter must be an instance of Hypergraph.\")\n\n if inplace:\n # Hyperedges symmetric difference\n hyperedges_to_add = set(other.hyperedges.keys()) - set(self.hyperedges.keys())\n hyperedges_to_remove = set(self.hyperedges.keys()) & set(other.hyperedges.keys())\n self.remove_hyperedges(hyperedges_to_remove, inplace=True)\n for he_id in hyperedges_to_add:\n hyperedge = other.hyperedges[he_id]\n self.add_hyperedge(he_id, hyperedge.nodes, properties=hyperedge.properties)\n\n # Nodes symmetric difference\n nodes_to_add = set(other.node_properties.keys()) - set(self.node_properties.keys())\n nodes_to_remove = set(self.node_properties.keys()) & set(other.node_properties.keys())\n self.remove_nodes_from(nodes_to_remove, inplace=True)\n for node_id in nodes_to_add:\n props = other.node_properties[node_id]\n self.add_node(node_id, properties=props, inplace=True)\n\n if name:\n self.name = name\n return self\n else:\n # Create a new Hypergraph instance\n union_hg = self.union(other)\n intersection_hg = self.intersection(other)\n return union_hg.difference(intersection_hg, name=name if name else f\"SymmetricDifference_of_{self.name}_{other.name}\")\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.to_bipartite_graph","title":"to_bipartite_graph(keep_data=False, directed=False)
","text":"Creates a bipartite NetworkX graph from the hypergraph. The nodes and hyperedges of the hypergraph become nodes in the bipartite graph. For every hyperedge in the hypergraph and each node it connects to, there is an edge in the bipartite graph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.to_bipartite_graph--parameters","title":"Parameters","text":"keep_data : bool, optional, default = False If True, includes the node and hyperedge properties in the NetworkX graph. directed : bool, optional, default = False If True, the edges in the graph are directed with hyperedges as sources and nodes as targets.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.to_bipartite_graph--returns","title":"Returns","text":"networkx.Graph or networkx.DiGraph The bipartite graph representation of the hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def to_bipartite_graph(self, keep_data=False, directed=False) -> nx.Graph:\n \"\"\"\n Creates a bipartite NetworkX graph from the hypergraph.\n The nodes and hyperedges of the hypergraph become nodes in the bipartite graph.\n For every hyperedge in the hypergraph and each node it connects to, there\n is an edge in the bipartite graph.\n\n Parameters\n ----------\n keep_data : bool, optional, default = False\n If True, includes the node and hyperedge properties in the NetworkX graph.\n directed : bool, optional, default = False\n If True, the edges in the graph are directed with hyperedges as sources and nodes as targets.\n\n Returns\n -------\n networkx.Graph or networkx.DiGraph\n The bipartite graph representation of the hypergraph.\n \"\"\"\n # Choose graph type based on directed flag\n B = nx.DiGraph() if directed else nx.Graph()\n\n if not keep_data:\n # Add nodes with bipartite attributes, where 0 indicates hyperedges and 1 indicates regular nodes\n B.add_nodes_from(self.hyperedges.keys(), bipartite=0) # hyperedges\n B.add_nodes_from(self.node_properties.keys(), bipartite=1) # nodes\n\n # Add edges between hyperedges and nodes based on hyperedges data\n for he_id, hyperedge in self.hyperedges.items():\n for node in hyperedge.nodes:\n B.add_edge(he_id, node)\n else:\n # Add nodes with properties if keep_data is True\n for node_id, properties in self.node_properties.items():\n B.add_node(node_id, bipartite=1, **properties)\n\n for he_id, hyperedge in self.hyperedges.items():\n B.add_node(he_id, bipartite=0, **self.hyperedge_properties.get(he_id, {}))\n for node in hyperedge.nodes:\n # Add edges with optional properties if keep_data is True\n B.add_edge(he_id, node)\n\n return B\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.transpose","title":"transpose(name=None)
","text":"Transposes the hypergraph by swapping the roles of nodes and hyperedges. The resulting hypergraph has hyperedges corresponding to the original nodes and vice versa.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.transpose--parameters","title":"Parameters","text":"name : Optional[str], default=None The name assigned to the transposed hypergraph. If None, defaults to the original name suffixed with '_transposed'.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.transpose--returns","title":"Returns","text":"Hypergraph The transposed hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def transpose(self, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Transposes the hypergraph by swapping the roles of nodes and hyperedges.\n The resulting hypergraph has hyperedges corresponding to the original nodes and vice versa.\n\n Parameters\n ----------\n name : Optional[str], default=None\n The name assigned to the transposed hypergraph. If None, defaults to the original name suffixed with '_transposed'.\n\n Returns\n -------\n Hypergraph\n The transposed hypergraph.\n \"\"\"\n transposed_hyperedges = {node_id: HyperEdge(id=node_id, nodes=set(), properties=copy.deepcopy(props))\n for node_id, props in self.node_properties.items()}\n transposed_node_properties = {he_id: copy.deepcopy(props) for he_id, props in self.hyperedge_properties.items()}\n\n for he_id, hyperedge in self.hyperedges.items():\n for node in hyperedge.nodes:\n if node in transposed_hyperedges:\n transposed_hyperedges[node].nodes.add(he_id)\n\n # Construct the transposed hypergraph\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in transposed_hyperedges.items()\n },\n node_properties=transposed_node_properties,\n hyperedge_properties={he_id: he.properties.copy() for he_id, he in transposed_hyperedges.items()},\n name=name if name else f\"{self.name}_transposed\"\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.union","title":"union(other, inplace=False, name=None)
","text":"Returns the union of the current hypergraph with another hypergraph. The union combines all nodes and hyperedges from both hypergraphs.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.union--parameters","title":"Parameters","text":"other : Hypergraph The hypergraph to union with. inplace : bool, optional, default=False If True, modifies the current hypergraph. Otherwise, returns a new Hypergraph instance. name : Optional[str], default=None The name for the resulting hypergraph. If None, defaults to 'Union_of_{self.name}_{other.name}'.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.union--returns","title":"Returns","text":"Hypergraph The resulting union hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.union--raises","title":"Raises","text":"TypeError If other
is not an instance of Hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def union(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Returns the union of the current hypergraph with another hypergraph.\n The union combines all nodes and hyperedges from both hypergraphs.\n\n Parameters\n ----------\n other : Hypergraph\n The hypergraph to union with.\n inplace : bool, optional, default=False\n If True, modifies the current hypergraph. Otherwise, returns a new Hypergraph instance.\n name : Optional[str], default=None\n The name for the resulting hypergraph. If None, defaults to 'Union_of_{self.name}_{other.name}'.\n\n Returns\n -------\n Hypergraph\n The resulting union hypergraph.\n\n Raises\n ------\n TypeError\n If `other` is not an instance of Hypergraph.\n \"\"\"\n if not isinstance(other, Hypergraph):\n raise TypeError(\"The `other` parameter must be an instance of Hypergraph.\")\n\n if inplace:\n # Add nodes from other\n for node_id, props in other.node_properties.items():\n if node_id not in self.node_properties:\n self.add_node(node_id, properties=props, inplace=True)\n else:\n # Optionally, merge properties\n self.node_properties[node_id].update(props)\n self.graph.nodes[node_id].update(props)\n\n # Add hyperedges from other\n for he_id, hyperedge in other.hyperedges.items():\n if he_id not in self.hyperedges:\n self.add_hyperedge(he_id, hyperedge.nodes, properties=hyperedge.properties)\n else:\n # Optionally, merge properties and nodes\n self.hyperedges[he_id].nodes.update(hyperedge.nodes)\n self.hyperedge_properties[he_id].update(hyperedge.properties)\n for node in hyperedge.nodes:\n if node not in self.graph:\n self.add_node(node)\n self.graph.add_edge(he_id, node)\n if name:\n self.name = name\n return self\n else:\n # Create a new Hypergraph instance\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n new_name = name if name else f\"Union_of_{self.name}_{other.name}\"\n\n # Add nodes from other\n for node_id, props in other.node_properties.items():\n if node_id not in new_node_properties:\n new_node_properties[node_id] = copy.deepcopy(props)\n new_graph.add_node(node_id, bipartite='node', **props)\n\n # Add hyperedges from other\n for he_id, hyperedge in other.hyperedges.items():\n if he_id not in new_hyperedges:\n new_hyperedges[he_id] = copy.deepcopy(hyperedge)\n new_hyperedge_properties[he_id] = copy.deepcopy(other.hyperedge_properties[he_id])\n new_graph.add_node(he_id, bipartite='hyperedge', **new_hyperedge_properties[he_id])\n new_bipartite_nodes.add(he_id)\n for node in hyperedge.nodes:\n new_graph.add_edge(he_id, node)\n else:\n # Merge nodes and properties\n new_hyperedges[he_id].nodes.update(hyperedge.nodes)\n new_hyperedge_properties[he_id].update(other.hyperedge_properties[he_id])\n for node in hyperedge.nodes:\n new_graph.add_edge(he_id, node)\n\n # Construct the new Hypergraph\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=new_name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization","title":"visualization
","text":""},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_edge_labels","title":"draw_hyper_edge_labels(H, pos, polys, labels={}, edge_labels_on_edge=True, ax=None, **kwargs)
","text":"Draws a label on the hyper edge boundary.
Should be passed Matplotlib PolyCollection representing the hyper-edges, see the return value of draw_hyper_edges.
The label will be draw on the least curvy part of the polygon, and will be aligned parallel to the orientation of the polygon where it is drawn.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_edge_labels--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn polys: PolyCollection collection of polygons returned by draw_hyper_edges labels: dict mapping of node id to string label ax: Axis matplotlib axis on which the plot is rendered kwargs: dict Keyword arguments are passed through to Matplotlib's annotate function.
Source code in src/aeiva/hypergraph/visualization.py
def draw_hyper_edge_labels(\n H, pos, polys, labels={}, edge_labels_on_edge=True, ax=None, **kwargs\n):\n \"\"\"\n Draws a label on the hyper edge boundary.\n\n Should be passed Matplotlib PolyCollection representing the hyper-edges, see\n the return value of draw_hyper_edges.\n\n The label will be draw on the least curvy part of the polygon, and will be\n aligned parallel to the orientation of the polygon where it is drawn.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n polys: PolyCollection\n collection of polygons returned by draw_hyper_edges\n labels: dict\n mapping of node id to string label\n ax: Axis\n matplotlib axis on which the plot is rendered\n kwargs: dict\n Keyword arguments are passed through to Matplotlib's annotate function.\n\n \"\"\"\n ax = ax or plt.gca()\n\n params = transpose_inflated_kwargs(inflate_kwargs(H.edges(), kwargs))\n\n for edge, path, params in zip(H.edges(), polys.get_paths(), params):\n s = labels.get(edge, edge)\n\n theta = 0\n xy = None\n\n if edge_labels_on_edge:\n # calculate the xy location of the annotation\n # this is the midpoint of the pair of adjacent points the most distant\n d = ((path.vertices[:-1] - path.vertices[1:]) ** 2).sum(axis=1)\n i = d.argmax()\n\n x1, x2 = path.vertices[i : i + 2]\n x, y = x2 - x1\n theta = 360 * np.arctan2(y, x) / (2 * np.pi)\n theta = (theta + 360) % 360\n\n while theta > 90:\n theta -= 180\n\n xy = (x1 + x2) / 2\n else:\n xy = pos[edge]\n\n # the string is a comma separated list of the edge uid\n ax.annotate(s, xy, rotation=theta, ha=\"center\", va=\"center\", **params)\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_edges","title":"draw_hyper_edges(H, pos, ax=None, node_radius={}, contain_hyper_edges=False, dr=None, **kwargs)
","text":"Draws a convex hull around the nodes contained within each edge in H
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_edges--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 node_radius: dict mapping of node to R^1 (radius of each node) dr: float the spacing between concentric rings ax: Axis matplotlib axis on which the plot is rendered kwargs: dict keyword arguments, e.g., linewidth, facecolors, are passed through to the PolyCollection constructor
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_edges--returns","title":"Returns","text":"PolyCollection a Matplotlib PolyCollection that can be further styled
Source code in src/aeiva/hypergraph/visualization.py
def draw_hyper_edges(\n H, pos, ax=None, node_radius={}, contain_hyper_edges=False, dr=None, **kwargs\n):\n \"\"\"\n Draws a convex hull around the nodes contained within each edge in H\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n pos: dict\n mapping of node and edge positions to R^2\n node_radius: dict\n mapping of node to R^1 (radius of each node)\n dr: float\n the spacing between concentric rings\n ax: Axis\n matplotlib axis on which the plot is rendered\n kwargs: dict\n keyword arguments, e.g., linewidth, facecolors, are passed through to the PolyCollection constructor\n\n Returns\n -------\n PolyCollection\n a Matplotlib PolyCollection that can be further styled\n \"\"\"\n points = layout_hyper_edges(\n H, pos, node_radius=node_radius, dr=dr, contain_hyper_edges=contain_hyper_edges\n )\n\n polys = PolyCollection(points, **inflate_kwargs(H.edges(), kwargs))\n\n (ax or plt.gca()).add_collection(polys)\n\n return polys\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_edges_two_column","title":"draw_hyper_edges_two_column(H, pos, ax=None, **kwargs)
","text":"Renders hyper edges for the two column layout.
Each node-hyper edge membership is rendered as a line connecting the node in the left column to the edge in the right column.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_edges_two_column--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 ax: Axis matplotlib axis on which the plot is rendered kwargs: dict keyword arguments passed to matplotlib.LineCollection
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_edges_two_column--returns","title":"Returns","text":"LineCollection the hyper edges
Source code in src/aeiva/hypergraph/visualization.py
def draw_hyper_edges_two_column(H, pos, ax=None, **kwargs):\n \"\"\"\n Renders hyper edges for the two column layout.\n\n Each node-hyper edge membership is rendered as a line connecting the node\n in the left column to the edge in the right column.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n pos: dict\n mapping of node and edge positions to R^2\n ax: Axis\n matplotlib axis on which the plot is rendered\n kwargs: dict\n keyword arguments passed to matplotlib.LineCollection\n\n Returns\n -------\n LineCollection\n the hyper edges\n \"\"\"\n ax = ax or plt.gca()\n\n pairs = [(v, e) for e in H.edges() for v in H.edge_elements()[e]]\n\n kwargs = {\n k: v if type(v) != dict else [v.get(e) for _, e in pairs]\n for k, v in kwargs.items()\n }\n\n lines = LineCollection([(pos[u], pos[v]) for u, v in pairs], **kwargs)\n\n ax.add_collection(lines)\n\n return lines\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_labels","title":"draw_hyper_labels(H, pos, node_radius={}, ax=None, labels={}, **kwargs)
","text":"Draws text labels for the hypergraph nodes.
The label is drawn to the right of the node. The node radius is needed (see draw_hyper_nodes) so the text can be offset appropriately as the node size changes.
The text label can be customized by passing in a dictionary, labels, mapping a node to its custom label. By default, the label is the string representation of the node.
Keyword arguments are passed through to Matplotlib's annotate function.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_labels--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 node_radius: dict mapping of node to R^1 (radius of each node) ax: Axis matplotlib axis on which the plot is rendered labels: dict mapping of node to text label kwargs: dict keyword arguments passed to matplotlib.annotate
Source code in src/aeiva/hypergraph/visualization.py
def draw_hyper_labels(H, pos, node_radius={}, ax=None, labels={}, **kwargs):\n \"\"\"\n Draws text labels for the hypergraph nodes.\n\n The label is drawn to the right of the node. The node radius is needed (see\n draw_hyper_nodes) so the text can be offset appropriately as the node size\n changes.\n\n The text label can be customized by passing in a dictionary, labels, mapping\n a node to its custom label. By default, the label is the string\n representation of the node.\n\n Keyword arguments are passed through to Matplotlib's annotate function.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n pos: dict\n mapping of node and edge positions to R^2\n node_radius: dict\n mapping of node to R^1 (radius of each node)\n ax: Axis\n matplotlib axis on which the plot is rendered\n labels: dict\n mapping of node to text label\n kwargs: dict\n keyword arguments passed to matplotlib.annotate\n\n \"\"\"\n ax = ax or plt.gca()\n params = transpose_inflated_kwargs(inflate_kwargs(H.nodes(), kwargs))\n\n for v, v_kwargs in zip(iter(H.nodes()), params):\n xy = np.array([node_radius.get(v, 0), 0]) + pos[v]\n ax.annotate(\n labels.get(v, v),\n xy,\n **{\n k: (\n d[v]\n if hasattr(d, \"__getitem__\") and type(d) not in {str, tuple}\n else d\n )\n for k, d in kwargs.items()\n }\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_labels_two_column","title":"draw_hyper_labels_two_column(H, pos, labels={}, with_node_labels=True, with_edge_labels=True, ax=None)
","text":"Renders hyper labels (nodes and edges) for the two column layout.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_labels_two_column--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 labels: dict custom labels for nodes and edges can be supplied with_node_labels: bool False to disable node labels with_edge_labels: bool False to disable edge labels ax: Axis matplotlib axis on which the plot is rendered kwargs: dict keyword arguments passed to matplotlib.LineCollection
Source code in src/aeiva/hypergraph/visualization.py
def draw_hyper_labels_two_column(\n H, pos, labels={}, with_node_labels=True, with_edge_labels=True, ax=None\n):\n \"\"\"\n Renders hyper labels (nodes and edges) for the two column layout.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n pos: dict\n mapping of node and edge positions to R^2\n labels: dict\n custom labels for nodes and edges can be supplied\n with_node_labels: bool\n False to disable node labels\n with_edge_labels: bool\n False to disable edge labels\n ax: Axis\n matplotlib axis on which the plot is rendered\n kwargs: dict\n keyword arguments passed to matplotlib.LineCollection\n\n \"\"\"\n\n ax = ax or plt.gca()\n\n to_draw = []\n if with_node_labels:\n to_draw.append((list(H.nodes()), \"right\"))\n\n if with_edge_labels:\n to_draw.append((list(H.edges()), \"left\"))\n\n for points, ha in to_draw:\n for p in points:\n ax.annotate(labels.get(p, p), pos[p], ha=ha, va=\"center\")\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_nodes","title":"draw_hyper_nodes(H, pos, node_radius={}, r0=None, ax=None, **kwargs)
","text":"Draws a circle for each node in H.
The position of each node is specified by the a dictionary/list-like, pos, where pos[v] is the xy-coordinate for the vertex. The radius of each node can be specified as a dictionary where node_radius[v] is the radius. If a node is missing from this dictionary, or the node_radius is not specified at all, a sensible default radius is chosen based on distances between nodes given by pos.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_nodes--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 node_radius: dict mapping of node to R^1 (radius of each node) r0: float minimum distance that concentric rings start from the node position ax: Axis matplotlib axis on which the plot is rendered kwargs: dict keyword arguments, e.g., linewidth, facecolors, are passed through to the PolyCollection constructor
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_nodes--returns","title":"Returns","text":"PolyCollection a Matplotlib PolyCollection that can be further styled
Source code in src/aeiva/hypergraph/visualization.py
def draw_hyper_nodes(H, pos, node_radius={}, r0=None, ax=None, **kwargs):\n \"\"\"\n Draws a circle for each node in H.\n\n The position of each node is specified by the a dictionary/list-like, pos,\n where pos[v] is the xy-coordinate for the vertex. The radius of each node\n can be specified as a dictionary where node_radius[v] is the radius. If a\n node is missing from this dictionary, or the node_radius is not specified at\n all, a sensible default radius is chosen based on distances between nodes\n given by pos.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n pos: dict\n mapping of node and edge positions to R^2\n node_radius: dict\n mapping of node to R^1 (radius of each node)\n r0: float\n minimum distance that concentric rings start from the node position\n ax: Axis\n matplotlib axis on which the plot is rendered\n kwargs: dict\n keyword arguments, e.g., linewidth, facecolors, are passed through to the PolyCollection constructor\n\n Returns\n -------\n PolyCollection\n a Matplotlib PolyCollection that can be further styled\n \"\"\"\n\n ax = ax or plt.gca()\n\n r0 = r0 or get_default_radius(H, pos)\n\n points = [node_radius.get(v, r0) * cp + pos[v] for v in H.nodes()]\n\n kwargs.setdefault(\"facecolors\", \"black\")\n\n circles = PolyCollection(points, **inflate_kwargs(H, kwargs))\n\n ax.add_collection(circles)\n\n return circles\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_rubber_band","title":"draw_rubber_band(H, pos=None, with_color=True, with_node_counts=False, with_edge_counts=False, layout=nx.spring_layout, layout_kwargs={}, ax=None, node_radius=None, edges_kwargs={}, nodes_kwargs={}, edge_labels_on_edge=True, edge_labels={}, edge_labels_kwargs={}, node_labels={}, node_labels_kwargs={}, with_edge_labels=True, with_node_labels=True, node_label_alpha=0.35, edge_label_alpha=0.35, with_additional_edges=None, contain_hyper_edges=False, additional_edges_kwargs={}, return_pos=False)
","text":"Draw a hypergraph as a Matplotlib figure
By default this will draw a colorful \"rubber band\" like hypergraph, where convex hulls represent edges and are drawn around the nodes they contain.
This is a convenience function that wraps calls with sensible parameters to the following lower-level drawing functions:
- draw_hyper_edges,
- draw_hyper_edge_labels,
- draw_hyper_labels, and
- draw_hyper_nodes
The default layout algorithm is nx.spring_layout, but other layouts can be passed in. The Hypergraph is converted to a bipartite graph, and the layout algorithm is passed the bipartite graph.
If you have a pre-determined layout, you can pass in a \"pos\" dictionary. This is a dictionary mapping from node id's to x-y coordinates. For example:
>>> pos = {\n>>> 'A': (0, 0),\n>>> 'B': (1, 2),\n>>> 'C': (5, -3)\n>>> }\n
will position the nodes {A, B, C} manually at the locations specified. The coordinate system is in Matplotlib \"data coordinates\", and the figure will be centered within the figure.
By default, this will draw in a new figure, but the axis to render in can be specified using :code:ax
.
This approach works well for small hypergraphs, and does not guarantee a rigorously \"correct\" drawing. Overlapping of sets in the drawing generally implies that the sets intersect, but sometimes sets overlap if there is no intersection. It is not possible, in general, to draw a \"correct\" hypergraph this way for an arbitrary hypergraph, in the same way that not all graphs have planar drawings.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_rubber_band--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 with_color: bool set to False to disable color cycling of edges with_node_counts: bool set to True to replace the label for collapsed nodes with the number of elements with_edge_counts: bool set to True to label collapsed edges with number of elements layout: function layout algorithm to compute layout_kwargs: dict keyword arguments passed to layout function ax: Axis matplotlib axis on which the plot is rendered edges_kwargs: dict keyword arguments passed to matplotlib.collections.PolyCollection for edges node_radius: None, int, float, or dict radius of all nodes, or dictionary of node:value; the default (None) calculates radius based on number of collapsed nodes; reasonable values range between 1 and 3 nodes_kwargs: dict keyword arguments passed to matplotlib.collections.PolyCollection for nodes edge_labels_on_edge: bool whether to draw edge labels on the edge (rubber band) or inside edge_labels_kwargs: dict keyword arguments passed to matplotlib.annotate for edge labels node_labels_kwargs: dict keyword argumetns passed to matplotlib.annotate for node labels with_edge_labels: bool set to False to make edge labels invisible with_node_labels: bool set to False to make node labels invisible node_label_alpha: float the transparency (alpha) of the box behind text drawn in the figure for node labels edge_label_alpha: float the transparency (alpha) of the box behind text drawn in the figure for edge labels with_additional_edges: networkx.Graph ... contain_hyper_edges: bool whether the rubber band shoudl be drawn around the location of the edge in the bipartite graph. This may be invisibile unless \"with_additional_edges\" contains this information.
Source code in src/aeiva/hypergraph/visualization.py
def draw_rubber_band(\n H,\n pos=None,\n with_color=True,\n with_node_counts=False,\n with_edge_counts=False,\n layout=nx.spring_layout,\n layout_kwargs={},\n ax=None,\n node_radius=None,\n edges_kwargs={},\n nodes_kwargs={},\n edge_labels_on_edge=True,\n edge_labels={},\n edge_labels_kwargs={},\n node_labels={},\n node_labels_kwargs={},\n with_edge_labels=True,\n with_node_labels=True,\n node_label_alpha=0.35,\n edge_label_alpha=0.35,\n with_additional_edges=None,\n contain_hyper_edges=False,\n additional_edges_kwargs={},\n return_pos=False,\n):\n \"\"\"\n Draw a hypergraph as a Matplotlib figure\n\n By default this will draw a colorful \"rubber band\" like hypergraph, where\n convex hulls represent edges and are drawn around the nodes they contain.\n\n This is a convenience function that wraps calls with sensible parameters to\n the following lower-level drawing functions:\n\n * draw_hyper_edges,\n * draw_hyper_edge_labels,\n * draw_hyper_labels, and\n * draw_hyper_nodes\n\n The default layout algorithm is nx.spring_layout, but other layouts can be\n passed in. The Hypergraph is converted to a bipartite graph, and the layout\n algorithm is passed the bipartite graph.\n\n If you have a pre-determined layout, you can pass in a \"pos\" dictionary.\n This is a dictionary mapping from node id's to x-y coordinates. For example:\n\n >>> pos = {\n >>> 'A': (0, 0),\n >>> 'B': (1, 2),\n >>> 'C': (5, -3)\n >>> }\n\n will position the nodes {A, B, C} manually at the locations specified. The\n coordinate system is in Matplotlib \"data coordinates\", and the figure will\n be centered within the figure.\n\n By default, this will draw in a new figure, but the axis to render in can be\n specified using :code:`ax`.\n\n This approach works well for small hypergraphs, and does not guarantee\n a rigorously \"correct\" drawing. Overlapping of sets in the drawing generally\n implies that the sets intersect, but sometimes sets overlap if there is no\n intersection. It is not possible, in general, to draw a \"correct\" hypergraph\n this way for an arbitrary hypergraph, in the same way that not all graphs\n have planar drawings.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n pos: dict\n mapping of node and edge positions to R^2\n with_color: bool\n set to False to disable color cycling of edges\n with_node_counts: bool\n set to True to replace the label for collapsed nodes with the number of elements\n with_edge_counts: bool\n set to True to label collapsed edges with number of elements\n layout: function\n layout algorithm to compute\n layout_kwargs: dict\n keyword arguments passed to layout function\n ax: Axis\n matplotlib axis on which the plot is rendered\n edges_kwargs: dict\n keyword arguments passed to matplotlib.collections.PolyCollection for edges\n node_radius: None, int, float, or dict\n radius of all nodes, or dictionary of node:value; the default (None) calculates radius based on number of collapsed nodes; reasonable values range between 1 and 3\n nodes_kwargs: dict\n keyword arguments passed to matplotlib.collections.PolyCollection for nodes\n edge_labels_on_edge: bool\n whether to draw edge labels on the edge (rubber band) or inside\n edge_labels_kwargs: dict\n keyword arguments passed to matplotlib.annotate for edge labels\n node_labels_kwargs: dict\n keyword argumetns passed to matplotlib.annotate for node labels\n with_edge_labels: bool\n set to False to make edge labels invisible\n with_node_labels: bool\n set to False to make node labels invisible\n node_label_alpha: float\n the transparency (alpha) of the box behind text drawn in the figure for node labels\n edge_label_alpha: float\n the transparency (alpha) of the box behind text drawn in the figure for edge labels\n with_additional_edges: networkx.Graph\n ...\n contain_hyper_edges: bool\n whether the rubber band shoudl be drawn around the location of the edge in the bipartite graph. This may be invisibile unless \"with_additional_edges\" contains this information.\n\n \"\"\"\n\n ax = ax or plt.gca()\n\n if pos is None:\n pos = layout_node_link(H, with_additional_edges, layout=layout, **layout_kwargs)\n\n r0 = get_default_radius(H, pos)\n a0 = np.pi * r0**2\n\n def get_node_radius(v):\n if node_radius is None:\n return np.sqrt(a0 * get_collapsed_size(v) / np.pi)\n elif hasattr(node_radius, \"get\"):\n return node_radius.get(v, 1) * r0\n return node_radius * r0\n\n # guarantee that node radius is a dictionary mapping nodes to values\n node_radius = {v: get_node_radius(v) for v in H.nodes()}\n\n # for convenience, we are using setdefault to mutate the argument\n # however, we need to copy this to prevent side-effects\n edges_kwargs = edges_kwargs.copy()\n edges_kwargs.setdefault(\"edgecolors\", plt.cm.tab10(np.arange(len((H.edges()))) % 10))\n edges_kwargs.setdefault(\"facecolors\", \"none\")\n\n polys = draw_hyper_edges(\n H,\n pos,\n node_radius=node_radius,\n ax=ax,\n contain_hyper_edges=contain_hyper_edges,\n **edges_kwargs\n )\n\n if with_additional_edges:\n nx.draw_networkx_edges(\n with_additional_edges,\n pos=pos,\n ax=ax,\n **inflate_kwargs(with_additional_edges.edges(), additional_edges_kwargs)\n )\n\n if with_edge_labels:\n labels = get_frozenset_label(\n H.edges(), count=with_edge_counts, override=edge_labels\n )\n\n draw_hyper_edge_labels(\n H,\n pos,\n polys,\n color=edges_kwargs[\"edgecolors\"],\n backgroundcolor=(1, 1, 1, edge_label_alpha),\n labels=labels,\n ax=ax,\n edge_labels_on_edge=edge_labels_on_edge,\n **edge_labels_kwargs\n )\n\n if with_node_labels:\n labels = get_frozenset_label(\n H.nodes(), count=with_node_counts, override=node_labels\n )\n\n draw_hyper_labels(\n H,\n pos,\n node_radius=node_radius,\n labels=labels,\n ax=ax,\n va=\"center\",\n xytext=(5, 0),\n textcoords=\"offset points\",\n backgroundcolor=(1, 1, 1, node_label_alpha),\n **node_labels_kwargs\n )\n\n draw_hyper_nodes(H, pos, node_radius=node_radius, ax=ax, **nodes_kwargs)\n\n if len(H.nodes()) == 1:\n x, y = pos[list(H.nodes())[0]]\n s = 20\n\n ax.axis([x - s, x + s, y - s, y + s])\n else:\n ax.axis(\"equal\")\n\n ax.axis(\"off\")\n if return_pos:\n return pos\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_two_column","title":"draw_two_column(H, with_node_labels=True, with_edge_labels=True, with_node_counts=False, with_edge_counts=False, with_color=True, edge_kwargs=None, ax=None)
","text":"Draw a hypergraph using a two-collumn layout.
This is intended reproduce an illustrative technique for bipartite graphs and hypergraphs that is typically used in papers and textbooks.
The left column is reserved for nodes and the right column is reserved for edges. A line is drawn between a node an an edge
The order of nodes and edges is optimized to reduce line crossings between the two columns. Spacing between disconnected components is adjusted to make the diagram easier to read, by reducing the angle of the lines.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_two_column--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn with_node_labels: bool False to disable node labels with_edge_labels: bool False to disable edge labels with_node_counts: bool set to True to label collapsed nodes with number of elements with_edge_counts: bool set to True to label collapsed edges with number of elements with_color: bool set to False to disable color cycling of hyper edges edge_kwargs: dict keyword arguments to pass to matplotlib.LineCollection ax: Axis matplotlib axis on which the plot is rendered
Source code in src/aeiva/hypergraph/visualization.py
def draw_two_column(\n H,\n with_node_labels=True,\n with_edge_labels=True,\n with_node_counts=False,\n with_edge_counts=False,\n with_color=True,\n edge_kwargs=None,\n ax=None,\n):\n \"\"\"\n Draw a hypergraph using a two-collumn layout.\n\n This is intended reproduce an illustrative technique for bipartite graphs\n and hypergraphs that is typically used in papers and textbooks.\n\n The left column is reserved for nodes and the right column is reserved for\n edges. A line is drawn between a node an an edge\n\n The order of nodes and edges is optimized to reduce line crossings between\n the two columns. Spacing between disconnected components is adjusted to make\n the diagram easier to read, by reducing the angle of the lines.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n with_node_labels: bool\n False to disable node labels\n with_edge_labels: bool\n False to disable edge labels\n with_node_counts: bool\n set to True to label collapsed nodes with number of elements\n with_edge_counts: bool\n set to True to label collapsed edges with number of elements\n with_color: bool\n set to False to disable color cycling of hyper edges\n edge_kwargs: dict\n keyword arguments to pass to matplotlib.LineCollection\n ax: Axis\n matplotlib axis on which the plot is rendered\n \"\"\"\n\n edge_kwargs = edge_kwargs or {}\n\n ax = ax or plt.gca()\n\n pos = layout_two_column(H)\n\n V = [v for v in H.nodes()]\n E = [e for e in H.edges()]\n\n labels = {}\n labels.update(get_frozenset_label(V, count=with_node_counts))\n labels.update(get_frozenset_label(E, count=with_edge_counts))\n\n if with_color:\n edge_kwargs[\"color\"] = {\n e: plt.cm.tab10(i % 10) for i, e in enumerate(H.edges())\n }\n\n draw_hyper_edges_two_column(H, pos, ax=ax, **edge_kwargs)\n draw_hyper_labels_two_column(\n H,\n pos,\n labels,\n ax=ax,\n with_node_labels=with_node_labels,\n with_edge_labels=with_edge_labels,\n )\n ax.autoscale_view()\n\n ax.axis(\"off\")\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_default_radius","title":"get_default_radius(H, pos)
","text":"Calculate a reasonable default node radius
This function iterates over the hyper edges and finds the most distant pair of points given the positions provided. Then, the node radius is a fraction of the median of this distance take across all hyper-edges.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_default_radius--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_default_radius--returns","title":"Returns","text":"float the recommended radius
Source code in src/aeiva/hypergraph/visualization.py
def get_default_radius(H, pos):\n \"\"\"\n Calculate a reasonable default node radius\n\n This function iterates over the hyper edges and finds the most distant\n pair of points given the positions provided. Then, the node radius is a fraction\n of the median of this distance take across all hyper-edges.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n pos: dict\n mapping of node and edge positions to R^2\n\n Returns\n -------\n float\n the recommended radius\n\n \"\"\"\n if len(H) > 1:\n return 0.0125 * np.median(\n [pdist(np.vstack(list(map(pos.get, H.nodes())))).max() for nodes in H.edges()]\n )\n return 1\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_frozenset_label","title":"get_frozenset_label(S, count=False, override={})
","text":"Helper function for rendering the labels of possibly collapsed nodes and edges
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_frozenset_label--parameters","title":"Parameters","text":"S: iterable list of entities to be labeled count: bool True if labels should be counts of entities instead of list
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_frozenset_label--returns","title":"Returns","text":"dict mapping of entity to its string representation
Source code in src/aeiva/hypergraph/visualization.py
def get_frozenset_label(S, count=False, override={}):\n \"\"\"\n Helper function for rendering the labels of possibly collapsed nodes and edges\n\n Parameters\n ----------\n S: iterable\n list of entities to be labeled\n count: bool\n True if labels should be counts of entities instead of list\n\n Returns\n -------\n dict\n mapping of entity to its string representation\n \"\"\"\n\n def helper(v):\n if type(v) == str:\n n = get_collapsed_size(v)\n if count and n > 1:\n return f\"x {n}\"\n elif count:\n return \"\"\n return str(v)\n\n return {v: override.get(v, helper(v)) for v in S}\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_line_graph","title":"get_line_graph(H, collapse=True)
","text":"Computes the line graph, a directed graph, where a directed edge (u, v) exists if the edge u is a subset of the edge v in the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_line_graph--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn collapse: bool True if edges should be added if hyper edges are identical
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_line_graph--returns","title":"Returns","text":"networkx.DiGraph A directed graph
Source code in src/aeiva/hypergraph/visualization.py
def get_line_graph(H, collapse=True):\n \"\"\"\n Computes the line graph, a directed graph, where a directed edge (u, v)\n exists if the edge u is a subset of the edge v in the hypergraph.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n collapse: bool\n True if edges should be added if hyper edges are identical\n\n Returns\n -------\n networkx.DiGraph\n A directed graph\n \"\"\"\n D = nx.DiGraph()\n\n V = {edge: set(nodes) for edge, nodes in H.edge_elements().items()}\n\n D.add_nodes_from(V)\n\n for u, v in combinations(V, 2):\n if V[u] != V[v] or not collapse:\n if V[u].issubset(V[v]):\n D.add_edge(u, v)\n elif V[v].issubset(V[u]):\n D.add_edge(v, u)\n\n return D\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_set_layering","title":"get_set_layering(H, collapse=True)
","text":"Computes a layering of the edges in the hyper graph.
In this layering, each edge is assigned a level. An edge u will be above (e.g., have a smaller level value) another edge v if v is a subset of u.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_set_layering--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn collapse: bool True if edges should be added if hyper edges are identical
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_set_layering--returns","title":"Returns","text":"dict a mapping of vertices in H to integer levels
Source code in src/aeiva/hypergraph/visualization.py
def get_set_layering(H, collapse=True):\n \"\"\"\n Computes a layering of the edges in the hyper graph.\n\n In this layering, each edge is assigned a level. An edge u will be above\n (e.g., have a smaller level value) another edge v if v is a subset of u.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n collapse: bool\n True if edges should be added if hyper edges are identical\n\n Returns\n -------\n dict\n a mapping of vertices in H to integer levels\n \"\"\"\n\n D = get_line_graph(H, collapse=collapse)\n\n levels = {}\n\n for v in nx.topological_sort(D):\n parent_levels = [levels[u] for u, _ in D.in_edges(v)]\n levels[v] = max(parent_levels) + 1 if len(parent_levels) else 0\n\n return levels\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.inflate_kwargs","title":"inflate_kwargs(items, kwargs)
","text":"Helper function to expand keyword arguments.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.inflate_kwargs--parameters","title":"Parameters","text":"n: int length of resulting list if argument is expanded kwargs: dict keyword arguments to be expanded
"},{"location":"reference/#src.aeiva.hypergraph.visualization.inflate_kwargs--returns","title":"Returns","text":"dict dictionary with same keys as kwargs and whose values are lists of length n
Source code in src/aeiva/hypergraph/visualization.py
def inflate_kwargs(items, kwargs):\n \"\"\"\n Helper function to expand keyword arguments.\n\n Parameters\n ----------\n n: int\n length of resulting list if argument is expanded\n kwargs: dict\n keyword arguments to be expanded\n\n Returns\n -------\n dict\n dictionary with same keys as kwargs and whose values are lists of length n\n \"\"\"\n\n return {k: inflate(items, v) for k, v in kwargs.items()}\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.layout_hyper_edges","title":"layout_hyper_edges(H, pos, node_radius={}, dr=None, contain_hyper_edges=False)
","text":"Draws a convex hull for each edge in H.
Position of the nodes in the graph is specified by the position dictionary, pos. Convex hulls are spaced out such that if one set contains another, the convex hull will surround the contained set. The amount of spacing added between hulls is specified by the parameter, dr.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.layout_hyper_edges--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 node_radius: dict mapping of node to R^1 (radius of each node) dr: float the spacing between concentric rings ax: Axis matplotlib axis on which the plot is rendered
"},{"location":"reference/#src.aeiva.hypergraph.visualization.layout_hyper_edges--returns","title":"Returns","text":"dict A mapping from hyper edge ids to paths (Nx2 numpy matrices)
Source code in src/aeiva/hypergraph/visualization.py
def layout_hyper_edges(H, pos, node_radius={}, dr=None, contain_hyper_edges=False):\n \"\"\"\n Draws a convex hull for each edge in H.\n\n Position of the nodes in the graph is specified by the position dictionary,\n pos. Convex hulls are spaced out such that if one set contains another, the\n convex hull will surround the contained set. The amount of spacing added\n between hulls is specified by the parameter, dr.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n pos: dict\n mapping of node and edge positions to R^2\n node_radius: dict\n mapping of node to R^1 (radius of each node)\n dr: float\n the spacing between concentric rings\n ax: Axis\n matplotlib axis on which the plot is rendered\n\n Returns\n -------\n dict\n A mapping from hyper edge ids to paths (Nx2 numpy matrices)\n \"\"\"\n\n if len(node_radius):\n r0 = min(node_radius.values())\n else:\n r0 = get_default_radius(H, pos)\n\n dr = dr or r0\n\n levels = get_set_layering(H)\n\n radii = {\n v: {v: i for i, v in enumerate(sorted(e, key=levels.get))}\n for v, e in H.node_memberships().items()\n }\n\n def get_padded_hull(uid, edge):\n # make sure the edge contains at least one node\n if len(edge):\n points = [\n cp * (node_radius.get(v, r0) + dr * (2 + radii[v][uid])) + pos[v]\n for v in edge\n ]\n\n if contain_hyper_edges:\n points.append(cp * r0 + pos[uid])\n\n points = np.vstack(points)\n\n # if not, draw an empty edge centered around the location of the edge node (in the bipartite graph)\n else:\n points = 4 * r0 * cp + pos[uid]\n\n hull = ConvexHull(points)\n\n return hull.points[hull.vertices]\n\n return [get_padded_hull(uid, list(H.edge_elements()[uid])) for uid in H.edges()]\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.layout_node_link","title":"layout_node_link(H, G=None, layout=nx.spring_layout, **kwargs)
","text":"Helper function to use a NetwrokX-like graph layout algorithm on a Hypergraph
The hypergraph is converted to a bipartite graph, allowing the usual graph layout techniques to be applied.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.layout_node_link--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn G: Graph an additional set of links to consider during the layout process layout: function the layout algorithm which accepts a NetworkX graph and keyword arguments kwargs: dict Keyword arguments are passed through to the layout algorithm
"},{"location":"reference/#src.aeiva.hypergraph.visualization.layout_node_link--returns","title":"Returns","text":"dict mapping of node and edge positions to R^2
Source code in src/aeiva/hypergraph/visualization.py
def layout_node_link(H, G=None, layout=nx.spring_layout, **kwargs):\n \"\"\"\n Helper function to use a NetwrokX-like graph layout algorithm on a Hypergraph\n\n The hypergraph is converted to a bipartite graph, allowing the usual graph layout\n techniques to be applied.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n G: Graph\n an additional set of links to consider during the layout process\n layout: function\n the layout algorithm which accepts a NetworkX graph and keyword arguments\n kwargs: dict\n Keyword arguments are passed through to the layout algorithm\n\n Returns\n -------\n dict\n mapping of node and edge positions to R^2\n \"\"\"\n\n B = H.to_bipartite_graph()\n\n if G is not None:\n B.add_edges_from(G.edges())\n\n return layout(B, **kwargs)\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.layout_two_column","title":"layout_two_column(H, spacing=2)
","text":"Two column (bipartite) layout algorithm.
This algorithm first converts the hypergraph into a bipartite graph and then computes connected components. Disonneccted components are handled independently and then stacked together.
Within a connected component, the spectral ordering of the bipartite graph provides a quick and dirty ordering that minimizes edge crossings in the diagram.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.layout_two_column--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn spacing: float amount of whitespace between disconnected components
Source code in src/aeiva/hypergraph/visualization.py
def layout_two_column(H, spacing=2):\n \"\"\"\n Two column (bipartite) layout algorithm.\n\n This algorithm first converts the hypergraph into a bipartite graph and\n then computes connected components. Disonneccted components are handled\n independently and then stacked together.\n\n Within a connected component, the spectral ordering of the bipartite graph\n provides a quick and dirty ordering that minimizes edge crossings in the\n diagram.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n spacing: float\n amount of whitespace between disconnected components\n \"\"\"\n offset = 0\n pos = {}\n\n def stack(vertices, x, height):\n for i, v in enumerate(vertices):\n pos[v] = (x, i + offset + (height - len(vertices)) / 2)\n\n G = H.to_bipartite_graph()\n for ci in nx.connected_components(G):\n Gi = G.subgraph(ci)\n key = {v: i for i, v in enumerate(nx.spectral_ordering(Gi))}.get\n ci_vertices, ci_edges = [\n sorted([v for v, d in Gi.nodes(data=True) if d[\"bipartite\"] == j], key=key)\n for j in [0, 1]\n ]\n\n height = max(len(ci_vertices), len(ci_edges))\n\n stack(ci_vertices, 0, height)\n stack(ci_edges, 1, height)\n\n offset += height + spacing\n\n return pos\n
"},{"location":"reference/#src.aeiva.llm","title":"llm
","text":""},{"location":"reference/#src.aeiva.llm.llm_client","title":"llm_client
","text":""},{"location":"reference/#src.aeiva.llm.llm_client.LLMClient","title":"LLMClient
","text":"Language Model interface that supports synchronous, asynchronous, and streaming modes, and optionally, tool usage via function calls.
Source code in src/aeiva/llm/llm_client.py
class LLMClient:\n \"\"\"\n Language Model interface that supports synchronous, asynchronous, and streaming modes,\n and optionally, tool usage via function calls.\n \"\"\"\n\n def __init__(self, config: LLMGatewayConfig):\n self.config = config\n self.metrics = LLMUsageMetrics()\n self.logger = get_logger(__name__, level=config.llm_logging_level.upper())\n self._validate_config()\n\n def _validate_config(self):\n if not self.config.llm_api_key:\n raise ValueError(\"API key must be provided in the configuration.\")\n\n @retry_sync(\n max_attempts=lambda self: self.config.llm_num_retries,\n backoff_factor=lambda self: self.config.llm_retry_backoff_factor,\n exceptions=(LLMGatewayError,), # Catching LLMGatewayError\n )\n def generate(\n self, messages: List[Any], tools: List[Dict[str, Any]] = None, **kwargs\n ) -> str:\n try:\n max_iterations = MAX_TOOL_CALL_LOOP # Prevent infinite loops\n iteration = 0\n\n while iteration < max_iterations:\n iteration += 1\n\n # Build parameters\n params = self._build_params(messages=messages, tools=tools, **kwargs)\n response = llm_completion(**params)\n self._update_metrics(response)\n response_message = response.choices[0].message\n\n tool_calls = response_message.tool_calls\n\n if tool_calls:\n # Append assistant's tool call message\n messages.append({\"role\": \"assistant\", \"tool_calls\": tool_calls})\n\n for tool_call in tool_calls:\n function_name = tool_call.function.name\n function_args = json.loads(tool_call.function.arguments)\n tool_call_id = tool_call.id\n self.logger.info(f\"Tool call id: {tool_call_id}\")\n\n try:\n function_response = self.call_tool_sync(\n api_name=function_name, function_name=function_name, params=function_args\n )\n except Exception as e:\n self.logger.error(f\"Error executing tool '{function_name}': {e}\")\n function_response = f\"Error executing tool '{function_name}': {e}\"\n\n # Append the function response to messages\n messages.append(\n {\n \"tool_call_id\": tool_call_id,\n \"role\": \"tool\",\n \"name\": function_name,\n \"content\": str(function_response),\n }\n )\n # Continue the loop to handle further function calls\n continue\n else:\n # Assistant provided a final response\n messages.append({\"role\": \"assistant\", \"content\": response_message.content})\n return response_message.content\n\n # If loop exceeds max iterations\n raise Exception(\"Maximum iterations reached without a final response.\")\n\n except Exception as e:\n self.logger.error(f\"LLM Gateway Error: {e}\")\n raise llm_gateway_exception(e)\n\n @retry_async(\n max_attempts=lambda self: self.config.llm_num_retries,\n backoff_factor=lambda self: self.config.llm_retry_backoff_factor,\n exceptions=(LLMGatewayError,), # Catching LLMGatewayError\n )\n async def agenerate(\n self, messages: List[Any], tools: List[Dict[str, Any]] = None, **kwargs\n ) -> str:\n try:\n max_iterations = MAX_TOOL_CALL_LOOP # Prevent infinite loops\n iteration = 0\n\n while iteration < max_iterations:\n iteration += 1\n\n # Build parameters\n params = self._build_params(messages=messages, tools=tools, **kwargs)\n response = await llm_acompletion(**params)\n self._update_metrics(response)\n response_message = response.choices[0].message\n\n tool_calls = response_message.tool_calls\n\n if tool_calls:\n # Append assistant's tool call message\n messages.append({\"role\": \"assistant\", \"tool_calls\": tool_calls})\n\n for tool_call in tool_calls:\n function_name = tool_call.function.name\n function_args = json.loads(tool_call.function.arguments)\n tool_call_id = tool_call.id\n\n try:\n function_response = await self.call_tool(\n api_name=function_name, function_name=function_name, params=function_args\n )\n except Exception as e:\n self.logger.error(f\"Error executing tool '{function_name}': {e}\")\n function_response = f\"Error executing tool '{function_name}': {e}\"\n\n # Append the function response to messages\n messages.append(\n {\n \"tool_call_id\": tool_call_id,\n \"role\": \"tool\",\n \"name\": function_name,\n \"content\": str(function_response),\n }\n )\n # Continue the loop to handle further function calls\n continue\n else:\n # Assistant provided a final response\n messages.append({\"role\": \"assistant\", \"content\": response_message.content})\n return response_message.content\n\n # If loop exceeds max iterations\n raise Exception(\"Maximum iterations reached without a final response.\")\n\n except Exception as e:\n self.logger.error(f\"LLM Asynchronous Generation Error: {e}\")\n raise llm_gateway_exception(e)\n\n async def stream_generate(\n self, messages: List[Any], tools: List[Dict[str, Any]] = None, **kwargs\n ) -> AsyncGenerator[str, None]:\n try:\n max_iterations = MAX_TOOL_CALL_LOOP # Prevent infinite loops\n iteration = 0\n\n while iteration < max_iterations:\n iteration += 1\n\n # Build parameters\n params = self._build_params(messages=messages, tools=tools, **kwargs)\n response_stream = await llm_acompletion(**params)\n\n # Prepare to collect the assistant's reply\n tool_calls = [] # Accumulator for tool calls\n full_delta_content = '' # Accumulator for assistant's content\n\n # Collect streamed responses\n async for response in response_stream:\n delta = response.choices[0].delta\n\n # Collect assistant's content and yield it\n if getattr(delta, 'content', None):\n full_delta_content += delta.content\n yield delta.content\n\n # Check for tool calls in the delta\n if getattr(delta, 'tool_calls', None):\n tc_chunk_list = delta.tool_calls\n for tc_chunk in tc_chunk_list:\n index = tc_chunk.index\n # Ensure tool_calls list is large enough\n while len(tool_calls) <= index:\n tool_calls.append({\"id\": \"\", \"type\": \"function\", \"function\": {\"name\": \"\", \"arguments\": \"\"}})\n tc = tool_calls[index]\n\n if getattr(tc_chunk, 'id', None):\n tc[\"id\"] += tc_chunk.id\n if getattr(tc_chunk.function, 'name', None):\n tc[\"function\"][\"name\"] += tc_chunk.function.name\n if getattr(tc_chunk.function, 'arguments', None):\n tc[\"function\"][\"arguments\"] += tc_chunk.function.arguments\n\n # After initial streaming, check if there are tool calls\n if tool_calls:\n # Append the assistant's tool_call message to messages\n messages.append({\"role\": \"assistant\", \"tool_calls\": tool_calls})\n\n # Process each tool_call\n available_functions = [tool[\"function\"][\"name\"] for tool in tools]\n for tool_call in tool_calls:\n function_name = tool_call[\"function\"][\"name\"]\n if function_name not in available_functions:\n # Handle error if function not found\n yield f\"Function {function_name} does not exist.\"\n return\n # Call the function with arguments\n try:\n function_args = json.loads(tool_call[\"function\"][\"arguments\"])\n except json.JSONDecodeError as e:\n self.logger.error(f\"Error decoding function arguments: {e}\")\n function_args = {}\n\n try:\n function_response = await self.call_tool(\n api_name=function_name, function_name=function_name, params=function_args\n )\n except Exception as e:\n self.logger.error(f\"Error executing tool '{function_name}': {e}\")\n function_response = f\"Error executing tool '{function_name}': {e}\"\n\n # Append the function's response to messages\n messages.append(\n {\n \"tool_call_id\": tool_call['id'],\n \"role\": \"tool\",\n \"name\": function_name,\n \"content\": str(function_response),\n }\n )\n # Continue the loop to handle further function calls\n continue\n else:\n # No tool calls, streaming is complete\n messages.append({\"role\": \"assistant\", \"content\": full_delta_content})\n return # Exit the loop\n\n # If loop exceeds max iterations\n yield \"Maximum iterations reached without a final response.\"\n\n except Exception as e:\n self.logger.error(f\"Streaming LLM Gateway Error: {e}\")\n yield \"An error occurred during streaming.\"\n\n def call_tool_via_server(self, api_name: str, function_name: str, params: Dict[str, Any]) -> Any: # TODO: may need revise\n \"\"\"Calls the API via FastAPI server.\"\"\"\n url = f\"http://localhost:8000/api/{api_name}/{function_name}\"\n self.logger.info(f\"Calling {api_name} with params: {params}\")\n response = requests.get(url, params=params)\n if response.status_code == 200:\n json_response = response.json()\n if \"result\" in json_response:\n return str(json_response[\"result\"])\n else:\n return f\"Error from API: {json_response.get('error', 'Unknown error')}\"\n else:\n return f\"HTTP Error {response.status_code}: {response.text}\"\n\n async def call_tool(self, api_name: str, function_name: str, params: Dict[str, Any]) -> Any: # TODO: may need revise\n \"\"\"Calls the API via action module.\"\"\"\n tool = Tool(api_name)\n return await tool.aexecute(params)\n\n def call_tool_sync(self, api_name: str, function_name: str, params: Dict[str, Any]) -> Any: # TODO: may need revise\n \"\"\"Calls the API via action module.\"\"\"\n tool = Tool(api_name)\n return tool.execute(params)\n\n def _build_params(\n self, messages: List[Any], tools: List[Dict[str, Any]] = None, **kwargs\n ) -> Dict[str, Any]:\n params = {\n \"model\": self.config.llm_model_name,\n \"messages\": messages,\n \"api_key\": self.config.llm_api_key,\n \"temperature\": self.config.llm_temperature,\n \"top_p\": self.config.llm_top_p,\n \"max_tokens\": self.config.llm_max_output_tokens,\n \"timeout\": self.config.llm_timeout,\n }\n params.update(self.config.llm_additional_params)\n params.update(kwargs)\n\n # Check if the model supports function calling\n if tools and supports_function_calling(self.config.llm_model_name):\n params[\"tools\"] = tools\n params[\"tool_choice\"] = \"auto\"\n\n return params\n\n def _update_metrics(self, response: Any, log: bool = False): # Note: log is False by default. Adjust according to the need.\n usage = getattr(response, \"usage\", {})\n self.metrics.add_tokens(\n prompt_tokens=getattr(usage, \"prompt_tokens\", 0),\n completion_tokens=getattr(usage, \"completion_tokens\", 0),\n )\n self.metrics.add_cost(getattr(usage, \"cost\", 0.0))\n if log:\n self.logger.info(\n f\"Tokens used: {self.metrics.total_tokens}, Cost: ${self.metrics.total_cost:.4f}\"\n )\n\n def __call__(\n self, messages: List[Any], tools: List[Dict[str, Any]] = None, **kwargs\n ) -> Any:\n if self.config.llm_use_async:\n if self.config.llm_stream:\n return self.stream_generate(messages, tools=tools, **kwargs)\n else:\n return self.agenerate(messages, tools=tools, **kwargs)\n else:\n if self.config.llm_stream:\n # OpenAI's API does not support synchronous streaming; streaming must be async\n raise NotImplementedError(\"Synchronous streaming is not supported.\")\n else:\n return self.generate(messages, tools=tools, **kwargs)\n
"},{"location":"reference/#src.aeiva.llm.llm_client.LLMClient.call_tool","title":"call_tool(api_name, function_name, params)
async
","text":"Calls the API via action module.
Source code in src/aeiva/llm/llm_client.py
async def call_tool(self, api_name: str, function_name: str, params: Dict[str, Any]) -> Any: # TODO: may need revise\n \"\"\"Calls the API via action module.\"\"\"\n tool = Tool(api_name)\n return await tool.aexecute(params)\n
"},{"location":"reference/#src.aeiva.llm.llm_client.LLMClient.call_tool_sync","title":"call_tool_sync(api_name, function_name, params)
","text":"Calls the API via action module.
Source code in src/aeiva/llm/llm_client.py
def call_tool_sync(self, api_name: str, function_name: str, params: Dict[str, Any]) -> Any: # TODO: may need revise\n \"\"\"Calls the API via action module.\"\"\"\n tool = Tool(api_name)\n return tool.execute(params)\n
"},{"location":"reference/#src.aeiva.llm.llm_client.LLMClient.call_tool_via_server","title":"call_tool_via_server(api_name, function_name, params)
","text":"Calls the API via FastAPI server.
Source code in src/aeiva/llm/llm_client.py
def call_tool_via_server(self, api_name: str, function_name: str, params: Dict[str, Any]) -> Any: # TODO: may need revise\n \"\"\"Calls the API via FastAPI server.\"\"\"\n url = f\"http://localhost:8000/api/{api_name}/{function_name}\"\n self.logger.info(f\"Calling {api_name} with params: {params}\")\n response = requests.get(url, params=params)\n if response.status_code == 200:\n json_response = response.json()\n if \"result\" in json_response:\n return str(json_response[\"result\"])\n else:\n return f\"Error from API: {json_response.get('error', 'Unknown error')}\"\n else:\n return f\"HTTP Error {response.status_code}: {response.text}\"\n
"},{"location":"reference/#src.aeiva.llm.llm_gateway_config","title":"llm_gateway_config
","text":""},{"location":"reference/#src.aeiva.llm.llm_gateway_config.LLMGatewayConfig","title":"LLMGatewayConfig
dataclass
","text":" Bases: BaseConfig
Configuration for the Language Model (LLM).
Source code in src/aeiva/llm/llm_gateway_config.py
@dataclass\nclass LLMGatewayConfig(BaseConfig):\n \"\"\"\n Configuration for the Language Model (LLM).\n \"\"\"\n\n llm_model_name: Optional[str] = field(\n default='gpt-4',\n metadata={\"help\": \"The name of the LLM model to use (e.g., 'gpt-4', 'gpt-3.5-turbo').\"}\n )\n llm_api_key: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The API key for authentication with the LLM provider.\"}\n )\n llm_base_url: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The base URL for API requests to the LLM provider.\"}\n )\n llm_api_version: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The version of the LLM API to use.\"}\n )\n llm_embedding_model: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The embedding model to use for tasks requiring embeddings.\"}\n )\n llm_timeout: Optional[int] = field(\n default=30,\n metadata={\"help\": \"The timeout in seconds for API requests.\"}\n )\n llm_max_input_tokens: Optional[int] = field(\n default=4096,\n metadata={\"help\": \"The maximum number of input tokens allowed in a request.\"}\n )\n llm_max_output_tokens: Optional[int] = field(\n default=1024,\n metadata={\"help\": \"The maximum number of output tokens generated by the LLM.\"}\n )\n llm_temperature: Optional[float] = field(\n default=0.7,\n metadata={\"help\": \"Sampling temperature for response variability (range: 0.0 - 1.0).\"}\n )\n llm_top_p: Optional[float] = field(\n default=0.9,\n metadata={\"help\": \"Nucleus sampling probability for token selection (range: 0.0 - 1.0).\"}\n )\n llm_num_retries: Optional[int] = field(\n default=3,\n metadata={\"help\": \"The number of times to retry failed API requests.\"}\n )\n llm_retry_backoff_factor: Optional[float] = field(\n default=0.5,\n metadata={\"help\": \"Factor for exponential backoff between retries.\"}\n )\n llm_retry_on_status: Optional[Tuple[int, ...]] = field(\n default=(429, 500, 502, 503, 504),\n metadata={\"help\": \"HTTP status codes that should trigger a retry.\"}\n )\n llm_use_async: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"Whether to use asynchronous API calls.\"}\n )\n llm_stream: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"Whether to enable streaming responses from the LLM.\"}\n )\n llm_logging_level: Optional[str] = field(\n default='INFO',\n metadata={\"help\": \"Logging level for the LLM module (e.g., 'DEBUG', 'INFO').\"}\n )\n llm_additional_params: Optional[Dict[str, Any]] = field(\n default_factory=dict,\n metadata={\"help\": \"Additional parameters to pass to the LLM API.\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n # Load API keys from the configuration file if not provided\n if not self.llm_api_key:\n self.load_api_key()\n\n def load_api_key(self):\n config_path = os.path.join(os.path.dirname(__file__), '../../../configs/llm_api_keys.yaml')\n try:\n with open(config_path, 'r') as f:\n keys = yaml.safe_load(f)\n self.llm_api_key = keys.get('openai_api_key')\n except FileNotFoundError:\n raise FileNotFoundError('API keys file not found.')\n except Exception as e:\n raise e\n\n def to_dict(self):\n return {\n key: ('******' if key == 'llm_api_key' and value else value)\n for key, value in self.__dict__.items()\n if not key.startswith('_')\n }\n
"},{"location":"reference/#src.aeiva.llm.llm_gateway_exceptions","title":"llm_gateway_exceptions
","text":""},{"location":"reference/#src.aeiva.llm.llm_gateway_exceptions.LLMGatewayError","title":"LLMGatewayError
","text":" Bases: Exception
Unified exception class for all LLM-related errors.
Source code in src/aeiva/llm/llm_gateway_exceptions.py
class LLMGatewayError(Exception):\n \"\"\"Unified exception class for all LLM-related errors.\"\"\"\n\n def __init__(self, message: str, original_exception: Exception = None):\n super().__init__(message)\n self.original_exception = original_exception\n
"},{"location":"reference/#src.aeiva.llm.llm_gateway_exceptions.llm_gateway_exception","title":"llm_gateway_exception(e)
","text":"Converts a litellm exception to a unified LLMGatewayError.
Source code in src/aeiva/llm/llm_gateway_exceptions.py
def llm_gateway_exception(e: Exception) -> LLMGatewayError:\n \"\"\"Converts a litellm exception to a unified LLMGatewayError.\"\"\"\n exception_type = type(e)\n mapped_exception = LITELLM_EXCEPTION_MAP.get(exception_type, LLMGatewayError)\n return mapped_exception(str(e), original_exception=e)\n
"},{"location":"reference/#src.aeiva.llm.llm_usage_metrics","title":"llm_usage_metrics
","text":""},{"location":"reference/#src.aeiva.llm.llm_usage_metrics.LLMUsageMetrics","title":"LLMUsageMetrics
","text":"Tracks metrics such as token usage and cost.
Source code in src/aeiva/llm/llm_usage_metrics.py
class LLMUsageMetrics:\n \"\"\"\n Tracks metrics such as token usage and cost.\n \"\"\"\n def __init__(self):\n self.total_tokens = 0\n self.prompt_tokens = 0\n self.completion_tokens = 0\n self.total_cost = 0.0\n\n def add_tokens(self, prompt_tokens: int, completion_tokens: int):\n self.prompt_tokens += prompt_tokens\n self.completion_tokens += completion_tokens\n self.total_tokens += prompt_tokens + completion_tokens\n\n def add_cost(self, cost: float):\n self.total_cost += cost\n
"},{"location":"reference/#src.aeiva.model","title":"model
","text":""},{"location":"reference/#src.aeiva.model.macaw_model","title":"macaw_model
","text":""},{"location":"reference/#src.aeiva.model.macaw_model.LlamaAttention","title":"LlamaAttention
","text":" Bases: Module
Multi-headed attention from 'Attention Is All You Need' paper
Source code in src/aeiva/model/macaw_model.py
class LlamaAttention(nn.Module):\n \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n def __init__(self, config: LlamaConfig):\n super().__init__()\n self.config = config\n self.hidden_size = config.hidden_size\n self.num_heads = config.num_attention_heads\n self.head_dim = self.hidden_size // self.num_heads\n self.max_position_embeddings = config.max_position_embeddings\n\n if (self.head_dim * self.num_heads) != self.hidden_size:\n raise ValueError(\n f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}\"\n f\" and `num_heads`: {self.num_heads}).\"\n )\n self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)\n self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)\n\n def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n def forward(\n self,\n hidden_states: torch.Tensor,\n attention_mask: Optional[torch.Tensor] = None,\n position_ids: Optional[torch.LongTensor] = None,\n past_key_value: Optional[Tuple[torch.Tensor]] = None,\n output_attentions: bool = False,\n use_cache: bool = False,\n ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n bsz, q_len, _ = hidden_states.size()\n\n query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n\n kv_seq_len = key_states.shape[-2]\n if past_key_value is not None:\n kv_seq_len += past_key_value[0].shape[-2]\n cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n # [bsz, nh, t, hd]\n\n if past_key_value is not None:\n # reuse k, v, self_attention\n key_states = torch.cat([past_key_value[0], key_states], dim=2)\n value_states = torch.cat([past_key_value[1], value_states], dim=2)\n\n past_key_value = (key_states, value_states) if use_cache else None\n\n attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n raise ValueError(\n f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n f\" {attn_weights.size()}\"\n )\n\n if attention_mask is not None:\n if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n raise ValueError(\n f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n )\n attn_weights = attn_weights + attention_mask\n attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))\n\n # upcast attention to fp32\n attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n attn_output = torch.matmul(attn_weights, value_states)\n\n if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n raise ValueError(\n f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n f\" {attn_output.size()}\"\n )\n\n attn_output = attn_output.transpose(1, 2)\n attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n\n attn_output = self.o_proj(attn_output)\n\n if not output_attentions:\n attn_weights = None\n\n return attn_output, attn_weights, past_key_value\n
"},{"location":"reference/#src.aeiva.model.macaw_model.LlamaDecoderLayer","title":"LlamaDecoderLayer
","text":" Bases: Module
Source code in src/aeiva/model/macaw_model.py
class LlamaDecoderLayer(nn.Module):\n def __init__(self, config: LlamaConfig):\n super().__init__()\n self.hidden_size = config.hidden_size\n self.self_attn = LlamaAttention(config=config)\n self.mlp = LlamaMLP(\n hidden_size=self.hidden_size,\n intermediate_size=config.intermediate_size,\n hidden_act=config.hidden_act,\n )\n self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n def forward(\n self,\n hidden_states: torch.Tensor,\n attention_mask: Optional[torch.Tensor] = None,\n position_ids: Optional[torch.LongTensor] = None,\n past_key_value: Optional[Tuple[torch.Tensor]] = None,\n output_attentions: Optional[bool] = False,\n use_cache: Optional[bool] = False,\n ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n \"\"\"\n Args:\n hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n output_attentions (`bool`, *optional*):\n Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n returned tensors for more detail.\n use_cache (`bool`, *optional*):\n If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n (see `past_key_values`).\n past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n \"\"\"\n\n residual = hidden_states\n\n hidden_states = self.input_layernorm(hidden_states)\n\n # Self Attention\n hidden_states, self_attn_weights, present_key_value = self.self_attn(\n hidden_states=hidden_states,\n attention_mask=attention_mask,\n position_ids=position_ids,\n past_key_value=past_key_value,\n output_attentions=output_attentions,\n use_cache=use_cache,\n )\n hidden_states = residual + hidden_states\n\n # Fully Connected\n residual = hidden_states\n hidden_states = self.post_attention_layernorm(hidden_states)\n hidden_states = self.mlp(hidden_states)\n hidden_states = residual + hidden_states\n\n outputs = (hidden_states,)\n\n if output_attentions:\n outputs += (self_attn_weights,)\n\n if use_cache:\n outputs += (present_key_value,)\n\n return outputs\n
"},{"location":"reference/#src.aeiva.model.macaw_model.LlamaDecoderLayer.forward","title":"forward(hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False)
","text":"Parameters:
Name Type Description Default hidden_states
`torch.FloatTensor`
input to the layer of shape (batch, seq_len, embed_dim)
required attention_mask
`torch.FloatTensor`, *optional*
attention mask of size (batch, 1, tgt_len, src_len)
where padding elements are indicated by very large negative values.
None
output_attentions
`bool`, *optional*
Whether or not to return the attentions tensors of all attention layers. See attentions
under returned tensors for more detail.
False
use_cache
`bool`, *optional*
If set to True
, past_key_values
key value states are returned and can be used to speed up decoding (see past_key_values
).
False
past_key_value
`Tuple(torch.FloatTensor)`, *optional*
cached past key and value projection states
None
Source code in src/aeiva/model/macaw_model.py
def forward(\n self,\n hidden_states: torch.Tensor,\n attention_mask: Optional[torch.Tensor] = None,\n position_ids: Optional[torch.LongTensor] = None,\n past_key_value: Optional[Tuple[torch.Tensor]] = None,\n output_attentions: Optional[bool] = False,\n use_cache: Optional[bool] = False,\n) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n \"\"\"\n Args:\n hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n output_attentions (`bool`, *optional*):\n Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n returned tensors for more detail.\n use_cache (`bool`, *optional*):\n If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n (see `past_key_values`).\n past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n \"\"\"\n\n residual = hidden_states\n\n hidden_states = self.input_layernorm(hidden_states)\n\n # Self Attention\n hidden_states, self_attn_weights, present_key_value = self.self_attn(\n hidden_states=hidden_states,\n attention_mask=attention_mask,\n position_ids=position_ids,\n past_key_value=past_key_value,\n output_attentions=output_attentions,\n use_cache=use_cache,\n )\n hidden_states = residual + hidden_states\n\n # Fully Connected\n residual = hidden_states\n hidden_states = self.post_attention_layernorm(hidden_states)\n hidden_states = self.mlp(hidden_states)\n hidden_states = residual + hidden_states\n\n outputs = (hidden_states,)\n\n if output_attentions:\n outputs += (self_attn_weights,)\n\n if use_cache:\n outputs += (present_key_value,)\n\n return outputs\n
"},{"location":"reference/#src.aeiva.model.macaw_model.LlamaModel","title":"LlamaModel
","text":" Bases: LlamaPreTrainedModel
Transformer decoder consisting of config.num_hidden_layers layers. Each layer is a [LlamaDecoderLayer
]
Parameters:
Name Type Description Default config
LlamaConfig
LlamaConfig
required Source code in src/aeiva/model/macaw_model.py
class LlamaModel(LlamaPreTrainedModel):\n \"\"\"\n Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n\n Args:\n config: LlamaConfig\n \"\"\"\n\n def __init__(self, config: LlamaConfig):\n super().__init__(config)\n self.padding_idx = config.pad_token_id\n self.vocab_size = config.vocab_size\n\n self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])\n self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n self.gradient_checkpointing = False\n # Initialize weights and apply final processing\n self.post_init()\n\n def get_input_embeddings(self):\n return self.embed_tokens\n\n def set_input_embeddings(self, value):\n self.embed_tokens = value\n\n # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask\n def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):\n # create causal mask\n # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n combined_attention_mask = None\n if input_shape[-1] > 1:\n combined_attention_mask = _make_causal_mask(\n input_shape,\n inputs_embeds.dtype,\n device=inputs_embeds.device,\n past_key_values_length=past_key_values_length,\n )\n\n if attention_mask is not None:\n # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(\n inputs_embeds.device\n )\n combined_attention_mask = (\n expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n )\n\n return combined_attention_mask\n\n # @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n def forward(\n self,\n input_ids: torch.LongTensor = None,\n attention_mask: Optional[torch.Tensor] = None,\n position_ids: Optional[torch.LongTensor] = None,\n past_key_values: Optional[List[torch.FloatTensor]] = None,\n inputs_embeds: Optional[torch.FloatTensor] = None,\n use_cache: Optional[bool] = None,\n output_attentions: Optional[bool] = None,\n output_hidden_states: Optional[bool] = None,\n return_dict: Optional[bool] = None,\n ) -> Union[Tuple, BaseModelOutputWithPast]:\n output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n output_hidden_states = (\n output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n )\n use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n # retrieve input_ids and inputs_embeds\n if input_ids is not None and inputs_embeds is not None:\n raise ValueError(\"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time\")\n elif input_ids is not None:\n batch_size, seq_length = input_ids.shape\n elif inputs_embeds is not None:\n batch_size, seq_length, _ = inputs_embeds.shape\n else:\n raise ValueError(\"You have to specify either decoder_input_ids or decoder_inputs_embeds\")\n\n seq_length_with_past = seq_length\n past_key_values_length = 0\n\n if past_key_values is not None:\n past_key_values_length = past_key_values[0][0].shape[2]\n seq_length_with_past = seq_length_with_past + past_key_values_length\n\n if position_ids is None:\n device = input_ids.device if input_ids is not None else inputs_embeds.device\n position_ids = torch.arange(\n past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device\n )\n position_ids = position_ids.unsqueeze(0).view(-1, seq_length)\n else:\n position_ids = position_ids.view(-1, seq_length).long()\n\n if inputs_embeds is None:\n inputs_embeds = self.embed_tokens(input_ids)\n # embed positions\n if attention_mask is None:\n attention_mask = torch.ones(\n (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device\n )\n attention_mask = self._prepare_decoder_attention_mask(\n attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length\n )\n\n hidden_states = inputs_embeds\n\n if self.gradient_checkpointing and self.training:\n if use_cache:\n logger.warning_once(\n \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n )\n use_cache = False\n\n # decoder layers\n all_hidden_states = () if output_hidden_states else None\n all_self_attns = () if output_attentions else None\n next_decoder_cache = () if use_cache else None\n\n for idx, decoder_layer in enumerate(self.layers):\n if output_hidden_states:\n all_hidden_states += (hidden_states,)\n\n past_key_value = past_key_values[idx] if past_key_values is not None else None\n\n if self.gradient_checkpointing and self.training:\n\n def create_custom_forward(module):\n def custom_forward(*inputs):\n # None for past_key_value\n return module(*inputs, output_attentions, None)\n\n return custom_forward\n\n layer_outputs = torch.utils.checkpoint.checkpoint(\n create_custom_forward(decoder_layer),\n hidden_states,\n attention_mask,\n position_ids,\n None,\n )\n else:\n layer_outputs = decoder_layer(\n hidden_states,\n attention_mask=attention_mask,\n position_ids=position_ids,\n past_key_value=past_key_value,\n output_attentions=output_attentions,\n use_cache=use_cache,\n )\n\n hidden_states = layer_outputs[0]\n\n if use_cache:\n next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)\n\n if output_attentions:\n all_self_attns += (layer_outputs[1],)\n\n hidden_states = self.norm(hidden_states)\n\n # add hidden states from the last decoder layer\n if output_hidden_states:\n all_hidden_states += (hidden_states,)\n\n next_cache = next_decoder_cache if use_cache else None\n if not return_dict:\n return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)\n return BaseModelOutputWithPast(\n last_hidden_state=hidden_states,\n past_key_values=next_cache,\n hidden_states=all_hidden_states,\n attentions=all_self_attns,\n )\n
"},{"location":"reference/#src.aeiva.model.macaw_model.LlamaRMSNorm","title":"LlamaRMSNorm
","text":" Bases: Module
Source code in src/aeiva/model/macaw_model.py
class LlamaRMSNorm(nn.Module):\n def __init__(self, hidden_size, eps=1e-6):\n \"\"\"\n LlamaRMSNorm is equivalent to T5LayerNorm\n \"\"\"\n super().__init__()\n self.weight = nn.Parameter(torch.ones(hidden_size))\n self.variance_epsilon = eps\n\n def forward(self, hidden_states):\n variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)\n hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n\n # convert into half-precision if necessary\n if self.weight.dtype in [torch.float16, torch.bfloat16]:\n hidden_states = hidden_states.to(self.weight.dtype)\n\n return self.weight * hidden_states\n
"},{"location":"reference/#src.aeiva.model.macaw_model.LlamaRMSNorm.__init__","title":"__init__(hidden_size, eps=1e-06)
","text":"LlamaRMSNorm is equivalent to T5LayerNorm
Source code in src/aeiva/model/macaw_model.py
def __init__(self, hidden_size, eps=1e-6):\n \"\"\"\n LlamaRMSNorm is equivalent to T5LayerNorm\n \"\"\"\n super().__init__()\n self.weight = nn.Parameter(torch.ones(hidden_size))\n self.variance_epsilon = eps\n
"},{"location":"reference/#src.aeiva.model.macaw_model.MM_LLMs_Config","title":"MM_LLMs_Config
","text":" Bases: PretrainedConfig
Source code in src/aeiva/model/macaw_model.py
class MM_LLMs_Config(PretrainedConfig):\n model_type = 'mm_llms'\n is_composition = True\n\n def __init__(self, n_frames=6, attention_heads=8, image_conv_kernel=48, image_conv_stride=36, \n video_conv_kernel=36, video_conv_stride=30, audio_conv_kernel=240, audio_conv_stride=220,\n clip_config=None, whisper_config=None, llm_config=None, **kwargs):\n\n self.image_config = clip_config\n self.audio_config = whisper_config\n self.llm_config = llm_config\n self.n_frames = n_frames\n self.attention_heads = attention_heads\n self.image_conv_kernel = image_conv_kernel\n self.image_conv_stride = image_conv_stride\n self.video_conv_kernel = video_conv_kernel\n self.video_conv_stride = video_conv_stride\n self.audio_conv_kernel = audio_conv_kernel\n self.audio_conv_stride = audio_conv_stride\n\n self.hidden_size = max(llm_config.hidden_size, clip_config.projection_dim, whisper_config.d_model, clip_config.projection_dim)\n\n super().__init__(**kwargs)\n\n def to_dict(self):\n \"\"\"\n Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].\n\n Returns:\n `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,\n \"\"\"\n output = copy.deepcopy(self.__dict__)\n output[\"image_config\"] = self.image_config.to_dict()\n output[\"audio_config\"] = self.audio_config.to_dict()\n output['llm_config'] = self.llm_config.to_dict()\n output['n_frames'] = self.n_frames\n output['attention_heads'] = self.attention_heads\n output['image_conv_kernel'] = self.image_conv_kernel\n output['image_conv_stride'] = self.image_conv_stride\n output['video_conv_kernel'] = self.video_conv_kernel\n output['video_conv_stride'] = self.video_conv_stride\n output['audio_conv_kernel'] = self.audio_conv_kernel\n output['audio_conv_stride'] = self.audio_conv_stride\n output['hidden_size'] = self.hidden_size\n output[\"model_type\"] = self.__class__.model_type\n return output\n @classmethod\n def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):\n config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)\n\n clip_config = CLIPConfig.from_dict(config_dict['image_config'])\n whisper_config = WhisperConfig.from_dict(config_dict['audio_config'])\n llm_config = LlamaConfig.from_dict(config_dict['llm_config'])\n\n return cls(clip_config=clip_config, whisper_config=whisper_config, llm_config=llm_config, **kwargs)\n
"},{"location":"reference/#src.aeiva.model.macaw_model.MM_LLMs_Config.to_dict","title":"to_dict()
","text":"Serializes this instance to a Python dictionary. Override the default [~PretrainedConfig.to_dict
].
Returns:
Type Description Dict[str, any]
: Dictionary of all the attributes that make up this configuration instance,
Source code in src/aeiva/model/macaw_model.py
def to_dict(self):\n \"\"\"\n Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].\n\n Returns:\n `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,\n \"\"\"\n output = copy.deepcopy(self.__dict__)\n output[\"image_config\"] = self.image_config.to_dict()\n output[\"audio_config\"] = self.audio_config.to_dict()\n output['llm_config'] = self.llm_config.to_dict()\n output['n_frames'] = self.n_frames\n output['attention_heads'] = self.attention_heads\n output['image_conv_kernel'] = self.image_conv_kernel\n output['image_conv_stride'] = self.image_conv_stride\n output['video_conv_kernel'] = self.video_conv_kernel\n output['video_conv_stride'] = self.video_conv_stride\n output['audio_conv_kernel'] = self.audio_conv_kernel\n output['audio_conv_stride'] = self.audio_conv_stride\n output['hidden_size'] = self.hidden_size\n output[\"model_type\"] = self.__class__.model_type\n return output\n
"},{"location":"reference/#src.aeiva.model.macaw_model.WhisperEncoder","title":"WhisperEncoder
","text":" Bases: WhisperPreTrainedModel
Transformer encoder consisting of config.encoder_layers self attention layers. Each layer is a [WhisperEncoderLayer
].
Parameters:
Name Type Description Default config
WhisperConfig
WhisperConfig
required Source code in src/aeiva/model/macaw_model.py
class WhisperEncoder(WhisperPreTrainedModel):\n \"\"\"\n Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a\n [`WhisperEncoderLayer`].\n\n Args:\n config: WhisperConfig\n \"\"\"\n\n def __init__(self, config: WhisperConfig):\n super().__init__(config)\n self.dropout = config.dropout\n self.layerdrop = config.encoder_layerdrop\n\n embed_dim = config.d_model\n self.num_mel_bins = config.num_mel_bins\n self.padding_idx = config.pad_token_id\n self.max_source_positions = config.max_source_positions\n self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0\n\n self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)\n self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)\n\n self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)\n\n self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)])\n self.layer_norm = nn.LayerNorm(config.d_model)\n\n self.gradient_checkpointing = False\n # Initialize weights and apply final processing\n self.post_init()\n\n def _freeze_parameters(self):\n for param in self.parameters():\n param.requires_grad = False\n self._requires_grad = False\n\n def get_input_embeddings(self) -> nn.Module:\n return self.conv1\n\n def set_input_embeddings(self, value: nn.Module):\n self.conv1 = value\n\n def forward(\n self,\n input_features,\n attention_mask=None,\n head_mask=None,\n output_attentions=None,\n output_hidden_states=None,\n return_dict=None,\n ):\n r\"\"\"\n Args:\n input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):\n Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be\n obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a\n `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into\n `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding\n and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]\n attention_mask (`torch.Tensor`)`, *optional*):\n Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,\n but it is not used. By default the silence in the input log mel spectrogram are ignored.\n head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):\n Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:\n\n - 1 indicates the head is **not masked**,\n - 0 indicates the head is **masked**.\n output_attentions (`bool`, *optional*):\n Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n returned tensors for more detail.\n output_hidden_states (`bool`, *optional*):\n Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors\n for more detail.\n return_dict (`bool`, *optional*):\n Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n \"\"\"\n output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n output_hidden_states = (\n output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n )\n return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n inputs_embeds = nn.functional.gelu(self.conv1(input_features))\n inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))\n\n inputs_embeds = inputs_embeds.permute(0, 2, 1)\n embed_pos = self.embed_positions.weight\n\n hidden_states = inputs_embeds + embed_pos\n hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)\n\n encoder_states = () if output_hidden_states else None\n all_attentions = () if output_attentions else None\n\n # check if head_mask has a correct number of layers specified if desired\n if head_mask is not None:\n assert head_mask.size()[0] == (\n len(self.layers)\n ), f\"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}.\"\n\n for idx, encoder_layer in enumerate(self.layers):\n if output_hidden_states:\n encoder_states = encoder_states + (hidden_states,)\n # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)\n dropout_probability = random.uniform(0, 1)\n if self.training and (dropout_probability < self.layerdrop): # skip the layer\n layer_outputs = (None, None)\n else:\n if self.gradient_checkpointing and self.training:\n\n def create_custom_forward(module):\n def custom_forward(*inputs):\n return module(*inputs, output_attentions)\n\n return custom_forward\n\n layer_outputs = torch.utils.checkpoint.checkpoint(\n create_custom_forward(encoder_layer),\n hidden_states,\n None,\n (head_mask[idx] if head_mask is not None else None),\n )\n else:\n layer_outputs = encoder_layer(\n hidden_states,\n None,\n layer_head_mask=(head_mask[idx] if head_mask is not None else None),\n output_attentions=output_attentions,\n )\n\n hidden_states = layer_outputs[0]\n\n if output_attentions:\n all_attentions = all_attentions + (layer_outputs[1],)\n\n hidden_states = self.layer_norm(hidden_states)\n if output_hidden_states:\n encoder_states = encoder_states + (hidden_states,)\n\n if not return_dict:\n return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)\n return BaseModelOutput(\n last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions\n )\n
"},{"location":"reference/#src.aeiva.model.macaw_model.WhisperEncoder.forward","title":"forward(input_features, attention_mask=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None)
","text":"Parameters:
Name Type Description Default input_features
`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`
Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by loading a .flac
or .wav
audio file into an array of type List[float]
or a numpy.ndarray
, e.g. via the soundfile library (pip install soundfile
). To prepare the array into input_features
, the [AutoFeatureExtractor
] should be used for extracting the mel features, padding and conversion into a tensor of type torch.FloatTensor
. See [~WhisperFeatureExtractor.__call__
]
required attention_mask
`torch.Tensor`)`, *optional*
Whisper does not support masking of the input_features
, this argument is preserved for compatibility, but it is not used. By default the silence in the input log mel spectrogram are ignored.
None
head_mask
`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*
Mask to nullify selected heads of the attention modules. Mask values selected in [0, 1]
:
- 1 indicates the head is not masked,
- 0 indicates the head is masked.
None
output_attentions
`bool`, *optional*
Whether or not to return the attentions tensors of all attention layers. See attentions
under returned tensors for more detail.
None
output_hidden_states
`bool`, *optional*
Whether or not to return the hidden states of all layers. See hidden_states
under returned tensors for more detail.
None
return_dict
`bool`, *optional*
Whether or not to return a [~utils.ModelOutput
] instead of a plain tuple.
None
Source code in src/aeiva/model/macaw_model.py
def forward(\n self,\n input_features,\n attention_mask=None,\n head_mask=None,\n output_attentions=None,\n output_hidden_states=None,\n return_dict=None,\n):\n r\"\"\"\n Args:\n input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):\n Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be\n obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a\n `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into\n `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding\n and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]\n attention_mask (`torch.Tensor`)`, *optional*):\n Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,\n but it is not used. By default the silence in the input log mel spectrogram are ignored.\n head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):\n Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:\n\n - 1 indicates the head is **not masked**,\n - 0 indicates the head is **masked**.\n output_attentions (`bool`, *optional*):\n Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n returned tensors for more detail.\n output_hidden_states (`bool`, *optional*):\n Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors\n for more detail.\n return_dict (`bool`, *optional*):\n Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n \"\"\"\n output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n output_hidden_states = (\n output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n )\n return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n inputs_embeds = nn.functional.gelu(self.conv1(input_features))\n inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))\n\n inputs_embeds = inputs_embeds.permute(0, 2, 1)\n embed_pos = self.embed_positions.weight\n\n hidden_states = inputs_embeds + embed_pos\n hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)\n\n encoder_states = () if output_hidden_states else None\n all_attentions = () if output_attentions else None\n\n # check if head_mask has a correct number of layers specified if desired\n if head_mask is not None:\n assert head_mask.size()[0] == (\n len(self.layers)\n ), f\"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}.\"\n\n for idx, encoder_layer in enumerate(self.layers):\n if output_hidden_states:\n encoder_states = encoder_states + (hidden_states,)\n # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)\n dropout_probability = random.uniform(0, 1)\n if self.training and (dropout_probability < self.layerdrop): # skip the layer\n layer_outputs = (None, None)\n else:\n if self.gradient_checkpointing and self.training:\n\n def create_custom_forward(module):\n def custom_forward(*inputs):\n return module(*inputs, output_attentions)\n\n return custom_forward\n\n layer_outputs = torch.utils.checkpoint.checkpoint(\n create_custom_forward(encoder_layer),\n hidden_states,\n None,\n (head_mask[idx] if head_mask is not None else None),\n )\n else:\n layer_outputs = encoder_layer(\n hidden_states,\n None,\n layer_head_mask=(head_mask[idx] if head_mask is not None else None),\n output_attentions=output_attentions,\n )\n\n hidden_states = layer_outputs[0]\n\n if output_attentions:\n all_attentions = all_attentions + (layer_outputs[1],)\n\n hidden_states = self.layer_norm(hidden_states)\n if output_hidden_states:\n encoder_states = encoder_states + (hidden_states,)\n\n if not return_dict:\n return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)\n return BaseModelOutput(\n last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions\n )\n
"},{"location":"reference/#src.aeiva.model.macaw_model.rotate_half","title":"rotate_half(x)
","text":"Rotates half the hidden dims of the input.
Source code in src/aeiva/model/macaw_model.py
def rotate_half(x):\n \"\"\"Rotates half the hidden dims of the input.\"\"\"\n x1 = x[..., : x.shape[-1] // 2]\n x2 = x[..., x.shape[-1] // 2 :]\n return torch.cat((-x2, x1), dim=-1)\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old","title":"macaw_model_old
","text":"This script contains the implementation of the MACAW model. MACAW is a multimodal transformer model that combines the CLIP and Whisper models.
Author: Bang Liu Date: 2023-06-22
References: - Macaw-LLM code repository: https://github.com/lyuchenyang/Macaw-LLM/blob/main/modeling.py
"},{"location":"reference/#src.aeiva.model.macaw_model_old.LlamaAttention","title":"LlamaAttention
","text":" Bases: Module
Multi-headed attention from 'Attention Is All You Need' paper
Source code in src/aeiva/model/macaw_model_old.py
class LlamaAttention(nn.Module):\n \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n def __init__(self, config: LlamaConfig):\n super().__init__()\n self.config = config\n self.hidden_size = config.hidden_size\n self.num_heads = config.num_attention_heads\n self.head_dim = self.hidden_size // self.num_heads\n self.max_position_embeddings = config.max_position_embeddings # !!! I want to change this variable name.\n\n if (self.head_dim * self.num_heads) != self.hidden_size:\n raise ValueError(\n f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}\"\n f\" and `num_heads`: {self.num_heads}).\"\n )\n self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)\n self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)\n\n def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n def forward(\n self,\n hidden_states: torch.Tensor,\n attention_mask: Optional[torch.Tensor] = None,\n position_ids: Optional[torch.LongTensor] = None,\n past_key_value: Optional[Tuple[torch.Tensor]] = None,\n output_attentions: bool = False,\n use_cache: bool = False,\n ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n bsz, q_len, _ = hidden_states.size()\n\n # By placing the num_heads dimension as the second dimension, it allows for \n # efficient batched matrix operations (e.g., matrix multiplication in attention computation) \n # across all the heads. It is basically a data layout optimization for computational efficiency \n # in the context of multi-head attention.\n query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n\n kv_seq_len = key_states.shape[-2] # the shape is [batch_size, num_heads, seq_len, head_dim], so -2 dimension is 'seq_len'\n if past_key_value is not None: \n # If past_key_value is not None, this means the model is being used in an autoregressive setting, \n # where the past key-value pairs are given to the current step.\n # past_key_value[0] refers to the previously computed key states,\n # past_key_value[1] refers to the previously computed value states.\n # The shape of past_key_value[0] and past_key_value[1] is [batch_size, num_heads, seq_len, head_dim].\n kv_seq_len += past_key_value[0].shape[-2] # + past seq_len\n\n cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n\n if past_key_value is not None:\n # reuse k, v, self_attention\n key_states = torch.cat([past_key_value[0], key_states], dim=2)\n value_states = torch.cat([past_key_value[1], value_states], dim=2)\n\n past_key_value = (key_states, value_states) if use_cache else None\n\n attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n raise ValueError(\n f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n f\" {attn_weights.size()}\"\n )\n\n if attention_mask is not None:\n if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n raise ValueError(\n f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n )\n attn_weights = attn_weights + attention_mask\n # This following line is ensuring numerical stability. It caps the minimum value of the attention weights\n # to be the minimum finite representable number for the data type of attn_weights. This avoids \n # potential issues with underflow when these weights are later passed through the softmax function.\n attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))\n\n # upcast attention to fp32\n # This is done to prevent numerical instability that can occur\n # during operations on very small numbers or very large numbers.\n attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n attn_output = torch.matmul(attn_weights, value_states)\n\n if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n raise ValueError(\n f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n f\" {attn_output.size()}\"\n )\n\n attn_output = attn_output.transpose(1, 2)\n attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) # self.hidden_size is equivalent to self.num_heads * self.head_dim\n\n attn_output = self.o_proj(attn_output)\n\n if not output_attentions:\n attn_weights = None\n\n return attn_output, attn_weights, past_key_value\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.LlamaDecoderLayer","title":"LlamaDecoderLayer
","text":" Bases: Module
Source code in src/aeiva/model/macaw_model_old.py
class LlamaDecoderLayer(nn.Module):\n def __init__(self, config: LlamaConfig):\n super().__init__()\n self.hidden_size = config.hidden_size\n self.self_attn = LlamaAttention(config=config)\n self.mlp = LlamaMLP(\n hidden_size=self.hidden_size,\n intermediate_size=config.intermediate_size,\n hidden_act=config.hidden_act,\n )\n self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n def forward(\n self,\n hidden_states: torch.Tensor,\n attention_mask: Optional[torch.Tensor] = None,\n position_ids: Optional[torch.LongTensor] = None,\n past_key_value: Optional[Tuple[torch.Tensor]] = None,\n output_attentions: Optional[bool] = False,\n use_cache: Optional[bool] = False,\n ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n \"\"\"\n Args:\n hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n output_attentions (`bool`, *optional*):\n Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n returned tensors for more detail.\n use_cache (`bool`, *optional*):\n If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n (see `past_key_values`).\n past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n \"\"\"\n\n residual = hidden_states\n\n hidden_states = self.input_layernorm(hidden_states)\n\n # Self Attention\n hidden_states, self_attn_weights, present_key_value = self.self_attn(\n hidden_states=hidden_states,\n attention_mask=attention_mask,\n position_ids=position_ids,\n past_key_value=past_key_value,\n output_attentions=output_attentions,\n use_cache=use_cache,\n )\n hidden_states = residual + hidden_states\n\n # Fully Connected\n residual = hidden_states\n hidden_states = self.post_attention_layernorm(hidden_states)\n hidden_states = self.mlp(hidden_states)\n hidden_states = residual + hidden_states\n\n outputs = (hidden_states,)\n\n if output_attentions:\n outputs += (self_attn_weights,)\n\n if use_cache:\n outputs += (present_key_value,)\n\n return outputs\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.LlamaDecoderLayer.forward","title":"forward(hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False)
","text":"Parameters:
Name Type Description Default hidden_states
`torch.FloatTensor`
input to the layer of shape (batch, seq_len, embed_dim)
required attention_mask
`torch.FloatTensor`, *optional*
attention mask of size (batch, 1, tgt_len, src_len)
where padding elements are indicated by very large negative values.
None
output_attentions
`bool`, *optional*
Whether or not to return the attentions tensors of all attention layers. See attentions
under returned tensors for more detail.
False
use_cache
`bool`, *optional*
If set to True
, past_key_values
key value states are returned and can be used to speed up decoding (see past_key_values
).
False
past_key_value
`Tuple(torch.FloatTensor)`, *optional*
cached past key and value projection states
None
Source code in src/aeiva/model/macaw_model_old.py
def forward(\n self,\n hidden_states: torch.Tensor,\n attention_mask: Optional[torch.Tensor] = None,\n position_ids: Optional[torch.LongTensor] = None,\n past_key_value: Optional[Tuple[torch.Tensor]] = None,\n output_attentions: Optional[bool] = False,\n use_cache: Optional[bool] = False,\n) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n \"\"\"\n Args:\n hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n output_attentions (`bool`, *optional*):\n Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n returned tensors for more detail.\n use_cache (`bool`, *optional*):\n If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n (see `past_key_values`).\n past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n \"\"\"\n\n residual = hidden_states\n\n hidden_states = self.input_layernorm(hidden_states)\n\n # Self Attention\n hidden_states, self_attn_weights, present_key_value = self.self_attn(\n hidden_states=hidden_states,\n attention_mask=attention_mask,\n position_ids=position_ids,\n past_key_value=past_key_value,\n output_attentions=output_attentions,\n use_cache=use_cache,\n )\n hidden_states = residual + hidden_states\n\n # Fully Connected\n residual = hidden_states\n hidden_states = self.post_attention_layernorm(hidden_states)\n hidden_states = self.mlp(hidden_states)\n hidden_states = residual + hidden_states\n\n outputs = (hidden_states,)\n\n if output_attentions:\n outputs += (self_attn_weights,)\n\n if use_cache:\n outputs += (present_key_value,)\n\n return outputs\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.LlamaModel","title":"LlamaModel
","text":" Bases: LlamaPreTrainedModel
Transformer decoder consisting of config.num_hidden_layers layers. Each layer is a [LlamaDecoderLayer
]
Parameters:
Name Type Description Default config
LlamaConfig
LlamaConfig
required Source code in src/aeiva/model/macaw_model_old.py
class LlamaModel(LlamaPreTrainedModel):\n \"\"\"\n Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n\n Args:\n config: LlamaConfig\n \"\"\"\n\n def __init__(self, config: LlamaConfig):\n super().__init__(config)\n # embedding layer, stacked decoder layers, and layer normalization in llama.\n self.padding_idx = config.pad_token_id\n self.vocab_size = config.vocab_size\n self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n\n self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])\n\n self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n # Gradient checkpointing is a technique to reduce the memory usage when training deep neural networks.\n # In deep learning, when you perform backpropagation to compute gradients and update the model parameters,\n # you need to store the intermediate activations from the forward pass, so you can use them in the backward pass. \n # For large models or long sequences, this can consume a lot of memory.\n # \n # Gradient checkpointing addresses this by not storing all the intermediate activations in memory during the forward pass. \n # Instead, it stores only a subset of the activations, and recomputes the rest during the backward pass as needed. \n # This trades off computation time (because you need to recompute some values) for memory usage.\n # \n # This technique is particularly useful when training large models that would otherwise not fit into GPU memory. \n # However, it can slow down training because of the extra computation.\n self.gradient_checkpointing = False\n\n # Initialize weights and apply final processing\n self.post_init()\n\n def get_input_embeddings(self):\n return self.embed_tokens\n\n def set_input_embeddings(self, value):\n self.embed_tokens = value\n\n # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask\n def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):\n # create causal mask\n # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n combined_attention_mask = None\n if input_shape[-1] > 1: # seq_len > 1\n combined_attention_mask = _make_causal_mask(\n input_shape,\n inputs_embeds.dtype,\n device=inputs_embeds.device,\n past_key_values_length=past_key_values_length,\n )\n\n if attention_mask is not None:\n # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(\n inputs_embeds.device\n )\n combined_attention_mask = (\n expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n )\n\n return combined_attention_mask\n\n # @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n def forward(\n self,\n input_ids: torch.LongTensor = None,\n attention_mask: Optional[torch.Tensor] = None,\n position_ids: Optional[torch.LongTensor] = None,\n past_key_values: Optional[List[torch.FloatTensor]] = None,\n inputs_embeds: Optional[torch.FloatTensor] = None,\n use_cache: Optional[bool] = None,\n output_attentions: Optional[bool] = None,\n output_hidden_states: Optional[bool] = None,\n return_dict: Optional[bool] = None,\n ) -> Union[Tuple, BaseModelOutputWithPast]:\n # set output and cache flags\n output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n use_cache = use_cache if use_cache is not None else self.config.use_cache\n return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n # prepare input_ids/inputs_embeds\n if input_ids is not None and inputs_embeds is not None:\n raise ValueError(\"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time\")\n elif input_ids is not None:\n batch_size, seq_length = input_ids.shape\n elif inputs_embeds is not None:\n batch_size, seq_length, _ = inputs_embeds.shape\n else:\n raise ValueError(\"You have to specify either decoder_input_ids or decoder_inputs_embeds\")\n if inputs_embeds is None:\n inputs_embeds = self.embed_tokens(input_ids)\n\n # prepare attention mask and other parameters for decoder layers\n past_key_values_length = 0\n seq_length_with_past = seq_length\n\n if past_key_values is not None:\n past_key_values_length = past_key_values[0][0].shape[2]\n seq_length_with_past = seq_length_with_past + past_key_values_length\n\n if position_ids is None:\n device = input_ids.device if input_ids is not None else inputs_embeds.device\n position_ids = torch.arange(\n past_key_values_length, past_key_values_length + seq_length, dtype=torch.long, device=device\n )\n position_ids = position_ids.unsqueeze(0).view(-1, seq_length)\n else:\n position_ids = position_ids.view(-1, seq_length).long()\n\n if attention_mask is None:\n attention_mask = torch.ones(\n (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device\n )\n attention_mask = self._prepare_decoder_attention_mask(\n attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length\n )\n\n hidden_states = inputs_embeds\n\n if self.gradient_checkpointing and self.training:\n if use_cache:\n logger.warning_once(\n \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n )\n use_cache = False\n\n # forward through all decoder layers\n all_hidden_states = () if output_hidden_states else None\n all_self_attns = () if output_attentions else None\n next_decoder_cache = () if use_cache else None\n\n for idx, decoder_layer in enumerate(self.layers):\n if output_hidden_states:\n all_hidden_states += (hidden_states,)\n\n past_key_value = past_key_values[idx] if past_key_values is not None else None\n\n if self.gradient_checkpointing and self.training:\n # define the function for gradient checkpointing\n # in checkpointing, we need to create a custom function for the forward pass \n # (the custom_forward function in your code) and then using the \n # torch.utils.checkpoint.checkpoint function to apply this custom function \n # with gradient checkpointing.\n def create_custom_forward(module):\n def custom_forward(*inputs):\n return module(*inputs, output_attentions, None) # None for past_key_value\n return custom_forward\n\n layer_outputs = torch.utils.checkpoint.checkpoint(\n create_custom_forward(decoder_layer),\n hidden_states,\n attention_mask,\n position_ids,\n None,\n )\n else:\n layer_outputs = decoder_layer(\n hidden_states,\n attention_mask=attention_mask,\n position_ids=position_ids,\n past_key_value=past_key_value,\n output_attentions=output_attentions,\n use_cache=use_cache,\n )\n\n hidden_states = layer_outputs[0]\n\n if use_cache:\n next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)\n\n if output_attentions:\n all_self_attns += (layer_outputs[1],)\n\n hidden_states = self.norm(hidden_states)\n\n # add hidden states from the last decoder layer\n if output_hidden_states:\n all_hidden_states += (hidden_states,)\n\n next_cache = next_decoder_cache if use_cache else None\n\n # output the hidden states, the self attentions and the cache (if needed)\n if not return_dict:\n return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)\n return BaseModelOutputWithPast(\n last_hidden_state=hidden_states,\n past_key_values=next_cache,\n hidden_states=all_hidden_states,\n attentions=all_self_attns,\n )\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.LlamaRMSNorm","title":"LlamaRMSNorm
","text":" Bases: Module
Source code in src/aeiva/model/macaw_model_old.py
class LlamaRMSNorm(nn.Module):\n def __init__(self, hidden_size, eps=1e-6):\n \"\"\"\n LlamaRMSNorm is equivalent to T5LayerNorm\n The overall effect of this layer is to ensure that,\n for each feature in the hidden_states,\n the activations have zero mean and unit variance across the batch.\n This can make the training process more stable and faster.\n \"\"\"\n super().__init__()\n self.weight = nn.Parameter(torch.ones(hidden_size)) # trainable parameter for affine transformation\n self.variance_epsilon = eps # for numerical stability\n\n def forward(self, hidden_states):\n variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)\n hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n\n # convert into half-precision if necessary\n if self.weight.dtype in [torch.float16, torch.bfloat16]:\n hidden_states = hidden_states.to(self.weight.dtype)\n\n return self.weight * hidden_states\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.LlamaRMSNorm.__init__","title":"__init__(hidden_size, eps=1e-06)
","text":"LlamaRMSNorm is equivalent to T5LayerNorm The overall effect of this layer is to ensure that, for each feature in the hidden_states, the activations have zero mean and unit variance across the batch. This can make the training process more stable and faster.
Source code in src/aeiva/model/macaw_model_old.py
def __init__(self, hidden_size, eps=1e-6):\n \"\"\"\n LlamaRMSNorm is equivalent to T5LayerNorm\n The overall effect of this layer is to ensure that,\n for each feature in the hidden_states,\n the activations have zero mean and unit variance across the batch.\n This can make the training process more stable and faster.\n \"\"\"\n super().__init__()\n self.weight = nn.Parameter(torch.ones(hidden_size)) # trainable parameter for affine transformation\n self.variance_epsilon = eps # for numerical stability\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.LlamaRotaryEmbedding","title":"LlamaRotaryEmbedding
","text":" Bases: Module
Rotary embedding described in: https://arxiv.org/pdf/2104.09864.pdf. It is used to modulate the position information in the input embeddings. Llama used rotary embedding.
Source code in src/aeiva/model/macaw_model_old.py
class LlamaRotaryEmbedding(torch.nn.Module):\n \"\"\"\n Rotary embedding described in: https://arxiv.org/pdf/2104.09864.pdf.\n It is used to modulate the position information in the input embeddings.\n Llama used rotary embedding.\n \"\"\"\n def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n super().__init__()\n # Compute the inverse frequencies, which will be used to modulate the position information\n inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))\n # The register_buffer() function is used in PyTorch to register a tensor that is not a parameter,\n # but you still want it to be a part of the model's state. It's used for tensors that should\n # have their state saved in the model's state_dict and should be moved to the device with the rest of the model.\n self.register_buffer(\"inv_freq\", inv_freq)\n\n # Build here to make `torch.jit.trace` work.\n # max_position_embeddings: max sequence length that this model might ever be used with\n self.max_seq_len_cached = max_position_embeddings\n\n # Compute the positional encodings (both cos and sin parts)\n t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)\n freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n\n # Different from paper, but it uses a different permutation in order to obtain the same calculation\n emb = torch.cat((freqs, freqs), dim=-1)\n self.register_buffer(\"cos_cached\", emb.cos()[None, None, :, :], persistent=False)\n self.register_buffer(\"sin_cached\", emb.sin()[None, None, :, :], persistent=False)\n\n def forward(self, x, seq_len=None):\n # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.\n # x.shape: [batch_size, num_attention_heads, sequence_length, head_size].\n # The forward function then outputs two tensors, each of which is a sin or cos embedding representation of the input x. \n # Both output tensors will have a shape of [1, 1, sequence_length, head_size].\n # NOTE: Only the dtype and device attributes of x are relevant here. The values are not used.\n if seq_len > self.max_seq_len_cached:\n self.max_seq_len_cached = seq_len\n t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)\n freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n # Different from paper, but it uses a different permutation in order to obtain the same calculation\n emb = torch.cat((freqs, freqs), dim=-1).to(x.device)\n self.register_buffer(\"cos_cached\", emb.cos()[None, None, :, :], persistent=False)\n self.register_buffer(\"sin_cached\", emb.sin()[None, None, :, :], persistent=False)\n return (\n self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),\n self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),\n )\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.MM_LLMs","title":"MM_LLMs
","text":" Bases: PreTrainedModel
This is the multimodal language model that combines CLIP and Whisper encoders with a language model. We need a config file to specify the multimodal encoder configurations.
Source code in src/aeiva/model/macaw_model_old.py
class MM_LLMs(PreTrainedModel):\n \"\"\"\n This is the multimodal language model that combines CLIP and Whisper encoders with a language model.\n We need a config file to specify the multimodal encoder configurations.\n \"\"\"\n def __init__(self, config):\n super().__init__(config)\n # multimodal config\n self.config = config\n\n # multimodal encoders\n self.image_encoder = CLIPModel(config.image_config) # NOTE: here they use CLIP for both image and video.\n self.video_encoder = CLIPModel(config.image_config)\n self.audio_encoder = WhisperModel(config.audio_config)\n self.llm = LlamaForCausalLM(config.llm_config)\n\n # video temporal position embedding layer\n self.temporal_position_embeddings = nn.Embedding(\n config.n_frames, \n config.image_config.projection_dim)\n\n # multimodal attention layers for mapping multimodal features to the same space\n attn_dropout = 0.1\n is_add_bias_kv = True\n is_add_zero_attn = True\n self.temporal_self_attention = nn.MultiheadAttention(config.image_config.projection_dim,\n config.attention_heads,\n dropout=attn_dropout,\n add_bias_kv=is_add_bias_kv,\n add_zero_attn=is_add_zero_attn)\n self.video_align_attention = nn.MultiheadAttention(config.llm_config.hidden_size, \n config.attention_heads,\n dropout=attn_dropout,\n add_bias_kv=is_add_bias_kv,\n add_zero_attn=is_add_zero_attn)\n self.audio_align_attention = nn.MultiheadAttention(config.llm_config.hidden_size, \n config.attention_heads,\n dropout=attn_dropout,\n add_bias_kv=is_add_bias_kv,\n add_zero_attn=is_add_zero_attn)\n self.image_align_attention = nn.MultiheadAttention(config.llm_config.hidden_size, \n config.attention_heads,\n dropout=attn_dropout,\n add_bias_kv=is_add_bias_kv,\n add_zero_attn=is_add_zero_attn)\n\n # multimodal projection layers for mapping multimodal features to the same space\n self.transform_video_to_hidden = nn.Linear(config.image_config.projection_dim, \n config.llm_config.hidden_size)\n self.transform_audio_to_hidden = nn.Linear(config.audio_config.d_model, \n config.llm_config.hidden_size)\n self.transform_image_to_hidden = nn.Linear(config.image_config.projection_dim, \n config.llm_config.hidden_size)\n\n self.project_image = nn.Conv1d(config.image_config.projection_dim, config.image_config.projection_dim, \n kernel_size=48, stride=36)\n self.project_video = nn.Conv1d(config.image_config.projection_dim, config.image_config.projection_dim, \n kernel_size=36, stride=30)\n self.project_audio = nn.Conv1d(config.audio_config.d_model, config.audio_config.d_model, \n kernel_size=240, stride=220)\n\n # multimodal fusion layers\n self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))\n\n self.layer_norm = nn.LayerNorm(config.image_config.projection_dim)\n self.softmax = nn.Softmax(dim=-1)\n self.relu = nn.ReLU()\n self.gelu = nn.GELU()\n self.elu = nn.ELU()\n self.sigmoid = nn.Sigmoid()\n\n self.loss_fct = CrossEntropyLoss()\n\n self.init_weights()\n\n def forward(self, inputs=None):\n # \"\"\"\n # :param inputs:\n # video_frames: (B x F)\n # audios: B x 1\n # images: B x 1\n # input_ids: B x L\n # labels: B x L\n #\n # :return: the output of the language model LlamaForCausalLM.\n # \"\"\"\n text_embeddings, attention_mask, labels = self.prepare_inputs_for_generation(inputs)\n\n if 'inference' in inputs and inputs['inference'] is True:\n # generate_ids = self.llm.generate(input_ids=inputs['input_ids'], inputs_embeds=text_embeddings, max_new_tokens=128)\n # generate_ids = self.llm.generate(inputs_embeds=text_embeddings, max_new_tokens=128)\n\n # !!! The code below will possibly trigger an error in : https://github.com/microsoft/DeepSpeed/issues/3156 (the solution only partially resolves the bug for me)\n generate_ids = self.llm.generate(\n inputs_embeds=text_embeddings, max_new_tokens=128, eos_token_id=2, bos_token_id=1, pad_token_id=32006 # !!! revise later. use config constants instead.\n )\n return generate_ids\n outputs = self.llm(inputs_embeds=text_embeddings, attention_mask=attention_mask, labels=labels)\n\n return outputs\n\n def prepare_inputs_for_generation(self, inputs):\n \"\"\"\n The purpose of this method is to integrate the different modalities into the text embeddings \n and prepare the associated attention mask and labels for the language model, so the model can \n generate text conditioned on all the input modalities.\n\n inputs is a dictionary containing the following keys: (!!! my hypothesis)\n video_frames: (B x F)\n audios: B x 1\n images: B x 1\n input_ids: B x L\n attention_mask: B x L\n labels: B x L\n video_starts: B x 1\n video_ends: B x 1\n audio_starts: B x 1\n audio_ends: B x 1\n image_starts: B x 1\n image_ends: B x 1\n inference: True/False\n \"\"\"\n # get multimodal embeddings\n image_features = self.encode_image(inputs['images']) if inputs['images'] is not None else None\n audio_features = self.encode_audio(inputs['audios']) if inputs['audios'] is not None else None\n video_features = self.encode_video(inputs['videos']) if inputs['videos'] is not None else None\n embed_tokens = self.llm.model.embed_tokens\n\n\n # for debug !!!!!!\n # Find maximum id in input_ids\n max_id = torch.max(inputs['input_ids'])\n print(f\"Max ID in input_ids: {max_id.item()}\")\n\n # Get vocab size from embedding layer\n vocab_size = embed_tokens.num_embeddings\n print(f\"Vocabulary size: {vocab_size}\")\n\n\n\n text_embeddings = embed_tokens(inputs['input_ids'])\n\n token_embeddings = embed_tokens.weight.unsqueeze(0).repeat(\n text_embeddings.size(0), 1, 1).transpose(0, 1)\n\n # ignore_num seems to be a counter that tracks the total size (or length) of the \n # multimodal input segments (video, audio, image) added to the original text inputs.\n ingore_num = 0\n\n # project and merge video features to the same space as text embeddings\n if video_features is not None:\n # get video starts and ends embeddings\n video_starts = embed_tokens(inputs['video_starts']).unsqueeze(1)\n video_ends = embed_tokens(inputs['video_ends']).unsqueeze(1)\n\n # project video features to the same space as text embeddings\n video_features = self.transform_video_to_hidden(video_features)\n\n video_features = self.video_align_attention(\n video_features.transpose(0, 1), token_embeddings, token_embeddings)[0].transpose(0, 1).contiguous()\n\n # concatenate video starts, video features, and video ends embeddings\n video_inputs = torch.cat([torch.cat([video_starts, video_features], dim=1), video_ends], dim=1)\n\n # concatenate video inputs to the original text embeddings\n # NOTE: the first token of text_embeddings keeps at the same position\n text_embeddings = torch.cat([torch.cat([text_embeddings[:, 0, :].unsqueeze(1), video_inputs], dim=1), text_embeddings[:, 1:, :]], dim=1)\n\n ingore_num += (video_inputs.size(1))\n\n # project and merge audio features to the same space as text embeddings\n if audio_features is not None:\n # get audio starts and ends embeddings\n audio_starts = embed_tokens(inputs['audio_starts']).unsqueeze(1)\n audio_ends = embed_tokens(inputs['audio_ends']).unsqueeze(1)\n\n # project audio features to the same space as text embeddings\n audio_features = self.project_audio(audio_features.transpose(1, 2).contiguous()).transpose(1, 2).contiguous()\n audio_features = self.transform_audio_to_hidden(audio_features)\n # mean pooling\n # audio_features = torch.sum(audio_features, dim=1) / audio_features.size(1) \n # audio_features = audio_features.unsqueeze(1)\n audio_features = self.audio_align_attention(\n audio_features.transpose(0, 1), token_embeddings, token_embeddings)[0].transpose(0, 1).contiguous()\n\n # concatenate audio starts, audio features, and audio ends embeddings\n audio_inputs = torch.cat([torch.cat([audio_starts, audio_features], dim=1), audio_ends], dim=1)\n\n # concatenate audio inputs to the original text embeddings\n text_embeddings = torch.cat(\n [torch.cat([text_embeddings[:, 0, :].unsqueeze(1), audio_inputs], dim=1), text_embeddings[:, 1:, :]],\n dim=1)\n\n ingore_num += (audio_inputs.size(1))\n\n # project and merge image features to the same space as text embeddings\n if image_features is not None:\n # get image starts and ends embeddings\n image_starts = embed_tokens(inputs['image_starts']).unsqueeze(1)\n image_ends = embed_tokens(inputs['image_ends']).unsqueeze(1)\n\n # project image features to the same space as text embeddings\n image_features = self.project_image(image_features.transpose(1, 2).contiguous()).transpose(1, 2).contiguous()\n image_features = self.transform_image_to_hidden(image_features)\n image_features = self.image_align_attention(\n image_features.transpose(0, 1), token_embeddings, token_embeddings)[0].transpose(0, 1).contiguous()\n\n # concatenate image starts, image features, and image ends embeddings\n image_inputs = torch.cat([torch.cat([image_starts, image_features], dim=1), image_ends], dim=1)\n\n # concatenate image inputs to the original text embeddings\n text_embeddings = torch.cat(\n [torch.cat([text_embeddings[:, 0, :].unsqueeze(1), image_inputs], dim=1), \n text_embeddings[:, 1:, :]], dim=1)\n\n ingore_num += (image_inputs.size(1))\n\n if 'attention_mask' in inputs:\n # increase the length of attention mask by adding the length of multimodal inputs\n attention_mask = torch.tensor([1]*ingore_num*text_embeddings.size(0), device=text_embeddings.device).view(text_embeddings.size(0), -1) # (B X ignore_num)\n attention_mask = torch.cat([attention_mask, inputs['attention_mask']], dim=1)\n else:\n attention_mask = None\n\n if 'labels' in inputs and inputs['labels'] is not None:\n # increase the length of labels by adding the length of labels\n # we use -100 to ignore the loss of labels in multimodal inputs\n # !!! we can replace -100 by config constants to make the code better\n\n # since the tokens corresponding to the image_inputs, audio_inputs, and video_inputs are not part of the original text \n # and don't have corresponding labels in the true text sequence, their labels are set to -100. This ensures that \n # the model's predictions for these tokens don't affect the loss and, consequently, the gradients and the model's subsequent learning.\n labels = torch.tensor([-100]*ingore_num*text_embeddings.size(0), device=text_embeddings.device).view(text_embeddings.size(0), -1)\n labels = torch.cat([labels, inputs['labels']], dim=1)\n else:\n labels = None\n\n # text_embeddings: (batch_size, sequence_length + ingore_num, embedding_dim)\n # attention_mask: (batch_size, sequence_length + ingore_num). 1 denotes we should attend to, and 0 denotes we should not attend to the token.\n # labels: (batch_size, sequence_length + ingore_num). -100 denotes we should ignore the token.\n return text_embeddings, attention_mask, labels\n\n def encode_video(self, videos):\n \"\"\"\n Encode video features to video embeddings.\n\n Args:\n videos: (batch_size, n_frames, n_channels, height, width)\n\n Returns:\n video_embeddings: (batch_size, n_frames, embedding_dim)\n \"\"\"\n # simple image encoding without temporal embedding and self attention\n # Reference: https://huggingface.co/docs/transformers/model_doc/clip\n videos = videos.view(-1, videos.size(-3), videos.size(-2), videos.size(-1)) # pixel_values (torch.FloatTensor of shape (batch_size * n_frames, num_channels, height, width)) \n video_outputs = self.video_encoder.get_image_features(videos) # image_features (torch.FloatTensor of shape (batch_size * n_frames, output_dim)\n video_features = video_outputs\n temporal_pos = torch.tensor(\n [[i for i in range(self.config.n_frames)] \n for j in range(videos.size(0) // self.config.n_frames)],\n dtype=torch.int, device=video_features.device).view(-1) # 2d indices to 1d indices, shape: (batch_size * n_frames)\n\n frame_temporal_pos_embed = self.temporal_position_embeddings(temporal_pos)\n\n video_features = (video_features + frame_temporal_pos_embed).view(\n videos.size(0) // self.config.n_frames, self.config.n_frames, -1) # (batch_size, n_frames, output_dim)\n\n video_features = video_features.transpose(0, 1).contiguous()\n # nn.MultiheadAttention takes query, key, value as inputs. Their shapes are (sequence_length, batch_size, embedding_dim).\n # The outputs are two elements: attn_output of shape (sequence_length, batch_size, embedding_dim), and attn_output_weights of shape (batch_size, sequence_length, sequence_length).\n self_attn_video_features = self.temporal_self_attention(video_features, video_features, video_features)[0]\n\n return self_attn_video_features.transpose(0, 1).contiguous() # (batch_size, n_frames, output_dim)\n\n def encode_video_long(self, videos):\n \"\"\"\n Encode video features to video embeddings.\n\n Args:\n videos: (batch_size, n_frames, n_channels, height, width)\n\n Returns:\n video_embeddings: (batch_size, n_frames, embedding_dim)\n \"\"\"\n # simple image encoding without temporal embedding and self attention\n videos = videos.view(-1, videos.size(-3), videos.size(-2), videos.size(-1)) # pixel_values (torch.FloatTensor of shape (batch_size * n_frames, num_channels, height, width))\n video_features = self.video_encoder.visual_projection(self.video_encoder.vision_model(videos)[0])[:, 1:, :]\n video_features = video_features.reshape(\n videos.size(0) // self.config.n_frames,\n self.config.n_frames * video_features.size(1),\n -1).contiguous()\n\n return video_features\n\n def encode_audio(self, audios):\n audio_features = self.audio_encoder.encoder(audios)\n return audio_features[0]\n\n def encode_image(self, images):\n # vision_outputs = self.image_encoder.get_image_features(images)\n # image_features = vision_outputs # pooled_output\n # image_features = self.visual_projection(pooled_output)\n # image_features = image_features.unsqueeze(1)\n image_features = self.image_encoder.visual_projection(self.image_encoder.vision_model(images)[0])[:, 1:, :]\n return image_features\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.MM_LLMs.encode_video","title":"encode_video(videos)
","text":"Encode video features to video embeddings.
Parameters:
Name Type Description Default videos
(batch_size, n_frames, n_channels, height, width)
required Returns:
Name Type Description video_embeddings
(batch_size, n_frames, embedding_dim)
Source code in src/aeiva/model/macaw_model_old.py
def encode_video(self, videos):\n \"\"\"\n Encode video features to video embeddings.\n\n Args:\n videos: (batch_size, n_frames, n_channels, height, width)\n\n Returns:\n video_embeddings: (batch_size, n_frames, embedding_dim)\n \"\"\"\n # simple image encoding without temporal embedding and self attention\n # Reference: https://huggingface.co/docs/transformers/model_doc/clip\n videos = videos.view(-1, videos.size(-3), videos.size(-2), videos.size(-1)) # pixel_values (torch.FloatTensor of shape (batch_size * n_frames, num_channels, height, width)) \n video_outputs = self.video_encoder.get_image_features(videos) # image_features (torch.FloatTensor of shape (batch_size * n_frames, output_dim)\n video_features = video_outputs\n temporal_pos = torch.tensor(\n [[i for i in range(self.config.n_frames)] \n for j in range(videos.size(0) // self.config.n_frames)],\n dtype=torch.int, device=video_features.device).view(-1) # 2d indices to 1d indices, shape: (batch_size * n_frames)\n\n frame_temporal_pos_embed = self.temporal_position_embeddings(temporal_pos)\n\n video_features = (video_features + frame_temporal_pos_embed).view(\n videos.size(0) // self.config.n_frames, self.config.n_frames, -1) # (batch_size, n_frames, output_dim)\n\n video_features = video_features.transpose(0, 1).contiguous()\n # nn.MultiheadAttention takes query, key, value as inputs. Their shapes are (sequence_length, batch_size, embedding_dim).\n # The outputs are two elements: attn_output of shape (sequence_length, batch_size, embedding_dim), and attn_output_weights of shape (batch_size, sequence_length, sequence_length).\n self_attn_video_features = self.temporal_self_attention(video_features, video_features, video_features)[0]\n\n return self_attn_video_features.transpose(0, 1).contiguous() # (batch_size, n_frames, output_dim)\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.MM_LLMs.encode_video_long","title":"encode_video_long(videos)
","text":"Encode video features to video embeddings.
Parameters:
Name Type Description Default videos
(batch_size, n_frames, n_channels, height, width)
required Returns:
Name Type Description video_embeddings
(batch_size, n_frames, embedding_dim)
Source code in src/aeiva/model/macaw_model_old.py
def encode_video_long(self, videos):\n \"\"\"\n Encode video features to video embeddings.\n\n Args:\n videos: (batch_size, n_frames, n_channels, height, width)\n\n Returns:\n video_embeddings: (batch_size, n_frames, embedding_dim)\n \"\"\"\n # simple image encoding without temporal embedding and self attention\n videos = videos.view(-1, videos.size(-3), videos.size(-2), videos.size(-1)) # pixel_values (torch.FloatTensor of shape (batch_size * n_frames, num_channels, height, width))\n video_features = self.video_encoder.visual_projection(self.video_encoder.vision_model(videos)[0])[:, 1:, :]\n video_features = video_features.reshape(\n videos.size(0) // self.config.n_frames,\n self.config.n_frames * video_features.size(1),\n -1).contiguous()\n\n return video_features\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.MM_LLMs.prepare_inputs_for_generation","title":"prepare_inputs_for_generation(inputs)
","text":"The purpose of this method is to integrate the different modalities into the text embeddings and prepare the associated attention mask and labels for the language model, so the model can generate text conditioned on all the input modalities.
(!!! my hypothesis) video_frames: (B x F) audios: B x 1 images: B x 1 input_ids: B x L attention_mask: B x L labels: B x L video_starts: B x 1 video_ends: B x 1 audio_starts: B x 1 audio_ends: B x 1 image_starts: B x 1 image_ends: B x 1 inference: True/False
Source code in src/aeiva/model/macaw_model_old.py
def prepare_inputs_for_generation(self, inputs):\n \"\"\"\n The purpose of this method is to integrate the different modalities into the text embeddings \n and prepare the associated attention mask and labels for the language model, so the model can \n generate text conditioned on all the input modalities.\n\n inputs is a dictionary containing the following keys: (!!! my hypothesis)\n video_frames: (B x F)\n audios: B x 1\n images: B x 1\n input_ids: B x L\n attention_mask: B x L\n labels: B x L\n video_starts: B x 1\n video_ends: B x 1\n audio_starts: B x 1\n audio_ends: B x 1\n image_starts: B x 1\n image_ends: B x 1\n inference: True/False\n \"\"\"\n # get multimodal embeddings\n image_features = self.encode_image(inputs['images']) if inputs['images'] is not None else None\n audio_features = self.encode_audio(inputs['audios']) if inputs['audios'] is not None else None\n video_features = self.encode_video(inputs['videos']) if inputs['videos'] is not None else None\n embed_tokens = self.llm.model.embed_tokens\n\n\n # for debug !!!!!!\n # Find maximum id in input_ids\n max_id = torch.max(inputs['input_ids'])\n print(f\"Max ID in input_ids: {max_id.item()}\")\n\n # Get vocab size from embedding layer\n vocab_size = embed_tokens.num_embeddings\n print(f\"Vocabulary size: {vocab_size}\")\n\n\n\n text_embeddings = embed_tokens(inputs['input_ids'])\n\n token_embeddings = embed_tokens.weight.unsqueeze(0).repeat(\n text_embeddings.size(0), 1, 1).transpose(0, 1)\n\n # ignore_num seems to be a counter that tracks the total size (or length) of the \n # multimodal input segments (video, audio, image) added to the original text inputs.\n ingore_num = 0\n\n # project and merge video features to the same space as text embeddings\n if video_features is not None:\n # get video starts and ends embeddings\n video_starts = embed_tokens(inputs['video_starts']).unsqueeze(1)\n video_ends = embed_tokens(inputs['video_ends']).unsqueeze(1)\n\n # project video features to the same space as text embeddings\n video_features = self.transform_video_to_hidden(video_features)\n\n video_features = self.video_align_attention(\n video_features.transpose(0, 1), token_embeddings, token_embeddings)[0].transpose(0, 1).contiguous()\n\n # concatenate video starts, video features, and video ends embeddings\n video_inputs = torch.cat([torch.cat([video_starts, video_features], dim=1), video_ends], dim=1)\n\n # concatenate video inputs to the original text embeddings\n # NOTE: the first token of text_embeddings keeps at the same position\n text_embeddings = torch.cat([torch.cat([text_embeddings[:, 0, :].unsqueeze(1), video_inputs], dim=1), text_embeddings[:, 1:, :]], dim=1)\n\n ingore_num += (video_inputs.size(1))\n\n # project and merge audio features to the same space as text embeddings\n if audio_features is not None:\n # get audio starts and ends embeddings\n audio_starts = embed_tokens(inputs['audio_starts']).unsqueeze(1)\n audio_ends = embed_tokens(inputs['audio_ends']).unsqueeze(1)\n\n # project audio features to the same space as text embeddings\n audio_features = self.project_audio(audio_features.transpose(1, 2).contiguous()).transpose(1, 2).contiguous()\n audio_features = self.transform_audio_to_hidden(audio_features)\n # mean pooling\n # audio_features = torch.sum(audio_features, dim=1) / audio_features.size(1) \n # audio_features = audio_features.unsqueeze(1)\n audio_features = self.audio_align_attention(\n audio_features.transpose(0, 1), token_embeddings, token_embeddings)[0].transpose(0, 1).contiguous()\n\n # concatenate audio starts, audio features, and audio ends embeddings\n audio_inputs = torch.cat([torch.cat([audio_starts, audio_features], dim=1), audio_ends], dim=1)\n\n # concatenate audio inputs to the original text embeddings\n text_embeddings = torch.cat(\n [torch.cat([text_embeddings[:, 0, :].unsqueeze(1), audio_inputs], dim=1), text_embeddings[:, 1:, :]],\n dim=1)\n\n ingore_num += (audio_inputs.size(1))\n\n # project and merge image features to the same space as text embeddings\n if image_features is not None:\n # get image starts and ends embeddings\n image_starts = embed_tokens(inputs['image_starts']).unsqueeze(1)\n image_ends = embed_tokens(inputs['image_ends']).unsqueeze(1)\n\n # project image features to the same space as text embeddings\n image_features = self.project_image(image_features.transpose(1, 2).contiguous()).transpose(1, 2).contiguous()\n image_features = self.transform_image_to_hidden(image_features)\n image_features = self.image_align_attention(\n image_features.transpose(0, 1), token_embeddings, token_embeddings)[0].transpose(0, 1).contiguous()\n\n # concatenate image starts, image features, and image ends embeddings\n image_inputs = torch.cat([torch.cat([image_starts, image_features], dim=1), image_ends], dim=1)\n\n # concatenate image inputs to the original text embeddings\n text_embeddings = torch.cat(\n [torch.cat([text_embeddings[:, 0, :].unsqueeze(1), image_inputs], dim=1), \n text_embeddings[:, 1:, :]], dim=1)\n\n ingore_num += (image_inputs.size(1))\n\n if 'attention_mask' in inputs:\n # increase the length of attention mask by adding the length of multimodal inputs\n attention_mask = torch.tensor([1]*ingore_num*text_embeddings.size(0), device=text_embeddings.device).view(text_embeddings.size(0), -1) # (B X ignore_num)\n attention_mask = torch.cat([attention_mask, inputs['attention_mask']], dim=1)\n else:\n attention_mask = None\n\n if 'labels' in inputs and inputs['labels'] is not None:\n # increase the length of labels by adding the length of labels\n # we use -100 to ignore the loss of labels in multimodal inputs\n # !!! we can replace -100 by config constants to make the code better\n\n # since the tokens corresponding to the image_inputs, audio_inputs, and video_inputs are not part of the original text \n # and don't have corresponding labels in the true text sequence, their labels are set to -100. This ensures that \n # the model's predictions for these tokens don't affect the loss and, consequently, the gradients and the model's subsequent learning.\n labels = torch.tensor([-100]*ingore_num*text_embeddings.size(0), device=text_embeddings.device).view(text_embeddings.size(0), -1)\n labels = torch.cat([labels, inputs['labels']], dim=1)\n else:\n labels = None\n\n # text_embeddings: (batch_size, sequence_length + ingore_num, embedding_dim)\n # attention_mask: (batch_size, sequence_length + ingore_num). 1 denotes we should attend to, and 0 denotes we should not attend to the token.\n # labels: (batch_size, sequence_length + ingore_num). -100 denotes we should ignore the token.\n return text_embeddings, attention_mask, labels\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.MM_LLMs_Config","title":"MM_LLMs_Config
","text":" Bases: PretrainedConfig
This is the configuration class to store the configuration of a MM_LLMsModel
. It contains class level and instance level attributes. It also contains the load (from_pretrained) and save (to_dict) methods for saving and loading configuration files.
Source code in src/aeiva/model/macaw_model_old.py
class MM_LLMs_Config(PretrainedConfig):\n \"\"\"\n This is the configuration class to store the configuration of a `MM_LLMsModel`.\n It contains class level and instance level attributes.\n It also contains the load (from_pretrained) and save (to_dict) methods for saving and loading configuration files.\n \"\"\"\n # general class attributes for all model instances\n model_type = 'mm_llms'\n is_composition = True\n\n def __init__(self, n_frames=6, attention_heads=8, clip_config=None, whisper_config=None, llm_config=None, **kwargs):\n self.image_config = clip_config\n self.audio_config = whisper_config\n self.llm_config = llm_config # language model config\n self.n_frames = n_frames # video config information. How many frames are used for each video clip.\n self.attention_heads = attention_heads\n self.hidden_size = max(llm_config.hidden_size, clip_config.projection_dim, whisper_config.d_model, clip_config.projection_dim)\n super().__init__(**kwargs)\n\n def to_dict(self):\n \"\"\"\n Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].\n This method overrides the base class method to include serialization of the \n image, audio, and language model configurations along with the base configuration.\n\n Returns:\n `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,\n \"\"\"\n output = copy.deepcopy(self.__dict__)\n output[\"image_config\"] = self.image_config.to_dict()\n output[\"audio_config\"] = self.audio_config.to_dict()\n output['llm_config'] = self.llm_config.to_dict()\n output['n_frames'] = self.n_frames\n output['attention_heads'] = self.attention_heads\n output['hidden_size'] = self.hidden_size\n output[\"model_type\"] = self.__class__.model_type\n return output\n\n @classmethod\n def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):\n config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)\n\n clip_config = CLIPConfig.from_dict(config_dict['image_config'])\n whisper_config = WhisperConfig.from_dict(config_dict['audio_config'])\n llm_config = LlamaConfig.from_dict(config_dict['llm_config'])\n\n return cls(clip_config=clip_config, whisper_config=whisper_config, llm_config=llm_config, **kwargs)\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.MM_LLMs_Config.to_dict","title":"to_dict()
","text":"Serializes this instance to a Python dictionary. Override the default [~PretrainedConfig.to_dict
]. This method overrides the base class method to include serialization of the image, audio, and language model configurations along with the base configuration.
Returns:
Type Description Dict[str, any]
: Dictionary of all the attributes that make up this configuration instance,
Source code in src/aeiva/model/macaw_model_old.py
def to_dict(self):\n \"\"\"\n Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].\n This method overrides the base class method to include serialization of the \n image, audio, and language model configurations along with the base configuration.\n\n Returns:\n `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,\n \"\"\"\n output = copy.deepcopy(self.__dict__)\n output[\"image_config\"] = self.image_config.to_dict()\n output[\"audio_config\"] = self.audio_config.to_dict()\n output['llm_config'] = self.llm_config.to_dict()\n output['n_frames'] = self.n_frames\n output['attention_heads'] = self.attention_heads\n output['hidden_size'] = self.hidden_size\n output[\"model_type\"] = self.__class__.model_type\n return output\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.WhisperEncoder","title":"WhisperEncoder
","text":" Bases: WhisperPreTrainedModel
Transformer encoder consisting of config.encoder_layers self attention layers. Each layer is a [WhisperEncoderLayer
].
Parameters:
Name Type Description Default config
WhisperConfig
WhisperConfig
required Source code in src/aeiva/model/macaw_model_old.py
class WhisperEncoder(WhisperPreTrainedModel):\n \"\"\"\n Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a\n [`WhisperEncoderLayer`].\n\n Args:\n config: WhisperConfig\n \"\"\"\n\n def __init__(self, config: WhisperConfig):\n super().__init__(config)\n self.dropout = config.dropout\n self.layerdrop = config.encoder_layerdrop\n\n embed_dim = config.d_model\n # num_mel_bins corresponds to the number of features extracted from the audio signal for each time step. \n # When we convert audio to a Mel spectrogram, each time step (or frame) in the spectrogram \n # is represented by a feature vector of size num_mel_bins. \n self.num_mel_bins = config.num_mel_bins\n self.padding_idx = config.pad_token_id\n self.max_source_positions = config.max_source_positions\n # embed_scale is a scaling factor that is applied to the embeddings.\n self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0\n\n self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)\n self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)\n\n # position embedding layer\n self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)\n\n self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)])\n self.layer_norm = nn.LayerNorm(config.d_model)\n\n self.gradient_checkpointing = False\n # Initialize weights and apply final processing\n self.post_init()\n\n def _freeze_parameters(self):\n for param in self.parameters():\n param.requires_grad = False\n self._requires_grad = False\n\n def get_input_embeddings(self) -> nn.Module:\n return self.conv1\n\n def set_input_embeddings(self, value: nn.Module):\n self.conv1 = value\n\n def forward(\n self,\n input_features,\n attention_mask=None,\n head_mask=None,\n output_attentions=None,\n output_hidden_states=None,\n return_dict=None,\n ):\n r\"\"\"\n Args:\n input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):\n Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be\n obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a\n `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into\n `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding\n and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]\n attention_mask (`torch.Tensor`)`, *optional*):\n Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,\n but it is not used. By default the silence in the input log mel spectrogram are ignored.\n head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):\n Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:\n\n - 1 indicates the head is **not masked**,\n - 0 indicates the head is **masked**.\n output_attentions (`bool`, *optional*):\n Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n returned tensors for more detail.\n output_hidden_states (`bool`, *optional*):\n Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors\n for more detail.\n return_dict (`bool`, *optional*):\n Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n \"\"\"\n # set output flags\n output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n # embed audio features\n # input_features shape: (batch_size, feature_size, sequence_length)\n inputs_embeds = nn.functional.gelu(self.conv1(input_features)) # (batch_size, embed_dim, sequence_length)\n inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) # (batch_size, embed_dim, sequence_length/2), because the stride is 2. Downsampling by 2.\n inputs_embeds = inputs_embeds.permute(0, 2, 1) # (batch_size, sequence_length/2, embed_dim)\n embed_pos = self.embed_positions.weight # (max_source_positions, embed_dim)\n\n # add position embedding to audio features embedding\n # !!!: Do max_source_positions and sequence_length/2 must be the same??? Kind of confusing.\n hidden_states = inputs_embeds + embed_pos\n hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)\n\n encoder_states = () if output_hidden_states else None\n all_attentions = () if output_attentions else None\n\n # check if head_mask has a correct number of layers specified if desired\n if head_mask is not None:\n assert head_mask.size()[0] == (\n len(self.layers)\n ), f\"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}.\"\n\n # go through the whisper encoder layers to get the hidden states and attentions in all layers\n for idx, encoder_layer in enumerate(self.layers):\n if output_hidden_states:\n encoder_states = encoder_states + (hidden_states,)\n # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)\n dropout_probability = random.uniform(0, 1)\n if self.training and (dropout_probability < self.layerdrop): # skip the layer\n layer_outputs = (None, None)\n else:\n if self.gradient_checkpointing and self.training:\n\n def create_custom_forward(module):\n def custom_forward(*inputs):\n return module(*inputs, output_attentions)\n\n return custom_forward\n\n layer_outputs = torch.utils.checkpoint.checkpoint(\n create_custom_forward(encoder_layer),\n hidden_states,\n None,\n (head_mask[idx] if head_mask is not None else None),\n )\n else:\n # The layer_outputs is a tuple of (hidden_states, attention).\n # The attention is None if output_attentions is False.\n # hidden_states shape: (batch_size, sequence_length/2, embed_dim), as stride is 2 in the self.conv2\n # attention shape: (batch_size, num_heads, sequence_length/2, sequence_length/2)\n layer_outputs = encoder_layer(\n hidden_states,\n None,\n layer_head_mask=(head_mask[idx] if head_mask is not None else None),\n output_attentions=output_attentions,\n )\n\n hidden_states = layer_outputs[0]\n\n if output_attentions:\n all_attentions = all_attentions + (layer_outputs[1],)\n\n hidden_states = self.layer_norm(hidden_states)\n if output_hidden_states:\n encoder_states = encoder_states + (hidden_states,)\n\n # output\n if not return_dict:\n return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)\n return BaseModelOutput(\n last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions\n )\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.WhisperEncoder.forward","title":"forward(input_features, attention_mask=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None)
","text":"Parameters:
Name Type Description Default input_features
`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`
Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by loading a .flac
or .wav
audio file into an array of type List[float]
or a numpy.ndarray
, e.g. via the soundfile library (pip install soundfile
). To prepare the array into input_features
, the [AutoFeatureExtractor
] should be used for extracting the mel features, padding and conversion into a tensor of type torch.FloatTensor
. See [~WhisperFeatureExtractor.__call__
]
required attention_mask
`torch.Tensor`)`, *optional*
Whisper does not support masking of the input_features
, this argument is preserved for compatibility, but it is not used. By default the silence in the input log mel spectrogram are ignored.
None
head_mask
`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*
Mask to nullify selected heads of the attention modules. Mask values selected in [0, 1]
:
- 1 indicates the head is not masked,
- 0 indicates the head is masked.
None
output_attentions
`bool`, *optional*
Whether or not to return the attentions tensors of all attention layers. See attentions
under returned tensors for more detail.
None
output_hidden_states
`bool`, *optional*
Whether or not to return the hidden states of all layers. See hidden_states
under returned tensors for more detail.
None
return_dict
`bool`, *optional*
Whether or not to return a [~utils.ModelOutput
] instead of a plain tuple.
None
Source code in src/aeiva/model/macaw_model_old.py
def forward(\n self,\n input_features,\n attention_mask=None,\n head_mask=None,\n output_attentions=None,\n output_hidden_states=None,\n return_dict=None,\n):\n r\"\"\"\n Args:\n input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):\n Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be\n obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a\n `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into\n `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding\n and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]\n attention_mask (`torch.Tensor`)`, *optional*):\n Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,\n but it is not used. By default the silence in the input log mel spectrogram are ignored.\n head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):\n Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:\n\n - 1 indicates the head is **not masked**,\n - 0 indicates the head is **masked**.\n output_attentions (`bool`, *optional*):\n Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n returned tensors for more detail.\n output_hidden_states (`bool`, *optional*):\n Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors\n for more detail.\n return_dict (`bool`, *optional*):\n Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n \"\"\"\n # set output flags\n output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n # embed audio features\n # input_features shape: (batch_size, feature_size, sequence_length)\n inputs_embeds = nn.functional.gelu(self.conv1(input_features)) # (batch_size, embed_dim, sequence_length)\n inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) # (batch_size, embed_dim, sequence_length/2), because the stride is 2. Downsampling by 2.\n inputs_embeds = inputs_embeds.permute(0, 2, 1) # (batch_size, sequence_length/2, embed_dim)\n embed_pos = self.embed_positions.weight # (max_source_positions, embed_dim)\n\n # add position embedding to audio features embedding\n # !!!: Do max_source_positions and sequence_length/2 must be the same??? Kind of confusing.\n hidden_states = inputs_embeds + embed_pos\n hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)\n\n encoder_states = () if output_hidden_states else None\n all_attentions = () if output_attentions else None\n\n # check if head_mask has a correct number of layers specified if desired\n if head_mask is not None:\n assert head_mask.size()[0] == (\n len(self.layers)\n ), f\"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}.\"\n\n # go through the whisper encoder layers to get the hidden states and attentions in all layers\n for idx, encoder_layer in enumerate(self.layers):\n if output_hidden_states:\n encoder_states = encoder_states + (hidden_states,)\n # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)\n dropout_probability = random.uniform(0, 1)\n if self.training and (dropout_probability < self.layerdrop): # skip the layer\n layer_outputs = (None, None)\n else:\n if self.gradient_checkpointing and self.training:\n\n def create_custom_forward(module):\n def custom_forward(*inputs):\n return module(*inputs, output_attentions)\n\n return custom_forward\n\n layer_outputs = torch.utils.checkpoint.checkpoint(\n create_custom_forward(encoder_layer),\n hidden_states,\n None,\n (head_mask[idx] if head_mask is not None else None),\n )\n else:\n # The layer_outputs is a tuple of (hidden_states, attention).\n # The attention is None if output_attentions is False.\n # hidden_states shape: (batch_size, sequence_length/2, embed_dim), as stride is 2 in the self.conv2\n # attention shape: (batch_size, num_heads, sequence_length/2, sequence_length/2)\n layer_outputs = encoder_layer(\n hidden_states,\n None,\n layer_head_mask=(head_mask[idx] if head_mask is not None else None),\n output_attentions=output_attentions,\n )\n\n hidden_states = layer_outputs[0]\n\n if output_attentions:\n all_attentions = all_attentions + (layer_outputs[1],)\n\n hidden_states = self.layer_norm(hidden_states)\n if output_hidden_states:\n encoder_states = encoder_states + (hidden_states,)\n\n # output\n if not return_dict:\n return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)\n return BaseModelOutput(\n last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions\n )\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.rotate_half","title":"rotate_half(x)
","text":"Rotates half the hidden dims of the input.
Source code in src/aeiva/model/macaw_model_old.py
def rotate_half(x):\n \"\"\"Rotates half the hidden dims of the input.\"\"\"\n x1 = x[..., : x.shape[-1] // 2]\n x2 = x[..., x.shape[-1] // 2 :]\n return torch.cat((-x2, x1), dim=-1)\n
"},{"location":"reference/#src.aeiva.operator","title":"operator
","text":""},{"location":"reference/#src.aeiva.operator.custom_ops","title":"custom_ops
","text":""},{"location":"reference/#src.aeiva.operator.custom_ops.macaw_dataitem_ops","title":"macaw_dataitem_ops
","text":"This module contains the data item processing functions.
For a data item processing function, it takes a data example (a dict) as input and return a processed data example.
@Author: Bang Liu (chatsci.ai@gmail.com) @Date: 2023-07-11
Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.
"},{"location":"reference/#src.aeiva.operator.dataitem_ops","title":"dataitem_ops
","text":"This module contains the data item processing functions.
For a data item processing function, it takes a data example (a dict) as input and return a processed data example.
@Author: Bang Liu (chatsci.ai@gmail.com) @Date: 2023-07-11
Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.
"},{"location":"reference/#src.aeiva.operator.dataset_ops","title":"dataset_ops
","text":"This module contains the utils for processing datasets.
A dataset in aeiva is a dictionary with the following structure: { \"data\": [ {sample1}, {sample2}, ..., {sampleN} ], \"metadata\": { \"num_samples\": XX, ... } } where each sample is a dictionary itself, and metadata is a dictionary that contains the number of samples and possibly other fields.
@Author: Bang Liu (chatsci.ai@gmail.com) @Date: 2023-07-13
Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.
"},{"location":"reference/#src.aeiva.operator.dataset_ops.build_and_merge_datasets","title":"build_and_merge_datasets(dataset_names, input_filepaths_dict, pipeline, output_dir, max_samples=sys.maxsize)
","text":"Build multiple datasets by formatting and processing them.
Source code in src/aeiva/operator/dataset_ops.py
def build_and_merge_datasets(dataset_names: list[str],\n input_filepaths_dict: dict[str, str],\n pipeline: list[Callable],\n output_dir: Optional[str],\n max_samples: Optional[int] = sys.maxsize) -> DataSet:\n r\"\"\" Build multiple datasets by formatting and processing them.\n \"\"\"\n merged_datasets = []\n for dataset_name in dataset_names:\n dataset = build_dataset(dataset_name, input_filepaths_dict, pipeline, output_dir, max_samples)\n merged_datasets.append(dataset)\n result = merge_datasets(merged_datasets)\n return result\n
"},{"location":"reference/#src.aeiva.operator.dataset_ops.build_dataset","title":"build_dataset(dataset_name, input_filepaths_dict, pipeline, output_dir, max_samples=sys.maxsize)
","text":"Build a dataset by formatting and processing it.
Source code in src/aeiva/operator/dataset_ops.py
def build_dataset(dataset_name: str,\n input_filepaths_dict: dict[str, str],\n pipeline: list[Callable],\n output_dir: Optional[str],\n max_samples: Optional[int] = sys.maxsize) -> DataSet:\n r\"\"\" Build a dataset by formatting and processing it.\n \"\"\"\n operator_type = 'data_formatter'\n format_func = OPERATORS[operator_type][dataset_name]\n formatted_dataset = format_func(input_filepaths_dict, output_dir, max_samples)\n processed_dataset = process_dataset(formatted_dataset, pipeline, output_dir, dataset_name)\n print(f\"Completed processing dataset: {dataset_name} (output_dir: {output_dir})\")\n return processed_dataset\n
"},{"location":"reference/#src.aeiva.operator.dataset_ops.filter_dataset","title":"filter_dataset(dataset, filter_criteria, *args, **kwargs)
","text":"Filter a dataset by a filter function.
Source code in src/aeiva/operator/dataset_ops.py
def filter_dataset(dataset: DataSet, filter_criteria: str, *args, **kwargs) -> DataSet:\n r\"\"\" Filter a dataset by a filter function.\n \"\"\"\n operator_type = 'data_filter'\n filter_func = OPERATORS[operator_type][filter_criteria]\n filtered_data = filter_func(dataset, *args, **kwargs)\n return filtered_data\n
"},{"location":"reference/#src.aeiva.operator.dataset_ops.filter_dataset_by_keys","title":"filter_dataset_by_keys(dataset, keys_to_preserve)
","text":"Filter the dataset to only include specified keys in each sample.
Source code in src/aeiva/operator/dataset_ops.py
@register_data_filter(\"filter_dataset_by_keys\")\ndef filter_dataset_by_keys(dataset: DataSet, keys_to_preserve: list[str]) -> DataSet:\n r\"\"\" Filter the dataset to only include specified keys in each sample.\n \"\"\"\n filtered_data = []\n for sample in dataset[\"data\"]:\n for key in keys_to_preserve:\n if key not in sample:\n raise KeyError(f\"Key {key} not found in sample\")\n filtered_sample = {key: sample[key] for key in keys_to_preserve if key in sample}\n filtered_data.append(filtered_sample)\n return {\"data\": filtered_data, \"metadata\": dataset[\"metadata\"]}\n
"},{"location":"reference/#src.aeiva.operator.dataset_ops.merge_datasets","title":"merge_datasets(datasets)
","text":"Merge multiple datasets into one.
Source code in src/aeiva/operator/dataset_ops.py
def merge_datasets(datasets: list[DataSet]) -> DataSet:\n r\"\"\" Merge multiple datasets into one.\n \"\"\"\n merged_data = []\n total_samples = 0\n for dataset in datasets:\n merged_data.extend(dataset[\"data\"])\n total_samples += dataset[\"metadata\"][\"num_samples\"]\n result = {\"data\": merged_data, \"metadata\": {\"num_samples\": total_samples}}\n return result\n
"},{"location":"reference/#src.aeiva.operator.dataset_ops.sample_dataset","title":"sample_dataset(dataset, n_samples)
","text":"Sample a number of samples from a dataset.
Source code in src/aeiva/operator/dataset_ops.py
def sample_dataset(dataset: DataSet, n_samples: int) -> DataSet:\n r\"\"\" Sample a number of samples from a dataset.\n \"\"\"\n random_indices = random.sample(range(dataset[\"metadata\"][\"num_samples\"]), n_samples)\n sampled_data = [dataset[\"data\"][i] for i in random_indices]\n return {\"data\": sampled_data, \"metadata\": {\"num_samples\": n_samples}}\n
"},{"location":"reference/#src.aeiva.operator.dataset_ops.save_dataset","title":"save_dataset(dataset, output_path)
","text":"Save a dataset to a file by pickling it.
Source code in src/aeiva/operator/dataset_ops.py
def save_dataset(dataset: DataSet, output_path: str) -> None:\n r\"\"\" Save a dataset to a file by pickling it.\n \"\"\"\n ensure_dir(output_path)\n pickle.dump(dataset, open(output_path, \"wb\"), protocol=4)\n
"},{"location":"reference/#src.aeiva.operator.dataset_ops.split_dataset","title":"split_dataset(dataset, train_ratio, seed=42)
","text":"Split a dataset into a training set and a validation set.
Source code in src/aeiva/operator/dataset_ops.py
def split_dataset(dataset: dict, train_ratio: float, seed: int = 42) -> Tuple[dict]:\n r\"\"\" Split a dataset into a training set and a validation set.\n \"\"\"\n np.random.seed(seed) # ensures the function is deterministic\n\n data = dataset[\"data\"]\n metadata = dataset[\"metadata\"]\n\n # Create a permutation of indices and shuffle the data.\n perm = np.random.permutation(len(data))\n shuffled_data = [data[i] for i in perm]\n\n # Calculate split index\n split_idx = int(train_ratio * len(shuffled_data))\n\n # Split the shuffled data\n train_data = shuffled_data[:split_idx]\n val_data = shuffled_data[split_idx:]\n\n # Create metadata for training and validation datasets\n train_metadata = metadata.copy()\n train_metadata[\"num_samples\"] = len(train_data)\n val_metadata = metadata.copy()\n val_metadata[\"num_samples\"] = len(val_data)\n\n # Create training and validation datasets\n train_dataset = {\"data\": train_data, \"metadata\": train_metadata}\n val_dataset = {\"data\": val_data, \"metadata\": val_metadata}\n\n return train_dataset, val_dataset\n
"},{"location":"reference/#src.aeiva.perception","title":"perception
","text":""},{"location":"reference/#src.aeiva.perception.base_perception_system","title":"base_perception_system
","text":""},{"location":"reference/#src.aeiva.perception.base_perception_system.PerceptionSystem","title":"PerceptionSystem
","text":" Bases: ABC
Abstract base class representing the Perception System of an agent.
The Perception System is responsible for capturing raw sensory data from the environment, processing this data into meaningful observations, and providing access to these observations for other components of the cognitive architecture.
Attributes:
Name Type Description config
Any
Configuration settings for the Perception System.
state
Any
The internal state of the Perception System, including raw data and observations.
Source code in src/aeiva/perception/base_perception_system.py
class PerceptionSystem(ABC):\n \"\"\"\n Abstract base class representing the Perception System of an agent.\n\n The Perception System is responsible for capturing raw sensory data from the environment,\n processing this data into meaningful observations, and providing access to these observations\n for other components of the cognitive architecture.\n\n Attributes:\n config (Any): Configuration settings for the Perception System.\n state (Any): The internal state of the Perception System, including raw data and observations.\n \"\"\"\n\n def __init__(self, config: Any):\n \"\"\"\n Initialize the Perception System with the provided configuration.\n\n Args:\n config (Any): Configuration settings for the Perception System.\n \"\"\"\n self.config = config\n self.state = self.init_state()\n\n @abstractmethod\n def init_state(self) -> Any:\n \"\"\"\n Initialize the internal state of the Perception System.\n\n This method should set up the initial state required for the Perception System's operations.\n\n Returns:\n Any: The initial state of the Perception System.\n \"\"\"\n pass\n\n @abstractmethod\n async def setup(self) -> None:\n \"\"\"\n Asynchronously set up the Perception System's components.\n\n This method should initialize any necessary components or resources based on the provided configuration.\n\n Raises:\n ConfigurationError: If the configuration is invalid or incomplete.\n \"\"\"\n pass\n\n @abstractmethod\n async def capture(self, raw_data: Any) -> None:\n \"\"\"\n Asynchronously capture raw sensory data from the environment.\n\n Args:\n raw_data (Any): The raw sensory data to capture.\n\n Raises:\n CaptureError: If capturing the raw data fails.\n \"\"\"\n pass\n\n @abstractmethod\n async def process(self) -> None:\n \"\"\"\n Asynchronously process the captured raw sensory data into meaningful observations.\n\n This method should transform raw data stored in the internal state into structured observations\n that can be utilized by other components of the cognitive architecture.\n\n Raises:\n ProcessingError: If processing the raw data fails.\n \"\"\"\n pass\n\n async def perceive(self, raw_data: Any) -> None:\n \"\"\"\n Asynchronously perform the full perception cycle: capture and process raw sensory data.\n\n Args:\n raw_data (Any): The raw sensory data to perceive.\n\n Raises:\n CaptureError: If capturing the raw data fails.\n ProcessingError: If processing the raw data fails.\n \"\"\"\n try:\n await self.capture(raw_data)\n await self.process()\n except Exception as e:\n self.handle_error(e)\n raise e\n\n def get_observations(self) -> Any:\n \"\"\"\n Retrieve the current processed observations from the Perception System.\n\n Returns:\n Any: The current observations.\n \"\"\"\n return self.state.get(\"observations\", None)\n\n def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during perception operations.\n\n This method can be overridden to implement custom error handling logic, such as logging\n or retry mechanisms.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n # Default error handling: log the error\n print(f\"PerceptionSystem encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.perception.base_perception_system.PerceptionSystem.__init__","title":"__init__(config)
","text":"Initialize the Perception System with the provided configuration.
Parameters:
Name Type Description Default config
Any
Configuration settings for the Perception System.
required Source code in src/aeiva/perception/base_perception_system.py
def __init__(self, config: Any):\n \"\"\"\n Initialize the Perception System with the provided configuration.\n\n Args:\n config (Any): Configuration settings for the Perception System.\n \"\"\"\n self.config = config\n self.state = self.init_state()\n
"},{"location":"reference/#src.aeiva.perception.base_perception_system.PerceptionSystem.capture","title":"capture(raw_data)
abstractmethod
async
","text":"Asynchronously capture raw sensory data from the environment.
Parameters:
Name Type Description Default raw_data
Any
The raw sensory data to capture.
required Raises:
Type Description CaptureError
If capturing the raw data fails.
Source code in src/aeiva/perception/base_perception_system.py
@abstractmethod\nasync def capture(self, raw_data: Any) -> None:\n \"\"\"\n Asynchronously capture raw sensory data from the environment.\n\n Args:\n raw_data (Any): The raw sensory data to capture.\n\n Raises:\n CaptureError: If capturing the raw data fails.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.perception.base_perception_system.PerceptionSystem.get_observations","title":"get_observations()
","text":"Retrieve the current processed observations from the Perception System.
Returns:
Name Type Description Any
Any
The current observations.
Source code in src/aeiva/perception/base_perception_system.py
def get_observations(self) -> Any:\n \"\"\"\n Retrieve the current processed observations from the Perception System.\n\n Returns:\n Any: The current observations.\n \"\"\"\n return self.state.get(\"observations\", None)\n
"},{"location":"reference/#src.aeiva.perception.base_perception_system.PerceptionSystem.handle_error","title":"handle_error(error)
","text":"Handle errors that occur during perception operations.
This method can be overridden to implement custom error handling logic, such as logging or retry mechanisms.
Parameters:
Name Type Description Default error
Exception
The exception that was raised.
required Source code in src/aeiva/perception/base_perception_system.py
def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during perception operations.\n\n This method can be overridden to implement custom error handling logic, such as logging\n or retry mechanisms.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n # Default error handling: log the error\n print(f\"PerceptionSystem encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.perception.base_perception_system.PerceptionSystem.init_state","title":"init_state()
abstractmethod
","text":"Initialize the internal state of the Perception System.
This method should set up the initial state required for the Perception System's operations.
Returns:
Name Type Description Any
Any
The initial state of the Perception System.
Source code in src/aeiva/perception/base_perception_system.py
@abstractmethod\ndef init_state(self) -> Any:\n \"\"\"\n Initialize the internal state of the Perception System.\n\n This method should set up the initial state required for the Perception System's operations.\n\n Returns:\n Any: The initial state of the Perception System.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.perception.base_perception_system.PerceptionSystem.perceive","title":"perceive(raw_data)
async
","text":"Asynchronously perform the full perception cycle: capture and process raw sensory data.
Parameters:
Name Type Description Default raw_data
Any
The raw sensory data to perceive.
required Raises:
Type Description CaptureError
If capturing the raw data fails.
ProcessingError
If processing the raw data fails.
Source code in src/aeiva/perception/base_perception_system.py
async def perceive(self, raw_data: Any) -> None:\n \"\"\"\n Asynchronously perform the full perception cycle: capture and process raw sensory data.\n\n Args:\n raw_data (Any): The raw sensory data to perceive.\n\n Raises:\n CaptureError: If capturing the raw data fails.\n ProcessingError: If processing the raw data fails.\n \"\"\"\n try:\n await self.capture(raw_data)\n await self.process()\n except Exception as e:\n self.handle_error(e)\n raise e\n
"},{"location":"reference/#src.aeiva.perception.base_perception_system.PerceptionSystem.process","title":"process()
abstractmethod
async
","text":"Asynchronously process the captured raw sensory data into meaningful observations.
This method should transform raw data stored in the internal state into structured observations that can be utilized by other components of the cognitive architecture.
Raises:
Type Description ProcessingError
If processing the raw data fails.
Source code in src/aeiva/perception/base_perception_system.py
@abstractmethod\nasync def process(self) -> None:\n \"\"\"\n Asynchronously process the captured raw sensory data into meaningful observations.\n\n This method should transform raw data stored in the internal state into structured observations\n that can be utilized by other components of the cognitive architecture.\n\n Raises:\n ProcessingError: If processing the raw data fails.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.perception.base_perception_system.PerceptionSystem.setup","title":"setup()
abstractmethod
async
","text":"Asynchronously set up the Perception System's components.
This method should initialize any necessary components or resources based on the provided configuration.
Raises:
Type Description ConfigurationError
If the configuration is invalid or incomplete.
Source code in src/aeiva/perception/base_perception_system.py
@abstractmethod\nasync def setup(self) -> None:\n \"\"\"\n Asynchronously set up the Perception System's components.\n\n This method should initialize any necessary components or resources based on the provided configuration.\n\n Raises:\n ConfigurationError: If the configuration is invalid or incomplete.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.perception.perception_system","title":"perception_system
","text":""},{"location":"reference/#src.aeiva.perception.perception_system.PerceptionSystem","title":"PerceptionSystem
","text":"Manages multiple sensors and emits stimuli via the EventBus.
Source code in src/aeiva/perception/perception_system.py
class PerceptionSystem:\n \"\"\"\n Manages multiple sensors and emits stimuli via the EventBus.\n \"\"\"\n def __init__(self, config: Dict, event_bus):\n \"\"\"\n Initializes the PerceptionSystem with a list of sensors.\n\n Args:\n config (Any): Configuration dictionary for the sensors.\n event_bus: The EventBus instance for emitting events.\n \"\"\"\n self.config = config\n self.event_bus = event_bus\n self.sensors: List[Sensor] = []\n self.logger = logging.getLogger('PerceptionSystem')\n\n def setup(self) -> None:\n \"\"\"\n Sets up the perception system by initializing all configured sensors.\n \"\"\"\n for sensor_config in self.config.get(\"sensors\", []):\n sensor_name = sensor_config.get(\"sensor_name\")\n sensor_params = sensor_config.get(\"sensor_params\", {})\n # TODO: revise later\n if sensor_name == 'percept_terminal_input':\n sensor = TerminalInputSensor(sensor_name, sensor_params, self.event_bus)\n self.sensors.append(sensor)\n else:\n self.logger.warning(f\"Unknown sensor type: {sensor_name}\")\n self.logger.info(\"PerceptionSystem setup complete.\")\n\n async def start(self) -> None: # TODO: maybe rename in the future\n \"\"\"\n Starts all sensors asynchronously.\n \"\"\"\n self.logger.info(\"Starting all sensors.\")\n for sensor in self.sensors:\n await sensor.start()\n\n async def stop(self) -> None:\n \"\"\"\n Stops all sensors asynchronously.\n \"\"\"\n self.logger.info(\"Stopping all sensors.\")\n for sensor in self.sensors:\n await sensor.stop()\n\n def signal_to_stimuli(self, data: Any) -> Any:\n \"\"\"\n Processes raw data from sensors into structured stimuli.\n\n Args:\n data: The raw data emitted by sensors.\n\n Returns:\n Processed data (stimuli).\n \"\"\"\n # Implement your data processing logic here\n signal = Signal(\n data=data,\n modularity=\"text\", # Or appropriate modality\n type=\"input\", # Or appropriate type\n # TODO: After revised Sensor class, Include other metadata as needed\n )\n stimuli = Stimuli(signals=[signal]) # TODO: add more fields\n return stimuli\n
"},{"location":"reference/#src.aeiva.perception.perception_system.PerceptionSystem.__init__","title":"__init__(config, event_bus)
","text":"Initializes the PerceptionSystem with a list of sensors.
Parameters:
Name Type Description Default config
Any
Configuration dictionary for the sensors.
required event_bus
The EventBus instance for emitting events.
required Source code in src/aeiva/perception/perception_system.py
def __init__(self, config: Dict, event_bus):\n \"\"\"\n Initializes the PerceptionSystem with a list of sensors.\n\n Args:\n config (Any): Configuration dictionary for the sensors.\n event_bus: The EventBus instance for emitting events.\n \"\"\"\n self.config = config\n self.event_bus = event_bus\n self.sensors: List[Sensor] = []\n self.logger = logging.getLogger('PerceptionSystem')\n
"},{"location":"reference/#src.aeiva.perception.perception_system.PerceptionSystem.setup","title":"setup()
","text":"Sets up the perception system by initializing all configured sensors.
Source code in src/aeiva/perception/perception_system.py
def setup(self) -> None:\n \"\"\"\n Sets up the perception system by initializing all configured sensors.\n \"\"\"\n for sensor_config in self.config.get(\"sensors\", []):\n sensor_name = sensor_config.get(\"sensor_name\")\n sensor_params = sensor_config.get(\"sensor_params\", {})\n # TODO: revise later\n if sensor_name == 'percept_terminal_input':\n sensor = TerminalInputSensor(sensor_name, sensor_params, self.event_bus)\n self.sensors.append(sensor)\n else:\n self.logger.warning(f\"Unknown sensor type: {sensor_name}\")\n self.logger.info(\"PerceptionSystem setup complete.\")\n
"},{"location":"reference/#src.aeiva.perception.perception_system.PerceptionSystem.signal_to_stimuli","title":"signal_to_stimuli(data)
","text":"Processes raw data from sensors into structured stimuli.
Parameters:
Name Type Description Default data
Any
The raw data emitted by sensors.
required Returns:
Type Description Any
Processed data (stimuli).
Source code in src/aeiva/perception/perception_system.py
def signal_to_stimuli(self, data: Any) -> Any:\n \"\"\"\n Processes raw data from sensors into structured stimuli.\n\n Args:\n data: The raw data emitted by sensors.\n\n Returns:\n Processed data (stimuli).\n \"\"\"\n # Implement your data processing logic here\n signal = Signal(\n data=data,\n modularity=\"text\", # Or appropriate modality\n type=\"input\", # Or appropriate type\n # TODO: After revised Sensor class, Include other metadata as needed\n )\n stimuli = Stimuli(signals=[signal]) # TODO: add more fields\n return stimuli\n
"},{"location":"reference/#src.aeiva.perception.perception_system.PerceptionSystem.start","title":"start()
async
","text":"Starts all sensors asynchronously.
Source code in src/aeiva/perception/perception_system.py
async def start(self) -> None: # TODO: maybe rename in the future\n \"\"\"\n Starts all sensors asynchronously.\n \"\"\"\n self.logger.info(\"Starting all sensors.\")\n for sensor in self.sensors:\n await sensor.start()\n
"},{"location":"reference/#src.aeiva.perception.perception_system.PerceptionSystem.stop","title":"stop()
async
","text":"Stops all sensors asynchronously.
Source code in src/aeiva/perception/perception_system.py
async def stop(self) -> None:\n \"\"\"\n Stops all sensors asynchronously.\n \"\"\"\n self.logger.info(\"Stopping all sensors.\")\n for sensor in self.sensors:\n await sensor.stop()\n
"},{"location":"reference/#src.aeiva.perception.sensation","title":"sensation
","text":""},{"location":"reference/#src.aeiva.perception.sensation.Signal","title":"Signal
","text":"Represents an atomic unit of perception that carries raw data from the environment. This class defines a signal, its characteristics, and its dependencies on other signals.
Source code in src/aeiva/perception/sensation.py
class Signal:\n \"\"\"\n Represents an atomic unit of perception that carries raw data from the environment.\n This class defines a signal, its characteristics, and its dependencies on other signals.\n \"\"\"\n\n def __init__(self, \n data: Any,\n name: Optional[str] = None, # Optional name for the signal\n modularity: Optional[str] = None,\n type: Optional[str] = None, # Renamed to avoid keyword conflict\n timestamp: Optional[datetime] = None,\n id: Optional[str] = None, # Optional unique identifier for the signal\n dependencies: Optional[Dict[str, Any]] = None, # Dependencies by other signal IDs with edge attributes\n description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n \"\"\"\n Initialize a signal with its data and other optional metadata.\n\n Args:\n data (Any): The raw data of the signal.\n name (Optional[str]): An optional name for the signal.\n modularity (Optional[str]): The modality of the signal (e.g., image, video, text, audio).\n type (Optional[str]): A more detailed signal type (e.g., 'text', 'document', etc.).\n timestamp (Optional[datetime]): The time when the signal was created or captured.\n id (Optional[str]): Unique identifier for the signal.\n dependencies (Optional[Dict[str, Any]]): Attributes of dependencies (e.g., relationship types).\n description (Optional[str]): Description of the signal.\n metadata (Optional[Dict[str, Any]]): Optional additional metadata for the signal.\n \"\"\"\n self.data = data\n self.name = name\n self.modularity = modularity\n self.type = type\n self.timestamp = timestamp or datetime.now()\n self.id = id\n self.dependencies = dependencies or {} # Edge attributes (could be string, embedding, etc.)\n self.description = description\n self.metadata = metadata or {}\n\n def to_dict(self) -> Dict[str, Any]:\n \"\"\"\n Converts the signal into a dictionary representation.\n \"\"\"\n return {\n \"data\": self.data,\n \"name\": self.name,\n \"modularity\": self.modularity,\n \"type\": self.type,\n \"timestamp\": self.timestamp,\n \"id\": self.id,\n \"dependencies\": self.dependencies,\n \"description\": self.description,\n \"metadata\": self.metadata\n }\n
"},{"location":"reference/#src.aeiva.perception.sensation.Signal.__init__","title":"__init__(data, name=None, modularity=None, type=None, timestamp=None, id=None, dependencies=None, description=None, metadata=None)
","text":"Initialize a signal with its data and other optional metadata.
Parameters:
Name Type Description Default data
Any
The raw data of the signal.
required name
Optional[str]
An optional name for the signal.
None
modularity
Optional[str]
The modality of the signal (e.g., image, video, text, audio).
None
type
Optional[str]
A more detailed signal type (e.g., 'text', 'document', etc.).
None
timestamp
Optional[datetime]
The time when the signal was created or captured.
None
id
Optional[str]
Unique identifier for the signal.
None
dependencies
Optional[Dict[str, Any]]
Attributes of dependencies (e.g., relationship types).
None
description
Optional[str]
Description of the signal.
None
metadata
Optional[Dict[str, Any]]
Optional additional metadata for the signal.
None
Source code in src/aeiva/perception/sensation.py
def __init__(self, \n data: Any,\n name: Optional[str] = None, # Optional name for the signal\n modularity: Optional[str] = None,\n type: Optional[str] = None, # Renamed to avoid keyword conflict\n timestamp: Optional[datetime] = None,\n id: Optional[str] = None, # Optional unique identifier for the signal\n dependencies: Optional[Dict[str, Any]] = None, # Dependencies by other signal IDs with edge attributes\n description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n \"\"\"\n Initialize a signal with its data and other optional metadata.\n\n Args:\n data (Any): The raw data of the signal.\n name (Optional[str]): An optional name for the signal.\n modularity (Optional[str]): The modality of the signal (e.g., image, video, text, audio).\n type (Optional[str]): A more detailed signal type (e.g., 'text', 'document', etc.).\n timestamp (Optional[datetime]): The time when the signal was created or captured.\n id (Optional[str]): Unique identifier for the signal.\n dependencies (Optional[Dict[str, Any]]): Attributes of dependencies (e.g., relationship types).\n description (Optional[str]): Description of the signal.\n metadata (Optional[Dict[str, Any]]): Optional additional metadata for the signal.\n \"\"\"\n self.data = data\n self.name = name\n self.modularity = modularity\n self.type = type\n self.timestamp = timestamp or datetime.now()\n self.id = id\n self.dependencies = dependencies or {} # Edge attributes (could be string, embedding, etc.)\n self.description = description\n self.metadata = metadata or {}\n
"},{"location":"reference/#src.aeiva.perception.sensation.Signal.to_dict","title":"to_dict()
","text":"Converts the signal into a dictionary representation.
Source code in src/aeiva/perception/sensation.py
def to_dict(self) -> Dict[str, Any]:\n \"\"\"\n Converts the signal into a dictionary representation.\n \"\"\"\n return {\n \"data\": self.data,\n \"name\": self.name,\n \"modularity\": self.modularity,\n \"type\": self.type,\n \"timestamp\": self.timestamp,\n \"id\": self.id,\n \"dependencies\": self.dependencies,\n \"description\": self.description,\n \"metadata\": self.metadata\n }\n
"},{"location":"reference/#src.aeiva.perception.sensor","title":"sensor
","text":""},{"location":"reference/#src.aeiva.perception.sensor.Sensor","title":"Sensor
","text":" Bases: ABC
Abstract base class for all sensors.
Source code in src/aeiva/perception/sensor.py
class Sensor(ABC):\n \"\"\"\n Abstract base class for all sensors.\n \"\"\"\n def __init__(self, name: str, params: dict, event_bus):\n \"\"\"\n Initializes the BaseSensor.\n\n Args:\n name (str): The name of the sensor.\n params (dict): Configuration parameters for the sensor.\n event_bus: The EventBus instance for emitting events.\n \"\"\"\n self.name = name\n self.params = params\n self.event_bus = event_bus\n\n @abstractmethod\n async def start(self):\n \"\"\"\n Starts the sensor.\n \"\"\"\n pass\n\n @abstractmethod\n async def stop(self):\n \"\"\"\n Stops the sensor.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.perception.sensor.Sensor.__init__","title":"__init__(name, params, event_bus)
","text":"Initializes the BaseSensor.
Parameters:
Name Type Description Default name
str
The name of the sensor.
required params
dict
Configuration parameters for the sensor.
required event_bus
The EventBus instance for emitting events.
required Source code in src/aeiva/perception/sensor.py
def __init__(self, name: str, params: dict, event_bus):\n \"\"\"\n Initializes the BaseSensor.\n\n Args:\n name (str): The name of the sensor.\n params (dict): Configuration parameters for the sensor.\n event_bus: The EventBus instance for emitting events.\n \"\"\"\n self.name = name\n self.params = params\n self.event_bus = event_bus\n
"},{"location":"reference/#src.aeiva.perception.sensor.Sensor.start","title":"start()
abstractmethod
async
","text":"Starts the sensor.
Source code in src/aeiva/perception/sensor.py
@abstractmethod\nasync def start(self):\n \"\"\"\n Starts the sensor.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.perception.sensor.Sensor.stop","title":"stop()
abstractmethod
async
","text":"Stops the sensor.
Source code in src/aeiva/perception/sensor.py
@abstractmethod\nasync def stop(self):\n \"\"\"\n Stops the sensor.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.perception.stimuli","title":"stimuli
","text":""},{"location":"reference/#src.aeiva.perception.stimuli.Stimuli","title":"Stimuli
","text":"Represents a structured composition of signals, where each node can be a Signal or a sub-Stimuli. The graph allows flexible, directed relationships between nodes, and the graph can contain cycles.
Source code in src/aeiva/perception/stimuli.py
class Stimuli:\n \"\"\"\n Represents a structured composition of signals, where each node can be a Signal or a sub-Stimuli.\n The graph allows flexible, directed relationships between nodes, and the graph can contain cycles.\n \"\"\"\n\n def __init__(self, \n signals: List[Union[Signal, 'Stimuli']],\n id: Optional[str] = None,\n name: Optional[str] = None,\n type: Optional[str] = None,\n modularity: Optional[str] = None,\n timestamp: Optional[str] = None,\n dependencies: Optional[Dict[str, Dict[str, Any]]] = None,\n description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n \"\"\"\n Initializes the Stimuli object by organizing signals or sub-stimuli in a graph structure.\n \"\"\"\n self.signals = signals or [] # Default to an empty list if no signals provided\n self.id = id\n self.name = name\n self.type = type\n self.modularity = modularity\n self.timestamp = timestamp\n self.description = description\n self.metadata = metadata or {}\n self.dependencies = dependencies or {}\n\n # Graph to represent the structure of signals and their relationships\n self.graph = nx.DiGraph()\n\n # Add all signals and sub-stimuli as nodes in the graph\n for signal in signals:\n self.graph.add_node(signal)\n\n # Handle dependencies for signals or sub-stimuli\n for signal in signals:\n if signal.id in self.dependencies:\n for dep_id, edge_attr in self.dependencies[signal.id].items():\n dep_node = next((s for s in signals if s.id == dep_id), None)\n if dep_node and (isinstance(dep_node, Signal) or isinstance(dep_node, Stimuli)):\n self.graph.add_edge(dep_node, signal, **edge_attr)\n else:\n raise ValueError(f\"Dependency {dep_id} not found or is not valid for signal or stimuli {signal.id}.\")\n\n def traverse(self, method: str = 'dfs') -> List[Union[Signal, 'Stimuli']]:\n \"\"\"\n Traverses the graph using the specified method ('dfs' or 'bfs').\n\n Args:\n method (str): The traversal method to use, either 'dfs' (Depth-First Search) or 'bfs' (Breadth-First Search).\n\n Returns:\n List[Union[Signal, 'Stimuli']]: A list of signals or sub-stimuli in the order they were visited.\n \"\"\"\n if not self.graph.nodes:\n return []\n\n if method == 'dfs':\n return list(nx.dfs_postorder_nodes(self.graph))\n elif method == 'bfs':\n return list(nx.bfs_tree(self.graph, list(self.graph.nodes)[0])) # BFS starting from an arbitrary node\n else:\n raise ValueError(f\"Unknown traversal method: {method}\")\n\n def to_dict(self) -> Dict[str, Any]:\n \"\"\"\n Converts the stimuli into a dictionary representation, including its signals and their relationships.\n \"\"\"\n return {\n \"id\": self.id,\n \"name\": self.name,\n \"type\": self.type,\n \"modularity\": self.modularity,\n \"timestamp\": self.timestamp,\n \"description\": self.description,\n \"metadata\": self.metadata,\n \"signals\": [signal.to_dict() if isinstance(signal, Signal) else signal.to_dict() for signal in self.signals],\n \"dependencies\": self.dependencies\n }\n\n def visualize(self, save_path: Optional[str] = None):\n \"\"\"\n Visualizes the procedure's structure using networkx and matplotlib.\n \"\"\"\n pos = nx.spring_layout(self.graph) # Layout for the graph\n labels = {node: f\"{node.id} ({node.type})\" if isinstance(node, Signal) else f\"{node.id} (Stimuli)\"\n for node in self.graph.nodes()}\n\n # Draw the graph with labels\n nx.draw(self.graph, pos, with_labels=True, labels=labels, node_color='lightblue', node_size=2000, font_size=10, font_weight='bold', arrows=True)\n\n plt.title(f\"{self.type} {self.description} Visualization\")\n if save_path:\n plt.savefig(save_path)\n else:\n plt.show()\n
"},{"location":"reference/#src.aeiva.perception.stimuli.Stimuli.__init__","title":"__init__(signals, id=None, name=None, type=None, modularity=None, timestamp=None, dependencies=None, description=None, metadata=None)
","text":"Initializes the Stimuli object by organizing signals or sub-stimuli in a graph structure.
Source code in src/aeiva/perception/stimuli.py
def __init__(self, \n signals: List[Union[Signal, 'Stimuli']],\n id: Optional[str] = None,\n name: Optional[str] = None,\n type: Optional[str] = None,\n modularity: Optional[str] = None,\n timestamp: Optional[str] = None,\n dependencies: Optional[Dict[str, Dict[str, Any]]] = None,\n description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n \"\"\"\n Initializes the Stimuli object by organizing signals or sub-stimuli in a graph structure.\n \"\"\"\n self.signals = signals or [] # Default to an empty list if no signals provided\n self.id = id\n self.name = name\n self.type = type\n self.modularity = modularity\n self.timestamp = timestamp\n self.description = description\n self.metadata = metadata or {}\n self.dependencies = dependencies or {}\n\n # Graph to represent the structure of signals and their relationships\n self.graph = nx.DiGraph()\n\n # Add all signals and sub-stimuli as nodes in the graph\n for signal in signals:\n self.graph.add_node(signal)\n\n # Handle dependencies for signals or sub-stimuli\n for signal in signals:\n if signal.id in self.dependencies:\n for dep_id, edge_attr in self.dependencies[signal.id].items():\n dep_node = next((s for s in signals if s.id == dep_id), None)\n if dep_node and (isinstance(dep_node, Signal) or isinstance(dep_node, Stimuli)):\n self.graph.add_edge(dep_node, signal, **edge_attr)\n else:\n raise ValueError(f\"Dependency {dep_id} not found or is not valid for signal or stimuli {signal.id}.\")\n
"},{"location":"reference/#src.aeiva.perception.stimuli.Stimuli.to_dict","title":"to_dict()
","text":"Converts the stimuli into a dictionary representation, including its signals and their relationships.
Source code in src/aeiva/perception/stimuli.py
def to_dict(self) -> Dict[str, Any]:\n \"\"\"\n Converts the stimuli into a dictionary representation, including its signals and their relationships.\n \"\"\"\n return {\n \"id\": self.id,\n \"name\": self.name,\n \"type\": self.type,\n \"modularity\": self.modularity,\n \"timestamp\": self.timestamp,\n \"description\": self.description,\n \"metadata\": self.metadata,\n \"signals\": [signal.to_dict() if isinstance(signal, Signal) else signal.to_dict() for signal in self.signals],\n \"dependencies\": self.dependencies\n }\n
"},{"location":"reference/#src.aeiva.perception.stimuli.Stimuli.traverse","title":"traverse(method='dfs')
","text":"Traverses the graph using the specified method ('dfs' or 'bfs').
Parameters:
Name Type Description Default method
str
The traversal method to use, either 'dfs' (Depth-First Search) or 'bfs' (Breadth-First Search).
'dfs'
Returns:
Type Description List[Union[Signal, Stimuli]]
List[Union[Signal, 'Stimuli']]: A list of signals or sub-stimuli in the order they were visited.
Source code in src/aeiva/perception/stimuli.py
def traverse(self, method: str = 'dfs') -> List[Union[Signal, 'Stimuli']]:\n \"\"\"\n Traverses the graph using the specified method ('dfs' or 'bfs').\n\n Args:\n method (str): The traversal method to use, either 'dfs' (Depth-First Search) or 'bfs' (Breadth-First Search).\n\n Returns:\n List[Union[Signal, 'Stimuli']]: A list of signals or sub-stimuli in the order they were visited.\n \"\"\"\n if not self.graph.nodes:\n return []\n\n if method == 'dfs':\n return list(nx.dfs_postorder_nodes(self.graph))\n elif method == 'bfs':\n return list(nx.bfs_tree(self.graph, list(self.graph.nodes)[0])) # BFS starting from an arbitrary node\n else:\n raise ValueError(f\"Unknown traversal method: {method}\")\n
"},{"location":"reference/#src.aeiva.perception.stimuli.Stimuli.visualize","title":"visualize(save_path=None)
","text":"Visualizes the procedure's structure using networkx and matplotlib.
Source code in src/aeiva/perception/stimuli.py
def visualize(self, save_path: Optional[str] = None):\n \"\"\"\n Visualizes the procedure's structure using networkx and matplotlib.\n \"\"\"\n pos = nx.spring_layout(self.graph) # Layout for the graph\n labels = {node: f\"{node.id} ({node.type})\" if isinstance(node, Signal) else f\"{node.id} (Stimuli)\"\n for node in self.graph.nodes()}\n\n # Draw the graph with labels\n nx.draw(self.graph, pos, with_labels=True, labels=labels, node_color='lightblue', node_size=2000, font_size=10, font_weight='bold', arrows=True)\n\n plt.title(f\"{self.type} {self.description} Visualization\")\n if save_path:\n plt.savefig(save_path)\n else:\n plt.show()\n
"},{"location":"reference/#src.aeiva.perception.terminal_input_sensor","title":"terminal_input_sensor
","text":""},{"location":"reference/#src.aeiva.perception.terminal_input_sensor.TerminalInputSensor","title":"TerminalInputSensor
","text":" Bases: Sensor
A sensor that reads input from the terminal and emits stimuli via the EventBus.
Source code in src/aeiva/perception/terminal_input_sensor.py
class TerminalInputSensor(Sensor):\n \"\"\"\n A sensor that reads input from the terminal and emits stimuli via the EventBus.\n \"\"\"\n def __init__(self, name: str, params: dict, event_bus):\n super().__init__(name, params, event_bus)\n self.prompt_message = params.get('prompt_message', 'You: ')\n self._running = False\n self._thread = None\n # self.logger = logging.getLogger(f'TerminalInputSensor-{self.name}')\n\n async def start(self):\n \"\"\"\n Starts the sensor by launching the input thread.\n \"\"\"\n self._running = True\n self._thread = threading.Thread(target=self._run, daemon=True)\n self._thread.start()\n # self.logger.info(f\"{self.name} started.\")\n\n async def stop(self):\n \"\"\"\n Stops the sensor by signaling the thread to stop and waiting for it to finish.\n \"\"\"\n self._running = False\n if self._thread:\n self._thread.join()\n # self.logger.info(f\"{self.name} stopped.\")\n\n def _run(self):\n \"\"\"\n The main loop that reads user input and emits events.\n \"\"\"\n loop = self.event_bus.loop\n if loop is None:\n # self.logger.error(\"EventBus loop is not set. Cannot emit events.\")\n return\n\n while self._running:\n try:\n user_input = input(self.prompt_message)\n if not self._running:\n break # Exit if stopped during input\n\n # # Process input into stimuli\n # stimuli = self.signal_to_stimuli(user_input)\n\n # Emit the stimuli as an event\n asyncio.run_coroutine_threadsafe(\n self.event_bus.emit('perception.stimuli', payload=user_input), # TODO: rename event later\n loop\n )\n except EOFError:\n # Handle end of input (Ctrl+D)\n # self.logger.info(\"EOF received. Stopping TerminalInputSensor.\")\n self._running = False\n except KeyboardInterrupt:\n # Handle Ctrl+C\n # self.logger.info(\"KeyboardInterrupt received. Stopping TerminalInputSensor.\")\n self._running = False\n except Exception as e:\n # self.logger.error(f\"Error in TerminalInputSensor: {e}\")\n self._running = False\n
"},{"location":"reference/#src.aeiva.perception.terminal_input_sensor.TerminalInputSensor.start","title":"start()
async
","text":"Starts the sensor by launching the input thread.
Source code in src/aeiva/perception/terminal_input_sensor.py
async def start(self):\n \"\"\"\n Starts the sensor by launching the input thread.\n \"\"\"\n self._running = True\n self._thread = threading.Thread(target=self._run, daemon=True)\n self._thread.start()\n
"},{"location":"reference/#src.aeiva.perception.terminal_input_sensor.TerminalInputSensor.stop","title":"stop()
async
","text":"Stops the sensor by signaling the thread to stop and waiting for it to finish.
Source code in src/aeiva/perception/terminal_input_sensor.py
async def stop(self):\n \"\"\"\n Stops the sensor by signaling the thread to stop and waiting for it to finish.\n \"\"\"\n self._running = False\n if self._thread:\n self._thread.join()\n
"},{"location":"reference/#src.aeiva.perception.test","title":"test
","text":""},{"location":"reference/#src.aeiva.perception.test.handle_observation","title":"handle_observation(stimuli)
async
","text":"Processes stimuli using the cognition system and outputs the response.
Source code in src/aeiva/perception/test.py
async def handle_observation(stimuli):\n \"\"\"\n Processes stimuli using the cognition system and outputs the response.\n \"\"\"\n for signal in stimuli.signals:\n user_input = signal.data\n stimuli_data = [{\"role\": \"user\", \"content\": user_input}]\n response = await llm_brain.think(stimuli_data, stream=True)\n print(f\"LLM Response: {response}\")\n
"},{"location":"reference/#src.aeiva.plugin","title":"plugin
","text":""},{"location":"reference/#src.aeiva.plugin.ability","title":"ability
","text":""},{"location":"reference/#src.aeiva.plugin.ability.plugin_a","title":"plugin_a
","text":""},{"location":"reference/#src.aeiva.plugin.ability.plugin_a.plugin","title":"plugin
","text":""},{"location":"reference/#src.aeiva.plugin.ability.plugin_a.plugin.PluginA","title":"PluginA
","text":" Bases: Plugin
Example Plugin A.
Source code in src/aeiva/plugin/ability/plugin_a/plugin.py
class PluginA(Plugin):\n \"\"\"\n Example Plugin A.\n \"\"\"\n\n def activate(self) -> None:\n print(\"PluginA activated.\")\n\n def deactivate(self) -> None:\n print(\"PluginA deactivated.\")\n\n def run(self) -> None:\n print(\"PluginA is running.\")\n
"},{"location":"reference/#src.aeiva.plugin.ability.plugin_b","title":"plugin_b
","text":""},{"location":"reference/#src.aeiva.plugin.ability.plugin_b.plugin","title":"plugin
","text":""},{"location":"reference/#src.aeiva.plugin.ability.plugin_b.plugin.PluginB","title":"PluginB
","text":" Bases: Plugin
Example Plugin B.
Source code in src/aeiva/plugin/ability/plugin_b/plugin.py
class PluginB(Plugin):\n \"\"\"\n Example Plugin B.\n \"\"\"\n\n def activate(self) -> None:\n print(\"PluginB activated.\")\n\n def deactivate(self) -> None:\n print(\"PluginB deactivated.\")\n\n def run(self) -> None:\n print(\"PluginB is running.\")\n
"},{"location":"reference/#src.aeiva.plugin.plug","title":"plug
","text":""},{"location":"reference/#src.aeiva.plugin.plug--plug-module","title":"Plug Module","text":"This module provides a flexible plugin system with support for:
- Multiple plugin sources with isolation
- Context managers and import hooks
- Resource loading from plugins
- Loading plugins from directories and zip files
- Hot swapping and lazy loading of plugins
Author: Bang Liu Date: 2024-11-19
"},{"location":"reference/#src.aeiva.plugin.plug.Plugin","title":"Plugin
","text":" Bases: ABC
Abstract base class that all plugins must inherit from.
Source code in src/aeiva/plugin/plug.py
class Plugin(abc.ABC):\n \"\"\"\n Abstract base class that all plugins must inherit from.\n \"\"\"\n\n @abc.abstractmethod\n def activate(self) -> None:\n \"\"\"Method called when the plugin is activated.\"\"\"\n pass\n\n @abc.abstractmethod\n def deactivate(self) -> None:\n \"\"\"Method called when the plugin is deactivated.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.plugin.plug.Plugin.activate","title":"activate()
abstractmethod
","text":"Method called when the plugin is activated.
Source code in src/aeiva/plugin/plug.py
@abc.abstractmethod\ndef activate(self) -> None:\n \"\"\"Method called when the plugin is activated.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.plugin.plug.Plugin.deactivate","title":"deactivate()
abstractmethod
","text":"Method called when the plugin is deactivated.
Source code in src/aeiva/plugin/plug.py
@abc.abstractmethod\ndef deactivate(self) -> None:\n \"\"\"Method called when the plugin is deactivated.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginFinder","title":"PluginFinder
","text":" Bases: MetaPathFinder
Custom finder for plugin modules. Finds plugins as directories containing a plugin.py
file.
Source code in src/aeiva/plugin/plug.py
class PluginFinder(importlib.abc.MetaPathFinder):\n \"\"\"\n Custom finder for plugin modules.\n Finds plugins as directories containing a `plugin.py` file.\n \"\"\"\n\n def __init__(self, plugin_source: 'PluginSource') -> None:\n self.plugin_source = plugin_source\n\n def find_spec(\n self,\n fullname: str,\n path: Optional[List[str]],\n target: Optional[ModuleType] = None\n ) -> Optional[importlib.machinery.ModuleSpec]:\n \"\"\"\n Find the module spec for the given module.\n Handles both the namespace package and its submodules (plugins).\n \"\"\"\n if fullname == self.plugin_source.namespace:\n # Handle the namespace package itself\n print(f\"PluginFinder: Creating namespace package '{fullname}'\")\n spec = importlib.machinery.ModuleSpec(fullname, loader=None, is_package=True)\n spec.submodule_search_locations = []\n return spec\n\n elif fullname.startswith(self.plugin_source.namespace + '.'):\n # Handle submodules (plugins)\n plugin_name = fullname[len(self.plugin_source.namespace) + 1:]\n if plugin_name in self.plugin_source.list_plugins():\n print(f\"PluginFinder: Found plugin '{plugin_name}' for module '{fullname}'\")\n loader = PluginLoader(self.plugin_source, plugin_name)\n spec = importlib.util.spec_from_loader(fullname, loader)\n spec.submodule_search_locations = []\n return spec\n\n # If not handling this module, return None\n print(f\"PluginFinder: Not handling module '{fullname}'\")\n return None\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginFinder.find_spec","title":"find_spec(fullname, path, target=None)
","text":"Find the module spec for the given module. Handles both the namespace package and its submodules (plugins).
Source code in src/aeiva/plugin/plug.py
def find_spec(\n self,\n fullname: str,\n path: Optional[List[str]],\n target: Optional[ModuleType] = None\n) -> Optional[importlib.machinery.ModuleSpec]:\n \"\"\"\n Find the module spec for the given module.\n Handles both the namespace package and its submodules (plugins).\n \"\"\"\n if fullname == self.plugin_source.namespace:\n # Handle the namespace package itself\n print(f\"PluginFinder: Creating namespace package '{fullname}'\")\n spec = importlib.machinery.ModuleSpec(fullname, loader=None, is_package=True)\n spec.submodule_search_locations = []\n return spec\n\n elif fullname.startswith(self.plugin_source.namespace + '.'):\n # Handle submodules (plugins)\n plugin_name = fullname[len(self.plugin_source.namespace) + 1:]\n if plugin_name in self.plugin_source.list_plugins():\n print(f\"PluginFinder: Found plugin '{plugin_name}' for module '{fullname}'\")\n loader = PluginLoader(self.plugin_source, plugin_name)\n spec = importlib.util.spec_from_loader(fullname, loader)\n spec.submodule_search_locations = []\n return spec\n\n # If not handling this module, return None\n print(f\"PluginFinder: Not handling module '{fullname}'\")\n return None\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginLoader","title":"PluginLoader
","text":" Bases: Loader
Custom loader for plugin modules. Loads the plugin.py
file within the plugin directory.
Source code in src/aeiva/plugin/plug.py
class PluginLoader(importlib.abc.Loader):\n \"\"\"\n Custom loader for plugin modules.\n Loads the `plugin.py` file within the plugin directory.\n \"\"\"\n\n def __init__(self, plugin_source: 'PluginSource', plugin_name: str) -> None:\n self.plugin_source = plugin_source\n self.plugin_name = plugin_name\n\n def create_module(self, spec: importlib.machinery.ModuleSpec) -> Optional[ModuleType]:\n \"\"\"Use default module creation semantics.\"\"\"\n return None\n\n def exec_module(self, module: ModuleType) -> None:\n \"\"\"Execute the plugin's `plugin.py` module.\"\"\"\n try:\n code = self.plugin_source.get_plugin_code(self.plugin_name)\n except ImportError as e:\n print(f\"PluginLoader: Failed to get code for plugin '{self.plugin_name}': {e}\")\n raise\n\n # Compute project_root dynamically based on plug.py's location\n plugin_dir = os.path.dirname(os.path.abspath(__file__))\n project_root = os.path.abspath(os.path.join(plugin_dir, '../../../'))\n print(f\"PluginLoader: Adding '{project_root}' to sys.path for plugin '{self.plugin_name}'\")\n sys.path.insert(0, project_root)\n\n try:\n print(f\"PluginLoader: Executing plugin '{self.plugin_name}'\")\n exec(code, module.__dict__)\n print(f\"PluginLoader: Plugin '{self.plugin_name}' executed successfully\")\n except Exception as e:\n print(f\"PluginLoader: Error executing plugin '{self.plugin_name}': {e}\")\n raise\n finally:\n sys.path.pop(0)\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginLoader.create_module","title":"create_module(spec)
","text":"Use default module creation semantics.
Source code in src/aeiva/plugin/plug.py
def create_module(self, spec: importlib.machinery.ModuleSpec) -> Optional[ModuleType]:\n \"\"\"Use default module creation semantics.\"\"\"\n return None\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginLoader.exec_module","title":"exec_module(module)
","text":"Execute the plugin's plugin.py
module.
Source code in src/aeiva/plugin/plug.py
def exec_module(self, module: ModuleType) -> None:\n \"\"\"Execute the plugin's `plugin.py` module.\"\"\"\n try:\n code = self.plugin_source.get_plugin_code(self.plugin_name)\n except ImportError as e:\n print(f\"PluginLoader: Failed to get code for plugin '{self.plugin_name}': {e}\")\n raise\n\n # Compute project_root dynamically based on plug.py's location\n plugin_dir = os.path.dirname(os.path.abspath(__file__))\n project_root = os.path.abspath(os.path.join(plugin_dir, '../../../'))\n print(f\"PluginLoader: Adding '{project_root}' to sys.path for plugin '{self.plugin_name}'\")\n sys.path.insert(0, project_root)\n\n try:\n print(f\"PluginLoader: Executing plugin '{self.plugin_name}'\")\n exec(code, module.__dict__)\n print(f\"PluginLoader: Plugin '{self.plugin_name}' executed successfully\")\n except Exception as e:\n print(f\"PluginLoader: Error executing plugin '{self.plugin_name}': {e}\")\n raise\n finally:\n sys.path.pop(0)\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginManager","title":"PluginManager
","text":"Manages multiple PluginSources and controls plugin imports.
Source code in src/aeiva/plugin/plug.py
class PluginManager:\n \"\"\"\n Manages multiple PluginSources and controls plugin imports.\n \"\"\"\n\n def __init__(self) -> None:\n self.plugin_sources: Dict[str, PluginSource] = {}\n\n def create_plugin_source(self, name: str, search_path: Optional[List[str]] = None) -> PluginSource:\n \"\"\"\n Creates a new PluginSource.\n\n :param name: Unique name for the plugin source.\n :param search_path: List of paths to search for plugins.\n :return: The created PluginSource.\n \"\"\"\n if name in self.plugin_sources:\n raise ValueError(f\"Plugin source '{name}' already exists.\")\n source = PluginSource(name, search_path)\n self.plugin_sources[name] = source\n print(f\"PluginManager: Created plugin source '{name}' with search paths {search_path}.\")\n return source\n\n def get_plugin_source(self, name: str) -> Optional[PluginSource]:\n \"\"\"\n Retrieves a PluginSource by name.\n\n :param name: Name of the PluginSource.\n :return: The PluginSource instance, or None if not found.\n \"\"\"\n return self.plugin_sources.get(name)\n\n def remove_plugin_source(self, name: str) -> None:\n \"\"\"\n Removes a PluginSource.\n\n :param name: Name of the PluginSource to remove.\n \"\"\"\n source = self.plugin_sources.pop(name, None)\n if source:\n source.disable()\n for plugin_name in list(source._modules.keys()):\n source.unload_plugin(plugin_name)\n print(f\"PluginManager: Removed plugin source '{name}'.\")\n else:\n print(f\"PluginManager: Plugin source '{name}' does not exist.\")\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginManager.create_plugin_source","title":"create_plugin_source(name, search_path=None)
","text":"Creates a new PluginSource.
:param name: Unique name for the plugin source. :param search_path: List of paths to search for plugins. :return: The created PluginSource.
Source code in src/aeiva/plugin/plug.py
def create_plugin_source(self, name: str, search_path: Optional[List[str]] = None) -> PluginSource:\n \"\"\"\n Creates a new PluginSource.\n\n :param name: Unique name for the plugin source.\n :param search_path: List of paths to search for plugins.\n :return: The created PluginSource.\n \"\"\"\n if name in self.plugin_sources:\n raise ValueError(f\"Plugin source '{name}' already exists.\")\n source = PluginSource(name, search_path)\n self.plugin_sources[name] = source\n print(f\"PluginManager: Created plugin source '{name}' with search paths {search_path}.\")\n return source\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginManager.get_plugin_source","title":"get_plugin_source(name)
","text":"Retrieves a PluginSource by name.
:param name: Name of the PluginSource. :return: The PluginSource instance, or None if not found.
Source code in src/aeiva/plugin/plug.py
def get_plugin_source(self, name: str) -> Optional[PluginSource]:\n \"\"\"\n Retrieves a PluginSource by name.\n\n :param name: Name of the PluginSource.\n :return: The PluginSource instance, or None if not found.\n \"\"\"\n return self.plugin_sources.get(name)\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginManager.remove_plugin_source","title":"remove_plugin_source(name)
","text":"Removes a PluginSource.
:param name: Name of the PluginSource to remove.
Source code in src/aeiva/plugin/plug.py
def remove_plugin_source(self, name: str) -> None:\n \"\"\"\n Removes a PluginSource.\n\n :param name: Name of the PluginSource to remove.\n \"\"\"\n source = self.plugin_sources.pop(name, None)\n if source:\n source.disable()\n for plugin_name in list(source._modules.keys()):\n source.unload_plugin(plugin_name)\n print(f\"PluginManager: Removed plugin source '{name}'.\")\n else:\n print(f\"PluginManager: Plugin source '{name}' does not exist.\")\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource","title":"PluginSource
","text":"Represents an isolated source of plugins. Each plugin is a directory containing a plugin.py
file.
Source code in src/aeiva/plugin/plug.py
class PluginSource:\n \"\"\"\n Represents an isolated source of plugins.\n Each plugin is a directory containing a `plugin.py` file.\n \"\"\"\n\n def __init__(self, name: str, search_path: Optional[List[str]] = None) -> None:\n \"\"\"\n Initializes the PluginSource.\n\n :param name: Unique name for the plugin source.\n :param search_path: List of paths (directories or zip files) to search for plugins.\n \"\"\"\n self.name = name\n self.search_path = search_path or []\n self._lock = threading.Lock()\n self._modules: Dict[str, ModuleType] = {}\n self.namespace = f\"_plug_{self.name}\"\n self._finder = PluginFinder(self)\n self._finder_enabled = False\n\n def __enter__(self) -> 'PluginSource':\n \"\"\"Enter the runtime context related to this object.\"\"\"\n self.enable()\n return self\n\n def __exit__(self, exc_type, exc_value, traceback) -> None:\n \"\"\"Exit the runtime context.\"\"\"\n self.disable()\n\n def enable(self) -> None:\n \"\"\"Enable the plugin import mechanism.\"\"\"\n if not self._finder_enabled:\n sys.meta_path.insert(0, self._finder)\n self._finder_enabled = True\n print(f\"PluginSource: Import hook enabled for namespace '{self.namespace}'.\")\n\n def disable(self) -> None:\n \"\"\"Disable the plugin import mechanism.\"\"\"\n if self._finder_enabled:\n try:\n sys.meta_path.remove(self._finder)\n print(f\"PluginSource: Import hook disabled for namespace '{self.namespace}'.\")\n except ValueError:\n print(f\"PluginSource: Import hook for namespace '{self.namespace}' was not found in sys.meta_path.\")\n self._finder_enabled = False\n\n def list_plugins(self) -> List[str]:\n \"\"\"\n Lists available plugins in the search paths.\n Each plugin is a directory containing a `plugin.py` file.\n\n :return: List of plugin names.\n \"\"\"\n plugins = set()\n for path in self.search_path:\n if zipfile.is_zipfile(path):\n with zipfile.ZipFile(path, 'r') as z:\n # Identify top-level directories containing `plugin.py`\n plugin_dirs = set()\n for file in z.namelist():\n parts = file.split('/')\n if len(parts) >= 2 and parts[-1] == 'plugin.py':\n plugin_dir = parts[0]\n plugin_dirs.add(plugin_dir)\n plugins.update(plugin_dirs)\n else:\n # Assume it's a directory\n if not os.path.isdir(path):\n print(f\"PluginSource: Path '{path}' is not a directory or a zip file. Skipping.\")\n continue\n for entry in os.listdir(path):\n plugin_path = os.path.join(path, entry)\n if os.path.isdir(plugin_path):\n plugin_main = os.path.join(plugin_path, 'plugin.py')\n if os.path.isfile(plugin_main):\n plugins.add(entry)\n return list(plugins)\n\n def get_plugin_code(self, plugin_name: str) -> str:\n \"\"\"\n Get the source code of the plugin's `plugin.py`.\n\n :param plugin_name: Name of the plugin to load.\n :return: Source code of `plugin.py` as a string.\n \"\"\"\n for path in self.search_path:\n if zipfile.is_zipfile(path):\n with zipfile.ZipFile(path, 'r') as z:\n plugin_main = f\"{plugin_name}/plugin.py\"\n if plugin_main in z.namelist():\n print(f\"PluginSource: Found plugin '{plugin_name}' in zip file '{path}'.\")\n return z.read(plugin_main).decode('utf-8')\n else:\n # Assume it's a directory\n plugin_dir = os.path.join(path, plugin_name)\n plugin_main = os.path.join(plugin_dir, 'plugin.py')\n if os.path.isfile(plugin_main):\n print(f\"PluginSource: Found plugin '{plugin_name}' as module file '{plugin_main}'.\")\n with open(plugin_main, 'r', encoding='utf-8') as f:\n return f.read()\n raise ImportError(f\"Cannot find plugin '{plugin_name}'.\")\n\n def load_plugin(self, plugin_name: str) -> ModuleType:\n \"\"\"\n Loads a plugin by name.\n\n :param plugin_name: Name of the plugin to load.\n :return: The loaded plugin module.\n \"\"\"\n with self._lock:\n full_name = f\"{self.namespace}.{plugin_name}\"\n if full_name in sys.modules:\n print(f\"PluginSource: Plugin '{plugin_name}' is already loaded as '{full_name}'.\")\n return sys.modules[full_name]\n # Enable the finder if not already enabled\n self.enable()\n try:\n print(f\"PluginSource: Loading plugin '{plugin_name}' as '{full_name}'.\")\n module = importlib.import_module(full_name)\n self._modules[plugin_name] = module\n return module\n except ImportError as e:\n print(f\"PluginSource: Cannot import plugin '{plugin_name}': {e}\")\n raise\n\n def unload_plugin(self, plugin_name: str) -> None:\n \"\"\"\n Unloads a plugin by name.\n\n :param plugin_name: Name of the plugin to unload.\n \"\"\"\n with self._lock:\n full_name = f\"{self.namespace}.{plugin_name}\"\n module = self._modules.pop(plugin_name, None)\n if module:\n if hasattr(module, 'deactivate'):\n try:\n print(f\"PluginSource: Deactivating plugin '{plugin_name}'.\")\n getattr(module, 'deactivate')()\n except Exception as e:\n print(f\"PluginSource: Error during deactivation of plugin '{plugin_name}': {e}\")\n if full_name in sys.modules:\n del sys.modules[full_name]\n print(f\"PluginSource: Plugin '{plugin_name}' unloaded and removed from sys.modules.\")\n else:\n print(f\"PluginSource: Plugin '{plugin_name}' is not loaded.\")\n\n def load_resource(self, plugin_name: str, resource_name: str) -> bytes:\n \"\"\"\n Loads a resource from a plugin.\n\n :param plugin_name: Name of the plugin.\n :param resource_name: Name of the resource file.\n :return: Contents of the resource file as bytes.\n \"\"\"\n for path in self.search_path:\n if zipfile.is_zipfile(path):\n with zipfile.ZipFile(path, 'r') as z:\n resource_file = f\"{plugin_name}/{resource_name}\"\n if resource_file in z.namelist():\n print(f\"PluginSource: Loading resource '{resource_name}' from plugin '{plugin_name}' in zip '{path}'.\")\n return z.read(resource_file)\n else:\n # Assume it's a directory\n resource_path = os.path.join(path, plugin_name, resource_name)\n if os.path.isfile(resource_path):\n print(f\"PluginSource: Loading resource '{resource_name}' from plugin '{plugin_name}' at '{resource_path}'.\")\n with open(resource_path, 'rb') as f:\n return f.read()\n raise FileNotFoundError(f\"Resource '{resource_name}' not found in plugin '{plugin_name}'.\")\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource.__enter__","title":"__enter__()
","text":"Enter the runtime context related to this object.
Source code in src/aeiva/plugin/plug.py
def __enter__(self) -> 'PluginSource':\n \"\"\"Enter the runtime context related to this object.\"\"\"\n self.enable()\n return self\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource.__exit__","title":"__exit__(exc_type, exc_value, traceback)
","text":"Exit the runtime context.
Source code in src/aeiva/plugin/plug.py
def __exit__(self, exc_type, exc_value, traceback) -> None:\n \"\"\"Exit the runtime context.\"\"\"\n self.disable()\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource.__init__","title":"__init__(name, search_path=None)
","text":"Initializes the PluginSource.
:param name: Unique name for the plugin source. :param search_path: List of paths (directories or zip files) to search for plugins.
Source code in src/aeiva/plugin/plug.py
def __init__(self, name: str, search_path: Optional[List[str]] = None) -> None:\n \"\"\"\n Initializes the PluginSource.\n\n :param name: Unique name for the plugin source.\n :param search_path: List of paths (directories or zip files) to search for plugins.\n \"\"\"\n self.name = name\n self.search_path = search_path or []\n self._lock = threading.Lock()\n self._modules: Dict[str, ModuleType] = {}\n self.namespace = f\"_plug_{self.name}\"\n self._finder = PluginFinder(self)\n self._finder_enabled = False\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource.disable","title":"disable()
","text":"Disable the plugin import mechanism.
Source code in src/aeiva/plugin/plug.py
def disable(self) -> None:\n \"\"\"Disable the plugin import mechanism.\"\"\"\n if self._finder_enabled:\n try:\n sys.meta_path.remove(self._finder)\n print(f\"PluginSource: Import hook disabled for namespace '{self.namespace}'.\")\n except ValueError:\n print(f\"PluginSource: Import hook for namespace '{self.namespace}' was not found in sys.meta_path.\")\n self._finder_enabled = False\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource.enable","title":"enable()
","text":"Enable the plugin import mechanism.
Source code in src/aeiva/plugin/plug.py
def enable(self) -> None:\n \"\"\"Enable the plugin import mechanism.\"\"\"\n if not self._finder_enabled:\n sys.meta_path.insert(0, self._finder)\n self._finder_enabled = True\n print(f\"PluginSource: Import hook enabled for namespace '{self.namespace}'.\")\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource.get_plugin_code","title":"get_plugin_code(plugin_name)
","text":"Get the source code of the plugin's plugin.py
.
:param plugin_name: Name of the plugin to load. :return: Source code of plugin.py
as a string.
Source code in src/aeiva/plugin/plug.py
def get_plugin_code(self, plugin_name: str) -> str:\n \"\"\"\n Get the source code of the plugin's `plugin.py`.\n\n :param plugin_name: Name of the plugin to load.\n :return: Source code of `plugin.py` as a string.\n \"\"\"\n for path in self.search_path:\n if zipfile.is_zipfile(path):\n with zipfile.ZipFile(path, 'r') as z:\n plugin_main = f\"{plugin_name}/plugin.py\"\n if plugin_main in z.namelist():\n print(f\"PluginSource: Found plugin '{plugin_name}' in zip file '{path}'.\")\n return z.read(plugin_main).decode('utf-8')\n else:\n # Assume it's a directory\n plugin_dir = os.path.join(path, plugin_name)\n plugin_main = os.path.join(plugin_dir, 'plugin.py')\n if os.path.isfile(plugin_main):\n print(f\"PluginSource: Found plugin '{plugin_name}' as module file '{plugin_main}'.\")\n with open(plugin_main, 'r', encoding='utf-8') as f:\n return f.read()\n raise ImportError(f\"Cannot find plugin '{plugin_name}'.\")\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource.list_plugins","title":"list_plugins()
","text":"Lists available plugins in the search paths. Each plugin is a directory containing a plugin.py
file.
:return: List of plugin names.
Source code in src/aeiva/plugin/plug.py
def list_plugins(self) -> List[str]:\n \"\"\"\n Lists available plugins in the search paths.\n Each plugin is a directory containing a `plugin.py` file.\n\n :return: List of plugin names.\n \"\"\"\n plugins = set()\n for path in self.search_path:\n if zipfile.is_zipfile(path):\n with zipfile.ZipFile(path, 'r') as z:\n # Identify top-level directories containing `plugin.py`\n plugin_dirs = set()\n for file in z.namelist():\n parts = file.split('/')\n if len(parts) >= 2 and parts[-1] == 'plugin.py':\n plugin_dir = parts[0]\n plugin_dirs.add(plugin_dir)\n plugins.update(plugin_dirs)\n else:\n # Assume it's a directory\n if not os.path.isdir(path):\n print(f\"PluginSource: Path '{path}' is not a directory or a zip file. Skipping.\")\n continue\n for entry in os.listdir(path):\n plugin_path = os.path.join(path, entry)\n if os.path.isdir(plugin_path):\n plugin_main = os.path.join(plugin_path, 'plugin.py')\n if os.path.isfile(plugin_main):\n plugins.add(entry)\n return list(plugins)\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource.load_plugin","title":"load_plugin(plugin_name)
","text":"Loads a plugin by name.
:param plugin_name: Name of the plugin to load. :return: The loaded plugin module.
Source code in src/aeiva/plugin/plug.py
def load_plugin(self, plugin_name: str) -> ModuleType:\n \"\"\"\n Loads a plugin by name.\n\n :param plugin_name: Name of the plugin to load.\n :return: The loaded plugin module.\n \"\"\"\n with self._lock:\n full_name = f\"{self.namespace}.{plugin_name}\"\n if full_name in sys.modules:\n print(f\"PluginSource: Plugin '{plugin_name}' is already loaded as '{full_name}'.\")\n return sys.modules[full_name]\n # Enable the finder if not already enabled\n self.enable()\n try:\n print(f\"PluginSource: Loading plugin '{plugin_name}' as '{full_name}'.\")\n module = importlib.import_module(full_name)\n self._modules[plugin_name] = module\n return module\n except ImportError as e:\n print(f\"PluginSource: Cannot import plugin '{plugin_name}': {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource.load_resource","title":"load_resource(plugin_name, resource_name)
","text":"Loads a resource from a plugin.
:param plugin_name: Name of the plugin. :param resource_name: Name of the resource file. :return: Contents of the resource file as bytes.
Source code in src/aeiva/plugin/plug.py
def load_resource(self, plugin_name: str, resource_name: str) -> bytes:\n \"\"\"\n Loads a resource from a plugin.\n\n :param plugin_name: Name of the plugin.\n :param resource_name: Name of the resource file.\n :return: Contents of the resource file as bytes.\n \"\"\"\n for path in self.search_path:\n if zipfile.is_zipfile(path):\n with zipfile.ZipFile(path, 'r') as z:\n resource_file = f\"{plugin_name}/{resource_name}\"\n if resource_file in z.namelist():\n print(f\"PluginSource: Loading resource '{resource_name}' from plugin '{plugin_name}' in zip '{path}'.\")\n return z.read(resource_file)\n else:\n # Assume it's a directory\n resource_path = os.path.join(path, plugin_name, resource_name)\n if os.path.isfile(resource_path):\n print(f\"PluginSource: Loading resource '{resource_name}' from plugin '{plugin_name}' at '{resource_path}'.\")\n with open(resource_path, 'rb') as f:\n return f.read()\n raise FileNotFoundError(f\"Resource '{resource_name}' not found in plugin '{plugin_name}'.\")\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource.unload_plugin","title":"unload_plugin(plugin_name)
","text":"Unloads a plugin by name.
:param plugin_name: Name of the plugin to unload.
Source code in src/aeiva/plugin/plug.py
def unload_plugin(self, plugin_name: str) -> None:\n \"\"\"\n Unloads a plugin by name.\n\n :param plugin_name: Name of the plugin to unload.\n \"\"\"\n with self._lock:\n full_name = f\"{self.namespace}.{plugin_name}\"\n module = self._modules.pop(plugin_name, None)\n if module:\n if hasattr(module, 'deactivate'):\n try:\n print(f\"PluginSource: Deactivating plugin '{plugin_name}'.\")\n getattr(module, 'deactivate')()\n except Exception as e:\n print(f\"PluginSource: Error during deactivation of plugin '{plugin_name}': {e}\")\n if full_name in sys.modules:\n del sys.modules[full_name]\n print(f\"PluginSource: Plugin '{plugin_name}' unloaded and removed from sys.modules.\")\n else:\n print(f\"PluginSource: Plugin '{plugin_name}' is not loaded.\")\n
"},{"location":"reference/#src.aeiva.plugin.test","title":"test
","text":""},{"location":"reference/#src.aeiva.plugin.test--main-application","title":"Main Application","text":"This script demonstrates the usage of the plug module and plugin system.
"},{"location":"reference/#src.aeiva.society","title":"society
","text":""},{"location":"reference/#src.aeiva.society.society","title":"society
","text":""},{"location":"reference/#src.aeiva.society.society.Society","title":"Society
","text":" Bases: ABC
Abstract base class representing a Society that connects an environment and agents.
The Society enables agents to interact with each other and with the environment, providing mechanisms for integrating social systems, such as communication or economy.
Attributes:
Name Type Description config
Any
Configuration settings for the society.
environment
Environment
The environment in which agents operate.
agents
Dict[str, Any]
A dictionary of agents within the society.
social_systems
Dict[str, Any]
A dictionary representing various social systems (e.g., communication).
Source code in src/aeiva/society/society.py
class Society(ABC):\n \"\"\"\n Abstract base class representing a Society that connects an environment and agents.\n\n The Society enables agents to interact with each other and with the environment, providing\n mechanisms for integrating social systems, such as communication or economy.\n\n Attributes:\n config (Any): Configuration settings for the society.\n environment (Environment): The environment in which agents operate.\n agents (Dict[str, Any]): A dictionary of agents within the society.\n social_systems (Dict[str, Any]): A dictionary representing various social systems (e.g., communication).\n \"\"\"\n\n def __init__(self, config: Any, environment: Any, agents: Dict[str, Any]):\n \"\"\"\n Initialize the Society with the provided configuration, environment, and agents.\n\n Args:\n config (Any): Configuration settings for the society.\n env (Environment): The environment in which agents operate.\n agents (Dict[str, Any]): A dictionary of agents within the society, keyed by their IDs.\n \"\"\"\n self.config = config\n self.environment = environment\n self.agents = agents # Agents are stored in a dictionary with IDs as keys\n self.social_systems = self.init_social_systems()\n\n @abstractmethod\n def init_social_systems(self) -> Dict[str, Any]:\n \"\"\"\n Initialize the social systems that operate within the society (e.g., communication, financial, law, political, social network systems).\n\n Returns:\n Dict[str, Any]: A dictionary of initialized social systems.\n \"\"\"\n pass\n\n @abstractmethod\n async def setup(self) -> None:\n \"\"\"\n Asynchronously set up the society's components, such as initializing the environment and agents.\n \"\"\"\n await self.env.setup()\n await asyncio.gather(*(agent.setup() for agent in self.agents.values()))\n print(\"Society: Setup completed.\")\n\n @abstractmethod\n async def run(self) -> None:\n \"\"\"\n Asynchronously run the society, managing interactions between agents and the environment.\n\n This method should control the flow of interactions between agents and the environment,\n and it can be designed as a continuous loop or a task-based execution.\n \"\"\"\n pass\n\n def add_agent(self, agent_id: str, agent: Any) -> None:\n \"\"\"\n Add a new agent to the society.\n\n Args:\n agent_id (str): The unique identifier of the agent.\n agent (Any): The agent object to add to the society.\n \"\"\"\n self.agents[agent_id] = agent\n\n def remove_agent(self, agent_id: str) -> None:\n \"\"\"\n Remove an agent from the society by its ID.\n\n Args:\n agent_id (str): The unique identifier of the agent.\n \"\"\"\n if agent_id in self.agents:\n del self.agents[agent_id]\n\n def get_agent(self, agent_id: str) -> Any:\n \"\"\"\n Retrieve an agent by its ID.\n\n Args:\n agent_id (str): The unique identifier of the agent.\n\n Returns:\n Any: The agent object, if found.\n \"\"\"\n return self.agents.get(agent_id, None)\n\n def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during society operations.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n print(f\"Society encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.society.society.Society.__init__","title":"__init__(config, environment, agents)
","text":"Initialize the Society with the provided configuration, environment, and agents.
Parameters:
Name Type Description Default config
Any
Configuration settings for the society.
required env
Environment
The environment in which agents operate.
required agents
Dict[str, Any]
A dictionary of agents within the society, keyed by their IDs.
required Source code in src/aeiva/society/society.py
def __init__(self, config: Any, environment: Any, agents: Dict[str, Any]):\n \"\"\"\n Initialize the Society with the provided configuration, environment, and agents.\n\n Args:\n config (Any): Configuration settings for the society.\n env (Environment): The environment in which agents operate.\n agents (Dict[str, Any]): A dictionary of agents within the society, keyed by their IDs.\n \"\"\"\n self.config = config\n self.environment = environment\n self.agents = agents # Agents are stored in a dictionary with IDs as keys\n self.social_systems = self.init_social_systems()\n
"},{"location":"reference/#src.aeiva.society.society.Society.add_agent","title":"add_agent(agent_id, agent)
","text":"Add a new agent to the society.
Parameters:
Name Type Description Default agent_id
str
The unique identifier of the agent.
required agent
Any
The agent object to add to the society.
required Source code in src/aeiva/society/society.py
def add_agent(self, agent_id: str, agent: Any) -> None:\n \"\"\"\n Add a new agent to the society.\n\n Args:\n agent_id (str): The unique identifier of the agent.\n agent (Any): The agent object to add to the society.\n \"\"\"\n self.agents[agent_id] = agent\n
"},{"location":"reference/#src.aeiva.society.society.Society.get_agent","title":"get_agent(agent_id)
","text":"Retrieve an agent by its ID.
Parameters:
Name Type Description Default agent_id
str
The unique identifier of the agent.
required Returns:
Name Type Description Any
Any
The agent object, if found.
Source code in src/aeiva/society/society.py
def get_agent(self, agent_id: str) -> Any:\n \"\"\"\n Retrieve an agent by its ID.\n\n Args:\n agent_id (str): The unique identifier of the agent.\n\n Returns:\n Any: The agent object, if found.\n \"\"\"\n return self.agents.get(agent_id, None)\n
"},{"location":"reference/#src.aeiva.society.society.Society.handle_error","title":"handle_error(error)
","text":"Handle errors that occur during society operations.
Parameters:
Name Type Description Default error
Exception
The exception that was raised.
required Source code in src/aeiva/society/society.py
def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during society operations.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n print(f\"Society encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.society.society.Society.init_social_systems","title":"init_social_systems()
abstractmethod
","text":"Initialize the social systems that operate within the society (e.g., communication, financial, law, political, social network systems).
Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary of initialized social systems.
Source code in src/aeiva/society/society.py
@abstractmethod\ndef init_social_systems(self) -> Dict[str, Any]:\n \"\"\"\n Initialize the social systems that operate within the society (e.g., communication, financial, law, political, social network systems).\n\n Returns:\n Dict[str, Any]: A dictionary of initialized social systems.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.society.society.Society.remove_agent","title":"remove_agent(agent_id)
","text":"Remove an agent from the society by its ID.
Parameters:
Name Type Description Default agent_id
str
The unique identifier of the agent.
required Source code in src/aeiva/society/society.py
def remove_agent(self, agent_id: str) -> None:\n \"\"\"\n Remove an agent from the society by its ID.\n\n Args:\n agent_id (str): The unique identifier of the agent.\n \"\"\"\n if agent_id in self.agents:\n del self.agents[agent_id]\n
"},{"location":"reference/#src.aeiva.society.society.Society.run","title":"run()
abstractmethod
async
","text":"Asynchronously run the society, managing interactions between agents and the environment.
This method should control the flow of interactions between agents and the environment, and it can be designed as a continuous loop or a task-based execution.
Source code in src/aeiva/society/society.py
@abstractmethod\nasync def run(self) -> None:\n \"\"\"\n Asynchronously run the society, managing interactions between agents and the environment.\n\n This method should control the flow of interactions between agents and the environment,\n and it can be designed as a continuous loop or a task-based execution.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.society.society.Society.setup","title":"setup()
abstractmethod
async
","text":"Asynchronously set up the society's components, such as initializing the environment and agents.
Source code in src/aeiva/society/society.py
@abstractmethod\nasync def setup(self) -> None:\n \"\"\"\n Asynchronously set up the society's components, such as initializing the environment and agents.\n \"\"\"\n await self.env.setup()\n await asyncio.gather(*(agent.setup() for agent in self.agents.values()))\n print(\"Society: Setup completed.\")\n
"},{"location":"reference/#src.aeiva.storage","title":"storage
","text":""},{"location":"reference/#src.aeiva.storage.azure_ai_search","title":"azure_ai_search
","text":""},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_config","title":"azure_ai_search_config
","text":""},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_config.AzureAISearchConfig","title":"AzureAISearchConfig
dataclass
","text":" Bases: BaseConfig
Configuration for Azure Cognitive Search vector database.
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_config.py
@dataclass\nclass AzureAISearchConfig(BaseConfig):\n \"\"\"\n Configuration for Azure Cognitive Search vector database.\n \"\"\"\n\n collection_name: str = field(\n default=\"mem0\",\n metadata={\"help\": \"Name of the collection (index name).\"}\n )\n service_name: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Azure Cognitive Search service name.\"}\n )\n api_key: Optional[str] = field(\n default=None,\n metadata={\"help\": \"API key for the Azure Cognitive Search service.\"}\n )\n embedding_model_dims: int = field(\n default=1536,\n metadata={\"help\": \"Dimension of the embedding vector.\"}\n )\n use_compression: bool = field(\n default=False,\n metadata={\"help\": \"Whether to use scalar quantization vector compression.\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n # Validate that service_name and api_key are provided\n if not self.service_name or not self.api_key:\n raise ValueError(\"Both 'service_name' and 'api_key' must be provided.\")\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database","title":"azure_ai_search_database
","text":""},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase","title":"AzureAISearchDatabase
","text":" Bases: VectorDatabase
Concrete implementation of VectorStoreBase using Azure Cognitive Search.
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
class AzureAISearchDatabase(VectorDatabase):\n \"\"\"\n Concrete implementation of VectorStoreBase using Azure Cognitive Search.\n \"\"\"\n\n def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the Azure Cognitive Search vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.index_name = config.get('collection_name')\n self.service_name = config.get('service_name')\n self.api_key = config.get('api_key')\n self.embedding_model_dims = config.get('embedding_model_dims')\n self.use_compression = config.get('use_compression', False)\n\n if not all([self.service_name, self.api_key, self.index_name, self.embedding_model_dims]):\n raise ValueError(\"Required configuration parameters are missing.\")\n\n self.create_client(\n uri=None,\n service_name=self.service_name,\n api_key=self.api_key\n )\n self.create_collection(\n collection_name=self.index_name,\n vector_size=self.embedding_model_dims,\n distance_metric='cosine'\n )\n\n def create_client(\n self,\n uri: Optional[str] = None,\n service_name: Optional[str] = None,\n api_key: Optional[str] = None,\n **kwargs\n ) -> None:\n \"\"\"\n Initializes the client connection to the vector store.\n\n Args:\n uri (Optional[str]): Not used for Azure Cognitive Search.\n service_name (str): Azure Cognitive Search service name.\n api_key (str): API key for the Azure Cognitive Search service.\n **kwargs: Additional parameters.\n \"\"\"\n if not service_name or not api_key:\n raise ValueError(\"Both 'service_name' and 'api_key' must be provided.\")\n\n endpoint = f\"https://{service_name}.search.windows.net\"\n credential = AzureKeyCredential(api_key)\n self.search_client = SearchClient(\n endpoint=endpoint,\n index_name=self.index_name,\n credential=credential\n )\n self.index_client = SearchIndexClient(\n endpoint=endpoint,\n credential=credential\n )\n\n def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection (index) in Azure Cognitive Search.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use (e.g., 'cosine').\n \"\"\"\n # Check if the index already exists\n try:\n self.index_client.get_index(collection_name)\n logger.info(f\"Index {collection_name} already exists. Skipping creation.\")\n return\n except ResourceNotFoundError:\n pass # Index does not exist, proceed to create\n\n if self.use_compression:\n vector_type = \"Collection(Edm.Half)\"\n compression_name = \"myCompression\"\n compression_configurations = [ScalarQuantizationCompression(compression_name=compression_name)]\n else:\n vector_type = \"Collection(Edm.Single)\"\n compression_name = None\n compression_configurations = []\n\n fields = [\n SimpleField(name=\"id\", type=SearchFieldDataType.String, key=True),\n SearchField(\n name=\"vector\",\n type=vector_type,\n searchable=True,\n vector_search_dimensions=vector_size,\n vector_search_profile_name=\"my-vector-config\",\n ),\n SimpleField(name=\"payload\", type=SearchFieldDataType.String, searchable=True),\n ]\n\n vector_search = VectorSearch(\n profiles=[\n VectorSearchProfile(name=\"my-vector-config\", algorithm_configuration_name=\"my-algorithms-config\")\n ],\n algorithms=[HnswAlgorithmConfiguration(name=\"my-algorithms-config\")],\n compressions=compression_configurations,\n )\n index = SearchIndex(name=collection_name, fields=fields, vector_search=vector_search)\n self.index_client.create_or_update_index(index)\n logger.info(f\"Index {collection_name} created successfully.\")\n\n def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n ) -> None:\n \"\"\"\n Insert vectors into the index.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n if ids is None:\n ids = [str(i) for i in range(len(vectors))]\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n documents = [\n {\"id\": id_, \"vector\": vector, \"payload\": json.dumps(payload)}\n for id_, vector, payload in zip(ids, vectors, payloads)\n ]\n self.search_client.upload_documents(documents)\n logger.info(f\"Inserted {len(vectors)} vectors into index {collection_name}.\")\n\n def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n vector_query = VectorizedQuery(vector=query_vector, k_nearest_neighbors=top_k, fields=\"vector\")\n search_results = self.search_client.search(vector_queries=[vector_query], top=top_k)\n\n results = []\n for result in search_results:\n payload = json.loads(result[\"payload\"])\n if filters:\n for key, value in filters.items():\n if key not in payload or payload[key] != value:\n continue\n result_dict = {\n \"id\": result[\"id\"],\n \"score\": result[\"@search.score\"],\n \"payload\": payload\n }\n results.append(result_dict)\n return results\n\n def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n self.search_client.delete_documents(documents=[{\"id\": vector_id}])\n logger.info(f\"Deleted vector with ID {vector_id} from collection {collection_name}.\")\n\n def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n ) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n document = {\"id\": vector_id}\n if vector is not None:\n document[\"vector\"] = vector\n if payload is not None:\n document[\"payload\"] = json.dumps(payload)\n self.search_client.merge_or_upload_documents(documents=[document])\n logger.info(f\"Updated vector with ID {vector_id} in collection {collection_name}.\")\n\n def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n try:\n result = self.search_client.get_document(key=vector_id)\n payload = json.loads(result[\"payload\"])\n vector_data = {\n \"id\": result[\"id\"],\n \"vector\": result[\"vector\"],\n \"payload\": payload\n }\n return vector_data\n except ResourceNotFoundError:\n raise KeyError(f\"Vector with ID {vector_id} not found in collection {collection_name}.\")\n\n def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections.\n\n Returns:\n List[str]: A list of collection names.\n \"\"\"\n indexes = self.index_client.list_indexes()\n return [index.name for index in indexes]\n\n def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n \"\"\"\n self.index_client.delete_index(collection_name)\n logger.info(f\"Deleted collection {collection_name}.\")\n\n def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection.\n \"\"\"\n index = self.index_client.get_index(collection_name)\n return {\n \"name\": index.name,\n \"fields\": [field.name for field in index.fields],\n \"vector_search\": index.vector_search\n }\n\n def __del__(self):\n \"\"\"Clean up resources.\"\"\"\n self.search_client.close()\n self.index_client.close()\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.__del__","title":"__del__()
","text":"Clean up resources.
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def __del__(self):\n \"\"\"Clean up resources.\"\"\"\n self.search_client.close()\n self.index_client.close()\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.__init__","title":"__init__(config)
","text":"Initialize the Azure Cognitive Search vector store.
Parameters:
Name Type Description Default config
Dict[str, Any]
Configuration dictionary.
required Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the Azure Cognitive Search vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.index_name = config.get('collection_name')\n self.service_name = config.get('service_name')\n self.api_key = config.get('api_key')\n self.embedding_model_dims = config.get('embedding_model_dims')\n self.use_compression = config.get('use_compression', False)\n\n if not all([self.service_name, self.api_key, self.index_name, self.embedding_model_dims]):\n raise ValueError(\"Required configuration parameters are missing.\")\n\n self.create_client(\n uri=None,\n service_name=self.service_name,\n api_key=self.api_key\n )\n self.create_collection(\n collection_name=self.index_name,\n vector_size=self.embedding_model_dims,\n distance_metric='cosine'\n )\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.create_client","title":"create_client(uri=None, service_name=None, api_key=None, **kwargs)
","text":"Initializes the client connection to the vector store.
Parameters:
Name Type Description Default uri
Optional[str]
Not used for Azure Cognitive Search.
None
service_name
str
Azure Cognitive Search service name.
None
api_key
str
API key for the Azure Cognitive Search service.
None
**kwargs
Additional parameters.
{}
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def create_client(\n self,\n uri: Optional[str] = None,\n service_name: Optional[str] = None,\n api_key: Optional[str] = None,\n **kwargs\n) -> None:\n \"\"\"\n Initializes the client connection to the vector store.\n\n Args:\n uri (Optional[str]): Not used for Azure Cognitive Search.\n service_name (str): Azure Cognitive Search service name.\n api_key (str): API key for the Azure Cognitive Search service.\n **kwargs: Additional parameters.\n \"\"\"\n if not service_name or not api_key:\n raise ValueError(\"Both 'service_name' and 'api_key' must be provided.\")\n\n endpoint = f\"https://{service_name}.search.windows.net\"\n credential = AzureKeyCredential(api_key)\n self.search_client = SearchClient(\n endpoint=endpoint,\n index_name=self.index_name,\n credential=credential\n )\n self.index_client = SearchIndexClient(\n endpoint=endpoint,\n credential=credential\n )\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.create_collection","title":"create_collection(collection_name, vector_size, distance_metric)
","text":"Create a new vector collection (index) in Azure Cognitive Search.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_size
int
The dimensionality of the vectors.
required distance_metric
str
The distance metric to use (e.g., 'cosine').
required Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection (index) in Azure Cognitive Search.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use (e.g., 'cosine').\n \"\"\"\n # Check if the index already exists\n try:\n self.index_client.get_index(collection_name)\n logger.info(f\"Index {collection_name} already exists. Skipping creation.\")\n return\n except ResourceNotFoundError:\n pass # Index does not exist, proceed to create\n\n if self.use_compression:\n vector_type = \"Collection(Edm.Half)\"\n compression_name = \"myCompression\"\n compression_configurations = [ScalarQuantizationCompression(compression_name=compression_name)]\n else:\n vector_type = \"Collection(Edm.Single)\"\n compression_name = None\n compression_configurations = []\n\n fields = [\n SimpleField(name=\"id\", type=SearchFieldDataType.String, key=True),\n SearchField(\n name=\"vector\",\n type=vector_type,\n searchable=True,\n vector_search_dimensions=vector_size,\n vector_search_profile_name=\"my-vector-config\",\n ),\n SimpleField(name=\"payload\", type=SearchFieldDataType.String, searchable=True),\n ]\n\n vector_search = VectorSearch(\n profiles=[\n VectorSearchProfile(name=\"my-vector-config\", algorithm_configuration_name=\"my-algorithms-config\")\n ],\n algorithms=[HnswAlgorithmConfiguration(name=\"my-algorithms-config\")],\n compressions=compression_configurations,\n )\n index = SearchIndex(name=collection_name, fields=fields, vector_search=vector_search)\n self.index_client.create_or_update_index(index)\n logger.info(f\"Index {collection_name} created successfully.\")\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.delete_collection","title":"delete_collection(collection_name)
","text":"Delete an entire vector collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection to delete.
required Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n \"\"\"\n self.index_client.delete_index(collection_name)\n logger.info(f\"Deleted collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.delete_vector","title":"delete_vector(collection_name, vector_id)
","text":"Delete a vector from a collection by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to delete.
required Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n self.search_client.delete_documents(documents=[{\"id\": vector_id}])\n logger.info(f\"Deleted vector with ID {vector_id} from collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.get_collection_info","title":"get_collection_info(collection_name)
","text":"Get information about a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: Information about the collection.
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection.\n \"\"\"\n index = self.index_client.get_index(collection_name)\n return {\n \"name\": index.name,\n \"fields\": [field.name for field in index.fields],\n \"vector_search\": index.vector_search\n }\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.get_vector","title":"get_vector(collection_name, vector_id)
","text":"Retrieve a vector by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the vector data and payload.
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n try:\n result = self.search_client.get_document(key=vector_id)\n payload = json.loads(result[\"payload\"])\n vector_data = {\n \"id\": result[\"id\"],\n \"vector\": result[\"vector\"],\n \"payload\": payload\n }\n return vector_data\n except ResourceNotFoundError:\n raise KeyError(f\"Vector with ID {vector_id} not found in collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.insert_vectors","title":"insert_vectors(collection_name, vectors, payloads=None, ids=None)
","text":"Insert vectors into the index.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vectors
List[List[float]]
A list of vectors to insert.
required payloads
Optional[List[Dict[str, Any]]]
Optional metadata associated with each vector.
None
ids
Optional[List[str]]
Optional unique identifiers for each vector.
None
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n) -> None:\n \"\"\"\n Insert vectors into the index.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n if ids is None:\n ids = [str(i) for i in range(len(vectors))]\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n documents = [\n {\"id\": id_, \"vector\": vector, \"payload\": json.dumps(payload)}\n for id_, vector, payload in zip(ids, vectors, payloads)\n ]\n self.search_client.upload_documents(documents)\n logger.info(f\"Inserted {len(vectors)} vectors into index {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.list_collections","title":"list_collections()
","text":"List all available vector collections.
Returns:
Type Description List[str]
List[str]: A list of collection names.
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections.\n\n Returns:\n List[str]: A list of collection names.\n \"\"\"\n indexes = self.index_client.list_indexes()\n return [index.name for index in indexes]\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.search_vectors","title":"search_vectors(collection_name, query_vector, top_k=5, filters=None)
","text":"Search for similar vectors in a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required query_vector
List[float]
The vector to search with.
required top_k
int
The number of top results to return.
5
filters
Optional[Dict[str, Any]]
Optional filters to apply to the search.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of search results.
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n vector_query = VectorizedQuery(vector=query_vector, k_nearest_neighbors=top_k, fields=\"vector\")\n search_results = self.search_client.search(vector_queries=[vector_query], top=top_k)\n\n results = []\n for result in search_results:\n payload = json.loads(result[\"payload\"])\n if filters:\n for key, value in filters.items():\n if key not in payload or payload[key] != value:\n continue\n result_dict = {\n \"id\": result[\"id\"],\n \"score\": result[\"@search.score\"],\n \"payload\": payload\n }\n results.append(result_dict)\n return results\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.update_vector","title":"update_vector(collection_name, vector_id, vector=None, payload=None)
","text":"Update a vector's data or payload.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to update.
required vector
Optional[List[float]]
The new vector data.
None
payload
Optional[Dict[str, Any]]
The new payload data.
None
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n document = {\"id\": vector_id}\n if vector is not None:\n document[\"vector\"] = vector\n if payload is not None:\n document[\"payload\"] = json.dumps(payload)\n self.search_client.merge_or_upload_documents(documents=[document])\n logger.info(f\"Updated vector with ID {vector_id} in collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.chroma","title":"chroma
","text":""},{"location":"reference/#src.aeiva.storage.chroma.chroma_config","title":"chroma_config
","text":""},{"location":"reference/#src.aeiva.storage.chroma.chroma_config.ChromaConfig","title":"ChromaConfig
dataclass
","text":" Bases: BaseConfig
Configuration for ChromaDB vector database.
Source code in src/aeiva/storage/chroma/chroma_config.py
@dataclass\nclass ChromaConfig(BaseConfig):\n \"\"\"\n Configuration for ChromaDB vector database.\n \"\"\"\n\n collection_name: str = field(\n default=\"mem0\",\n metadata={\"help\": \"Name of the collection.\"}\n )\n client: Optional[Any] = field(\n default=None,\n metadata={\"help\": \"Existing ChromaDB client instance (if any).\"}\n )\n path: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Path to the database directory for local storage.\"}\n )\n host: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Remote host address for ChromaDB.\"}\n )\n port: Optional[int] = field(\n default=None,\n metadata={\"help\": \"Remote port for ChromaDB.\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n # Validate that either path or host and port are provided\n if not self.path and not (self.host and self.port):\n raise ValueError(\"Either 'path' for local storage or both 'host' and 'port' for remote connection must be provided.\")\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database","title":"chroma_database
","text":""},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase","title":"ChromaDatabase
","text":" Bases: VectorDatabase
Concrete implementation of VectorStoreBase using ChromaDB.
Source code in src/aeiva/storage/chroma/chroma_database.py
class ChromaDatabase(VectorDatabase):\n \"\"\"\n Concrete implementation of VectorStoreBase using ChromaDB.\n \"\"\"\n\n def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the ChromaDB vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.collection_name = config.get('collection_name')\n self.client = config.get('client')\n self.host = config.get('host')\n self.port = config.get('port')\n self.path = config.get('path')\n\n if not self.collection_name:\n raise ValueError(\"Collection name must be provided in the configuration.\")\n\n self.create_client(\n host=self.host,\n port=self.port,\n path=self.path\n )\n self.create_collection(\n collection_name=self.collection_name,\n vector_size=None, # ChromaDB does not require specifying vector size upfront\n distance_metric='cosine'\n )\n\n def create_client(\n self,\n uri: Optional[str] = None,\n host: Optional[str] = None,\n port: Optional[int] = None,\n path: Optional[str] = None,\n **kwargs\n ) -> None:\n \"\"\"\n Initializes the client connection to the vector store.\n\n Args:\n uri (Optional[str]): Not used for ChromaDB.\n host (Optional[str]): Host address for ChromaDB server.\n port (Optional[int]): Port for ChromaDB server.\n path (Optional[str]): Path to the database directory.\n **kwargs: Additional parameters.\n \"\"\"\n if self.client:\n return # Client already provided\n\n settings = Settings(anonymized_telemetry=False)\n\n if host and port:\n settings.chroma_api_impl = \"chromadb.api.fastapi.FastAPI\"\n settings.chroma_server_host = host\n settings.chroma_server_http_port = port\n else:\n if not path:\n path = \"db\"\n settings.persist_directory = path\n settings.is_persistent = True\n\n self.client = chromadb.Client(settings)\n logger.info(\"ChromaDB client initialized.\")\n\n def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection in ChromaDB.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): Not used for ChromaDB.\n distance_metric (str): Not used for ChromaDB.\n \"\"\"\n # Check if collection exists\n existing_collections = self.list_collections()\n if collection_name in existing_collections:\n logger.info(f\"Collection {collection_name} already exists. Skipping creation.\")\n self.collection = self.client.get_collection(name=collection_name)\n else:\n self.collection = self.client.create_collection(name=collection_name)\n logger.info(f\"Collection {collection_name} created successfully.\")\n\n def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n ) -> None:\n \"\"\"\n Insert vectors into a collection.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n if ids is None:\n ids = [str(i) for i in range(len(vectors))]\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)\n logger.info(f\"Inserted {len(vectors)} vectors into collection {collection_name}.\")\n\n def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n results = self.collection.query(\n query_embeddings=[query_vector],\n where=filters,\n n_results=top_k\n )\n # Parse the results\n output = []\n for idx, (ids, distances, metadatas) in enumerate(zip(results['ids'], results['distances'], results['metadatas'])):\n for i in range(len(ids)):\n result = {\n 'id': ids[i],\n 'score': distances[i],\n 'payload': metadatas[i]\n }\n output.append(result)\n return output\n\n def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n self.collection.delete(ids=[vector_id])\n logger.info(f\"Deleted vector with ID {vector_id} from collection {collection_name}.\")\n\n def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n ) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n self.collection.update(ids=[vector_id], embeddings=[vector] if vector else None, metadatas=[payload] if payload else None)\n logger.info(f\"Updated vector with ID {vector_id} in collection {collection_name}.\")\n\n def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n result = self.collection.get(ids=[vector_id])\n if not result['ids']:\n raise KeyError(f\"Vector with ID {vector_id} not found in collection {collection_name}.\")\n\n vector_data = {\n 'id': result['ids'][0],\n 'vector': result['embeddings'][0] if 'embeddings' in result else None,\n 'payload': result['metadatas'][0]\n }\n return vector_data\n\n def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections.\n\n Returns:\n List[str]: A list of collection names.\n \"\"\"\n collections = self.client.list_collections()\n return [collection.name for collection in collections]\n\n def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n \"\"\"\n self.client.delete_collection(name=collection_name)\n logger.info(f\"Deleted collection {collection_name}.\")\n\n def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection.\n \"\"\"\n collection = self.client.get_collection(name=collection_name)\n return {\n 'name': collection.name,\n 'metadata': collection.metadata\n }\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.__init__","title":"__init__(config)
","text":"Initialize the ChromaDB vector store.
Parameters:
Name Type Description Default config
Dict[str, Any]
Configuration dictionary.
required Source code in src/aeiva/storage/chroma/chroma_database.py
def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the ChromaDB vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.collection_name = config.get('collection_name')\n self.client = config.get('client')\n self.host = config.get('host')\n self.port = config.get('port')\n self.path = config.get('path')\n\n if not self.collection_name:\n raise ValueError(\"Collection name must be provided in the configuration.\")\n\n self.create_client(\n host=self.host,\n port=self.port,\n path=self.path\n )\n self.create_collection(\n collection_name=self.collection_name,\n vector_size=None, # ChromaDB does not require specifying vector size upfront\n distance_metric='cosine'\n )\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.create_client","title":"create_client(uri=None, host=None, port=None, path=None, **kwargs)
","text":"Initializes the client connection to the vector store.
Parameters:
Name Type Description Default uri
Optional[str]
Not used for ChromaDB.
None
host
Optional[str]
Host address for ChromaDB server.
None
port
Optional[int]
Port for ChromaDB server.
None
path
Optional[str]
Path to the database directory.
None
**kwargs
Additional parameters.
{}
Source code in src/aeiva/storage/chroma/chroma_database.py
def create_client(\n self,\n uri: Optional[str] = None,\n host: Optional[str] = None,\n port: Optional[int] = None,\n path: Optional[str] = None,\n **kwargs\n) -> None:\n \"\"\"\n Initializes the client connection to the vector store.\n\n Args:\n uri (Optional[str]): Not used for ChromaDB.\n host (Optional[str]): Host address for ChromaDB server.\n port (Optional[int]): Port for ChromaDB server.\n path (Optional[str]): Path to the database directory.\n **kwargs: Additional parameters.\n \"\"\"\n if self.client:\n return # Client already provided\n\n settings = Settings(anonymized_telemetry=False)\n\n if host and port:\n settings.chroma_api_impl = \"chromadb.api.fastapi.FastAPI\"\n settings.chroma_server_host = host\n settings.chroma_server_http_port = port\n else:\n if not path:\n path = \"db\"\n settings.persist_directory = path\n settings.is_persistent = True\n\n self.client = chromadb.Client(settings)\n logger.info(\"ChromaDB client initialized.\")\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.create_collection","title":"create_collection(collection_name, vector_size, distance_metric)
","text":"Create a new vector collection in ChromaDB.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_size
int
Not used for ChromaDB.
required distance_metric
str
Not used for ChromaDB.
required Source code in src/aeiva/storage/chroma/chroma_database.py
def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection in ChromaDB.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): Not used for ChromaDB.\n distance_metric (str): Not used for ChromaDB.\n \"\"\"\n # Check if collection exists\n existing_collections = self.list_collections()\n if collection_name in existing_collections:\n logger.info(f\"Collection {collection_name} already exists. Skipping creation.\")\n self.collection = self.client.get_collection(name=collection_name)\n else:\n self.collection = self.client.create_collection(name=collection_name)\n logger.info(f\"Collection {collection_name} created successfully.\")\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.delete_collection","title":"delete_collection(collection_name)
","text":"Delete an entire vector collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection to delete.
required Source code in src/aeiva/storage/chroma/chroma_database.py
def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n \"\"\"\n self.client.delete_collection(name=collection_name)\n logger.info(f\"Deleted collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.delete_vector","title":"delete_vector(collection_name, vector_id)
","text":"Delete a vector from a collection by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to delete.
required Source code in src/aeiva/storage/chroma/chroma_database.py
def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n self.collection.delete(ids=[vector_id])\n logger.info(f\"Deleted vector with ID {vector_id} from collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.get_collection_info","title":"get_collection_info(collection_name)
","text":"Get information about a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: Information about the collection.
Source code in src/aeiva/storage/chroma/chroma_database.py
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection.\n \"\"\"\n collection = self.client.get_collection(name=collection_name)\n return {\n 'name': collection.name,\n 'metadata': collection.metadata\n }\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.get_vector","title":"get_vector(collection_name, vector_id)
","text":"Retrieve a vector by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the vector data and payload.
Source code in src/aeiva/storage/chroma/chroma_database.py
def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n result = self.collection.get(ids=[vector_id])\n if not result['ids']:\n raise KeyError(f\"Vector with ID {vector_id} not found in collection {collection_name}.\")\n\n vector_data = {\n 'id': result['ids'][0],\n 'vector': result['embeddings'][0] if 'embeddings' in result else None,\n 'payload': result['metadatas'][0]\n }\n return vector_data\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.insert_vectors","title":"insert_vectors(collection_name, vectors, payloads=None, ids=None)
","text":"Insert vectors into a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vectors
List[List[float]]
A list of vectors to insert.
required payloads
Optional[List[Dict[str, Any]]]
Optional metadata associated with each vector.
None
ids
Optional[List[str]]
Optional unique identifiers for each vector.
None
Source code in src/aeiva/storage/chroma/chroma_database.py
def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n) -> None:\n \"\"\"\n Insert vectors into a collection.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n if ids is None:\n ids = [str(i) for i in range(len(vectors))]\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)\n logger.info(f\"Inserted {len(vectors)} vectors into collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.list_collections","title":"list_collections()
","text":"List all available vector collections.
Returns:
Type Description List[str]
List[str]: A list of collection names.
Source code in src/aeiva/storage/chroma/chroma_database.py
def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections.\n\n Returns:\n List[str]: A list of collection names.\n \"\"\"\n collections = self.client.list_collections()\n return [collection.name for collection in collections]\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.search_vectors","title":"search_vectors(collection_name, query_vector, top_k=5, filters=None)
","text":"Search for similar vectors in a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required query_vector
List[float]
The vector to search with.
required top_k
int
The number of top results to return.
5
filters
Optional[Dict[str, Any]]
Optional filters to apply to the search.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of search results.
Source code in src/aeiva/storage/chroma/chroma_database.py
def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n results = self.collection.query(\n query_embeddings=[query_vector],\n where=filters,\n n_results=top_k\n )\n # Parse the results\n output = []\n for idx, (ids, distances, metadatas) in enumerate(zip(results['ids'], results['distances'], results['metadatas'])):\n for i in range(len(ids)):\n result = {\n 'id': ids[i],\n 'score': distances[i],\n 'payload': metadatas[i]\n }\n output.append(result)\n return output\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.update_vector","title":"update_vector(collection_name, vector_id, vector=None, payload=None)
","text":"Update a vector's data or payload.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to update.
required vector
Optional[List[float]]
The new vector data.
None
payload
Optional[Dict[str, Any]]
The new payload data.
None
Source code in src/aeiva/storage/chroma/chroma_database.py
def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n self.collection.update(ids=[vector_id], embeddings=[vector] if vector else None, metadatas=[payload] if payload else None)\n logger.info(f\"Updated vector with ID {vector_id} in collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.database_factory","title":"database_factory
","text":""},{"location":"reference/#src.aeiva.storage.database_factory.DatabaseConfigFactory","title":"DatabaseConfigFactory
","text":"Factory class to create database configuration objects based on the provider name.
Example config = DatabaseConfigFactory.create( 'milvus', host='localhost', port=19530, embedding_model_dims=128, ... )
Source code in src/aeiva/storage/database_factory.py
class DatabaseConfigFactory:\n \"\"\"\n Factory class to create database configuration objects based on the provider name.\n\n Example:\n config = DatabaseConfigFactory.create(\n 'milvus',\n host='localhost',\n port=19530,\n embedding_model_dims=128,\n ...\n )\n \"\"\"\n\n provider_to_class = {\n \"milvus\": \"aeiva.storage.milvus.milvus_config.MilvusConfig\",\n \"chroma\": \"aeiva.storage.chroma.chroma_config.ChromaConfig\",\n \"azure_ai_search\": \"aeiva.storage.azure_ai_search.azure_ai_search_config.AzureAISearchConfig\",\n \"pgvector\": \"aeiva.storage.pgvector.pgvector_config.PGVectorConfig\",\n \"qdrant\": \"aeiva.storage.qdrant.qdrant_config.QdrantConfig\",\n \"neo4j\": \"aeiva.storage.neo4jdb.neo4j_config.Neo4jConfig\",\n \"sqlite\": \"aeiva.storage.sqlite.sqlite_config.SQLiteConfig\",\n \"postgresql\": \"aeiva.storage.postgresql.postgresql_config.PostgreSQLConfig\",\n \"weaviate\": \"aeiva.storage.weaviate.weaviate_config.WeaviateConfig\",\n }\n\n @classmethod\n def create(cls, provider_name: str, **kwargs) -> Any:\n \"\"\"\n Create a database configuration object based on the provider name.\n\n Args:\n provider_name (str): The name of the database provider (e.g., 'milvus', 'chroma').\n **kwargs: Configuration parameters specific to the database provider.\n\n Returns:\n Any: An instance of the database configuration class.\n\n Raises:\n ValueError: If the provider name is not supported.\n ImportError: If the configuration class cannot be imported.\n \"\"\"\n class_path = cls.provider_to_class.get(provider_name.lower())\n if class_path:\n config_class = load_class(class_path)\n return config_class(**kwargs)\n else:\n raise ValueError(f\"Unsupported database provider: {provider_name}\")\n
"},{"location":"reference/#src.aeiva.storage.database_factory.DatabaseConfigFactory.create","title":"create(provider_name, **kwargs)
classmethod
","text":"Create a database configuration object based on the provider name.
Parameters:
Name Type Description Default provider_name
str
The name of the database provider (e.g., 'milvus', 'chroma').
required **kwargs
Configuration parameters specific to the database provider.
{}
Returns:
Name Type Description Any
Any
An instance of the database configuration class.
Raises:
Type Description ValueError
If the provider name is not supported.
ImportError
If the configuration class cannot be imported.
Source code in src/aeiva/storage/database_factory.py
@classmethod\ndef create(cls, provider_name: str, **kwargs) -> Any:\n \"\"\"\n Create a database configuration object based on the provider name.\n\n Args:\n provider_name (str): The name of the database provider (e.g., 'milvus', 'chroma').\n **kwargs: Configuration parameters specific to the database provider.\n\n Returns:\n Any: An instance of the database configuration class.\n\n Raises:\n ValueError: If the provider name is not supported.\n ImportError: If the configuration class cannot be imported.\n \"\"\"\n class_path = cls.provider_to_class.get(provider_name.lower())\n if class_path:\n config_class = load_class(class_path)\n return config_class(**kwargs)\n else:\n raise ValueError(f\"Unsupported database provider: {provider_name}\")\n
"},{"location":"reference/#src.aeiva.storage.database_factory.DatabaseFactory","title":"DatabaseFactory
","text":"Factory class to create database objects based on the provider name and configuration.
Example db = DatabaseFactory.create('milvus', config)
Source code in src/aeiva/storage/database_factory.py
class DatabaseFactory:\n \"\"\"\n Factory class to create database objects based on the provider name and configuration.\n\n Example:\n db = DatabaseFactory.create('milvus', config)\n \"\"\"\n\n provider_to_class = {\n \"milvus\": \"aeiva.storage.milvus.milvus_database.MilvusDatabase\",\n \"chroma\": \"aeiva.storage.chroma.chroma_database.ChromaDatabase\",\n \"azure_ai_search\": \"aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase\",\n \"pgvector\": \"aeiva.storage.pgvector.pgvector_database.PGVectorDatabase\",\n \"qdrant\": \"aeiva.storage.qdrant.qdrant_database.QdrantDatabase\",\n \"neo4j\": \"aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase\",\n \"sqlite\": \"aeiva.storage.sqlite.sqlite_database.SQLiteDatabase\",\n \"postgresql\": \"aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase\",\n \"weaviate\": \"aeiva.storage.weaviate.weaviate_database.WeaviateDatabase\",\n }\n\n @classmethod\n def create(cls, provider_name: str, config: Any) -> Any:\n \"\"\"\n Create a database object based on the provider name and configuration.\n\n Args:\n provider_name (str): The name of the database provider.\n config (Any): Configuration object or dictionary for the database.\n\n Returns:\n Any: An instance of the database class.\n\n Raises:\n ValueError: If the provider name is not supported.\n ImportError: If the database class cannot be imported.\n TypeError: If the configuration cannot be converted to a dictionary.\n \"\"\"\n class_path = cls.provider_to_class.get(provider_name.lower())\n if class_path:\n db_class = load_class(class_path)\n if isinstance(config, dict):\n return db_class(config)\n elif hasattr(config, 'to_dict'):\n # Assuming config is a dataclass with a 'to_dict' method\n return db_class(config.to_dict())\n elif hasattr(config, '__dict__'):\n # If config is a dataclass without 'to_dict', use __dict__\n return db_class(config.__dict__)\n else:\n raise TypeError(\n \"Config must be a dict or an object with 'to_dict' or '__dict__' method.\"\n )\n else:\n raise ValueError(f\"Unsupported database provider: {provider_name}\")\n
"},{"location":"reference/#src.aeiva.storage.database_factory.DatabaseFactory.create","title":"create(provider_name, config)
classmethod
","text":"Create a database object based on the provider name and configuration.
Parameters:
Name Type Description Default provider_name
str
The name of the database provider.
required config
Any
Configuration object or dictionary for the database.
required Returns:
Name Type Description Any
Any
An instance of the database class.
Raises:
Type Description ValueError
If the provider name is not supported.
ImportError
If the database class cannot be imported.
TypeError
If the configuration cannot be converted to a dictionary.
Source code in src/aeiva/storage/database_factory.py
@classmethod\ndef create(cls, provider_name: str, config: Any) -> Any:\n \"\"\"\n Create a database object based on the provider name and configuration.\n\n Args:\n provider_name (str): The name of the database provider.\n config (Any): Configuration object or dictionary for the database.\n\n Returns:\n Any: An instance of the database class.\n\n Raises:\n ValueError: If the provider name is not supported.\n ImportError: If the database class cannot be imported.\n TypeError: If the configuration cannot be converted to a dictionary.\n \"\"\"\n class_path = cls.provider_to_class.get(provider_name.lower())\n if class_path:\n db_class = load_class(class_path)\n if isinstance(config, dict):\n return db_class(config)\n elif hasattr(config, 'to_dict'):\n # Assuming config is a dataclass with a 'to_dict' method\n return db_class(config.to_dict())\n elif hasattr(config, '__dict__'):\n # If config is a dataclass without 'to_dict', use __dict__\n return db_class(config.__dict__)\n else:\n raise TypeError(\n \"Config must be a dict or an object with 'to_dict' or '__dict__' method.\"\n )\n else:\n raise ValueError(f\"Unsupported database provider: {provider_name}\")\n
"},{"location":"reference/#src.aeiva.storage.database_factory.load_class","title":"load_class(class_path)
","text":"Dynamically load a class from a string.
Parameters:
Name Type Description Default class_path
str
The full path to the class, e.g., 'module.submodule.ClassName'.
required Returns:
Name Type Description Type
Type
The class type.
Raises:
Type Description ImportError
If the module or class cannot be found.
Source code in src/aeiva/storage/database_factory.py
def load_class(class_path: str) -> Type:\n \"\"\"\n Dynamically load a class from a string.\n\n Args:\n class_path (str): The full path to the class, e.g., 'module.submodule.ClassName'.\n\n Returns:\n Type: The class type.\n\n Raises:\n ImportError: If the module or class cannot be found.\n \"\"\"\n try:\n module_path, class_name = class_path.rsplit('.', 1)\n module = importlib.import_module(module_path)\n return getattr(module, class_name)\n except (ImportError, AttributeError) as e:\n raise ImportError(f\"Cannot import '{class_name}' from '{module_path}': {e}\")\n
"},{"location":"reference/#src.aeiva.storage.graph_database","title":"graph_database
","text":""},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase","title":"GraphDatabase
","text":" Bases: ABC
Abstract base class for graph database operations.
Source code in src/aeiva/storage/graph_database.py
class GraphDatabase(ABC):\n \"\"\"\n Abstract base class for graph database operations.\n \"\"\"\n\n @abstractmethod\n def add_node(\n self, \n node_id: str, \n properties: Optional[Dict[str, Any]] = None, \n labels: Optional[List[str]] = None\n ) -> None:\n \"\"\"\n Adds a node to the graph.\n\n Args:\n node_id (str): Unique identifier for the node.\n properties (Optional[Dict[str, Any]]): Properties associated with the node.\n labels (Optional[List[str]]): Labels or types associated with the node.\n\n Raises:\n StorageError: If there is an issue adding the node.\n \"\"\"\n pass\n\n @abstractmethod\n def add_edge(\n self, \n source_id: str, \n target_id: str, \n relationship: str, \n properties: Optional[Dict[str, Any]] = None\n ) -> None:\n \"\"\"\n Adds an edge (relationship) between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship.\n properties (Optional[Dict[str, Any]]): Properties associated with the edge.\n\n Raises:\n NodeNotFoundError: If either the source or target node does not exist.\n StorageError: If there is an issue adding the edge.\n \"\"\"\n pass\n\n @abstractmethod\n def get_node(self, node_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieves a node by its identifier.\n\n Args:\n node_id (str): Unique identifier of the node.\n\n Returns:\n Dict[str, Any]: A dictionary containing the node's properties and labels.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue retrieving the node.\n \"\"\"\n pass\n\n @abstractmethod\n def update_node(self, node_id: str, properties: Dict[str, Any]) -> None:\n \"\"\"\n Updates properties of a node.\n\n Args:\n node_id (str): Unique identifier of the node.\n properties (Dict[str, Any]): Properties to update.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue updating the node.\n \"\"\"\n pass\n\n @abstractmethod\n def delete_node(self, node_id: str) -> None:\n \"\"\"\n Deletes a node from the graph.\n\n Args:\n node_id (str): Unique identifier of the node.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue deleting the node.\n \"\"\"\n pass\n\n @abstractmethod\n def delete_all(self) -> None:\n \"\"\"\n Deletes all nodes and their associated relationships from the graph.\n\n Raises:\n StorageError: If there is an issue deleting all nodes and relationships.\n \"\"\"\n pass\n\n @abstractmethod\n def delete_all_edges(self) -> None:\n \"\"\"\n Deletes all edges from the graph without deleting the nodes.\n\n Raises:\n StorageError: If there is an issue deleting all relationships.\n \"\"\"\n pass\n\n @abstractmethod\n def delete_relationships_by_type(self, relationship: str) -> None:\n \"\"\"\n Deletes all relationships of a specific type from the graph.\n\n Args:\n relationship (str): The type of relationships to delete.\n\n Raises:\n StorageError: If there is an issue deleting the relationships.\n \"\"\"\n pass\n\n @abstractmethod\n def delete_edge(\n self, \n source_id: str, \n target_id: str, \n relationship: str\n ) -> None:\n \"\"\"\n Deletes a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to delete.\n\n Raises:\n RelationshipNotFoundError: If the relationship does not exist.\n StorageError: If there is an issue deleting the relationship.\n \"\"\"\n pass\n\n @abstractmethod\n def update_edge(\n self, \n source_id: str, \n target_id: str, \n relationship: str, \n properties: Dict[str, Any]\n ) -> None:\n \"\"\"\n Updates properties of a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to update.\n properties (Dict[str, Any]): Properties to update on the relationship.\n\n Raises:\n RelationshipNotFoundError: If the relationship does not exist.\n StorageError: If there is an issue updating the relationship.\n \"\"\"\n pass\n\n @abstractmethod\n def get_relationship(\n self, \n source_id: str, \n target_id: str, \n relationship: str\n ) -> Dict[str, Any]:\n \"\"\"\n Retrieves a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to retrieve.\n\n Returns:\n Dict[str, Any]: A dictionary containing the relationship's properties.\n\n Raises:\n RelationshipNotFoundError: If the relationship does not exist.\n StorageError: If there is an issue retrieving the relationship.\n \"\"\"\n pass\n\n @abstractmethod\n def get_neighbors(\n self, \n node_id: str, \n relationship: Optional[str] = None, \n direction: str = \"both\"\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Retrieves neighboring nodes connected by edges.\n\n Args:\n node_id (str): Unique identifier of the node.\n relationship (Optional[str]): Filter by relationship type.\n direction (str): Direction of the relationships ('in', 'out', 'both').\n\n Returns:\n List[Dict[str, Any]]: A list of neighboring nodes.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue retrieving neighbors.\n \"\"\"\n pass\n\n @abstractmethod\n def query_nodes(\n self, \n properties: Dict[str, Any], \n labels: Optional[List[str]] = None\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Queries nodes based on properties and labels.\n\n Args:\n properties (Dict[str, Any]): Properties to filter nodes.\n labels (Optional[List[str]]): Labels to filter nodes.\n\n Returns:\n List[Dict[str, Any]]: A list of nodes matching the query.\n\n Raises:\n StorageError: If there is an issue querying nodes.\n \"\"\"\n pass\n\n @abstractmethod\n def execute_query(\n self, \n query: str, \n parameters: Optional[Dict[str, Any]] = None\n ) -> Any:\n \"\"\"\n Executes a raw query against the graph database.\n\n Args:\n query (str): The query string.\n parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.\n\n Returns:\n Any: The result of the query.\n\n Raises:\n StorageError: If there is an issue executing the query.\n \"\"\"\n pass\n\n @abstractmethod\n def close(self) -> None:\n \"\"\"\n Closes the graph database connection and releases resources.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.add_edge","title":"add_edge(source_id, target_id, relationship, properties=None)
abstractmethod
","text":"Adds an edge (relationship) between two nodes.
Parameters:
Name Type Description Default source_id
str
Unique identifier of the source node.
required target_id
str
Unique identifier of the target node.
required relationship
str
Type of the relationship.
required properties
Optional[Dict[str, Any]]
Properties associated with the edge.
None
Raises:
Type Description NodeNotFoundError
If either the source or target node does not exist.
StorageError
If there is an issue adding the edge.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef add_edge(\n self, \n source_id: str, \n target_id: str, \n relationship: str, \n properties: Optional[Dict[str, Any]] = None\n) -> None:\n \"\"\"\n Adds an edge (relationship) between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship.\n properties (Optional[Dict[str, Any]]): Properties associated with the edge.\n\n Raises:\n NodeNotFoundError: If either the source or target node does not exist.\n StorageError: If there is an issue adding the edge.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.add_node","title":"add_node(node_id, properties=None, labels=None)
abstractmethod
","text":"Adds a node to the graph.
Parameters:
Name Type Description Default node_id
str
Unique identifier for the node.
required properties
Optional[Dict[str, Any]]
Properties associated with the node.
None
labels
Optional[List[str]]
Labels or types associated with the node.
None
Raises:
Type Description StorageError
If there is an issue adding the node.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef add_node(\n self, \n node_id: str, \n properties: Optional[Dict[str, Any]] = None, \n labels: Optional[List[str]] = None\n) -> None:\n \"\"\"\n Adds a node to the graph.\n\n Args:\n node_id (str): Unique identifier for the node.\n properties (Optional[Dict[str, Any]]): Properties associated with the node.\n labels (Optional[List[str]]): Labels or types associated with the node.\n\n Raises:\n StorageError: If there is an issue adding the node.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.close","title":"close()
abstractmethod
","text":"Closes the graph database connection and releases resources.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef close(self) -> None:\n \"\"\"\n Closes the graph database connection and releases resources.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.delete_all","title":"delete_all()
abstractmethod
","text":"Deletes all nodes and their associated relationships from the graph.
Raises:
Type Description StorageError
If there is an issue deleting all nodes and relationships.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef delete_all(self) -> None:\n \"\"\"\n Deletes all nodes and their associated relationships from the graph.\n\n Raises:\n StorageError: If there is an issue deleting all nodes and relationships.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.delete_all_edges","title":"delete_all_edges()
abstractmethod
","text":"Deletes all edges from the graph without deleting the nodes.
Raises:
Type Description StorageError
If there is an issue deleting all relationships.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef delete_all_edges(self) -> None:\n \"\"\"\n Deletes all edges from the graph without deleting the nodes.\n\n Raises:\n StorageError: If there is an issue deleting all relationships.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.delete_edge","title":"delete_edge(source_id, target_id, relationship)
abstractmethod
","text":"Deletes a specific relationship between two nodes.
Parameters:
Name Type Description Default source_id
str
Unique identifier of the source node.
required target_id
str
Unique identifier of the target node.
required relationship
str
Type of the relationship to delete.
required Raises:
Type Description RelationshipNotFoundError
If the relationship does not exist.
StorageError
If there is an issue deleting the relationship.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef delete_edge(\n self, \n source_id: str, \n target_id: str, \n relationship: str\n) -> None:\n \"\"\"\n Deletes a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to delete.\n\n Raises:\n RelationshipNotFoundError: If the relationship does not exist.\n StorageError: If there is an issue deleting the relationship.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.delete_node","title":"delete_node(node_id)
abstractmethod
","text":"Deletes a node from the graph.
Parameters:
Name Type Description Default node_id
str
Unique identifier of the node.
required Raises:
Type Description NodeNotFoundError
If the node does not exist.
StorageError
If there is an issue deleting the node.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef delete_node(self, node_id: str) -> None:\n \"\"\"\n Deletes a node from the graph.\n\n Args:\n node_id (str): Unique identifier of the node.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue deleting the node.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.delete_relationships_by_type","title":"delete_relationships_by_type(relationship)
abstractmethod
","text":"Deletes all relationships of a specific type from the graph.
Parameters:
Name Type Description Default relationship
str
The type of relationships to delete.
required Raises:
Type Description StorageError
If there is an issue deleting the relationships.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef delete_relationships_by_type(self, relationship: str) -> None:\n \"\"\"\n Deletes all relationships of a specific type from the graph.\n\n Args:\n relationship (str): The type of relationships to delete.\n\n Raises:\n StorageError: If there is an issue deleting the relationships.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.execute_query","title":"execute_query(query, parameters=None)
abstractmethod
","text":"Executes a raw query against the graph database.
Parameters:
Name Type Description Default query
str
The query string.
required parameters
Optional[Dict[str, Any]]
Parameters for parameterized queries.
None
Returns:
Name Type Description Any
Any
The result of the query.
Raises:
Type Description StorageError
If there is an issue executing the query.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef execute_query(\n self, \n query: str, \n parameters: Optional[Dict[str, Any]] = None\n) -> Any:\n \"\"\"\n Executes a raw query against the graph database.\n\n Args:\n query (str): The query string.\n parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.\n\n Returns:\n Any: The result of the query.\n\n Raises:\n StorageError: If there is an issue executing the query.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.get_neighbors","title":"get_neighbors(node_id, relationship=None, direction='both')
abstractmethod
","text":"Retrieves neighboring nodes connected by edges.
Parameters:
Name Type Description Default node_id
str
Unique identifier of the node.
required relationship
Optional[str]
Filter by relationship type.
None
direction
str
Direction of the relationships ('in', 'out', 'both').
'both'
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of neighboring nodes.
Raises:
Type Description NodeNotFoundError
If the node does not exist.
StorageError
If there is an issue retrieving neighbors.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef get_neighbors(\n self, \n node_id: str, \n relationship: Optional[str] = None, \n direction: str = \"both\"\n) -> List[Dict[str, Any]]:\n \"\"\"\n Retrieves neighboring nodes connected by edges.\n\n Args:\n node_id (str): Unique identifier of the node.\n relationship (Optional[str]): Filter by relationship type.\n direction (str): Direction of the relationships ('in', 'out', 'both').\n\n Returns:\n List[Dict[str, Any]]: A list of neighboring nodes.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue retrieving neighbors.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.get_node","title":"get_node(node_id)
abstractmethod
","text":"Retrieves a node by its identifier.
Parameters:
Name Type Description Default node_id
str
Unique identifier of the node.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the node's properties and labels.
Raises:
Type Description NodeNotFoundError
If the node does not exist.
StorageError
If there is an issue retrieving the node.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef get_node(self, node_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieves a node by its identifier.\n\n Args:\n node_id (str): Unique identifier of the node.\n\n Returns:\n Dict[str, Any]: A dictionary containing the node's properties and labels.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue retrieving the node.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.get_relationship","title":"get_relationship(source_id, target_id, relationship)
abstractmethod
","text":"Retrieves a specific relationship between two nodes.
Parameters:
Name Type Description Default source_id
str
Unique identifier of the source node.
required target_id
str
Unique identifier of the target node.
required relationship
str
Type of the relationship to retrieve.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the relationship's properties.
Raises:
Type Description RelationshipNotFoundError
If the relationship does not exist.
StorageError
If there is an issue retrieving the relationship.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef get_relationship(\n self, \n source_id: str, \n target_id: str, \n relationship: str\n) -> Dict[str, Any]:\n \"\"\"\n Retrieves a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to retrieve.\n\n Returns:\n Dict[str, Any]: A dictionary containing the relationship's properties.\n\n Raises:\n RelationshipNotFoundError: If the relationship does not exist.\n StorageError: If there is an issue retrieving the relationship.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.query_nodes","title":"query_nodes(properties, labels=None)
abstractmethod
","text":"Queries nodes based on properties and labels.
Parameters:
Name Type Description Default properties
Dict[str, Any]
Properties to filter nodes.
required labels
Optional[List[str]]
Labels to filter nodes.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of nodes matching the query.
Raises:
Type Description StorageError
If there is an issue querying nodes.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef query_nodes(\n self, \n properties: Dict[str, Any], \n labels: Optional[List[str]] = None\n) -> List[Dict[str, Any]]:\n \"\"\"\n Queries nodes based on properties and labels.\n\n Args:\n properties (Dict[str, Any]): Properties to filter nodes.\n labels (Optional[List[str]]): Labels to filter nodes.\n\n Returns:\n List[Dict[str, Any]]: A list of nodes matching the query.\n\n Raises:\n StorageError: If there is an issue querying nodes.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.update_edge","title":"update_edge(source_id, target_id, relationship, properties)
abstractmethod
","text":"Updates properties of a specific relationship between two nodes.
Parameters:
Name Type Description Default source_id
str
Unique identifier of the source node.
required target_id
str
Unique identifier of the target node.
required relationship
str
Type of the relationship to update.
required properties
Dict[str, Any]
Properties to update on the relationship.
required Raises:
Type Description RelationshipNotFoundError
If the relationship does not exist.
StorageError
If there is an issue updating the relationship.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef update_edge(\n self, \n source_id: str, \n target_id: str, \n relationship: str, \n properties: Dict[str, Any]\n) -> None:\n \"\"\"\n Updates properties of a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to update.\n properties (Dict[str, Any]): Properties to update on the relationship.\n\n Raises:\n RelationshipNotFoundError: If the relationship does not exist.\n StorageError: If there is an issue updating the relationship.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.update_node","title":"update_node(node_id, properties)
abstractmethod
","text":"Updates properties of a node.
Parameters:
Name Type Description Default node_id
str
Unique identifier of the node.
required properties
Dict[str, Any]
Properties to update.
required Raises:
Type Description NodeNotFoundError
If the node does not exist.
StorageError
If there is an issue updating the node.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef update_node(self, node_id: str, properties: Dict[str, Any]) -> None:\n \"\"\"\n Updates properties of a node.\n\n Args:\n node_id (str): Unique identifier of the node.\n properties (Dict[str, Any]): Properties to update.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue updating the node.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.NodeNotFoundError","title":"NodeNotFoundError
","text":" Bases: Exception
Exception raised when a node is not found in the graph database.
Source code in src/aeiva/storage/graph_database.py
class NodeNotFoundError(Exception):\n \"\"\"Exception raised when a node is not found in the graph database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.RelationshipNotFoundError","title":"RelationshipNotFoundError
","text":" Bases: Exception
Exception raised when a relationship is not found in the graph database.
Source code in src/aeiva/storage/graph_database.py
class RelationshipNotFoundError(Exception):\n \"\"\"Exception raised when a relationship is not found in the graph database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.StorageError","title":"StorageError
","text":" Bases: Exception
Exception raised when there is a storage-related error in the graph database.
Source code in src/aeiva/storage/graph_database.py
class StorageError(Exception):\n \"\"\"Exception raised when there is a storage-related error in the graph database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.milvus","title":"milvus
","text":""},{"location":"reference/#src.aeiva.storage.milvus.milvus_config","title":"milvus_config
","text":""},{"location":"reference/#src.aeiva.storage.milvus.milvus_config.MilvusConfig","title":"MilvusConfig
dataclass
","text":" Bases: BaseConfig
Configuration for Milvus vector database.
Source code in src/aeiva/storage/milvus/milvus_config.py
@dataclass\nclass MilvusConfig(BaseConfig):\n \"\"\"\n Configuration for Milvus vector database.\n \"\"\"\n\n uri: str = field(\n default=\"http://localhost:19530\",\n metadata={\"help\": \"Full URL for Milvus server.\"}\n )\n token: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Token for Milvus server authentication (if required).\"}\n )\n collection_name: str = field(\n default=\"mem0\",\n metadata={\"help\": \"Name of the collection.\"}\n )\n embedding_model_dims: int = field(\n default=1536,\n metadata={\"help\": \"Dimensions of the embedding model.\"}\n )\n metric_type: str = field(\n default=\"L2\",\n metadata={\"help\": \"Metric type for similarity search (e.g., 'L2', 'IP', 'COSINE').\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n # Validate metric_type\n valid_metrics = {\"L2\", \"IP\", \"COSINE\", \"HAMMING\", \"JACCARD\"}\n if self.metric_type not in valid_metrics:\n raise ValueError(f\"Invalid metric_type '{self.metric_type}'. Valid options are {valid_metrics}.\")\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database","title":"milvus_database
","text":""},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase","title":"MilvusDatabase
","text":" Bases: VectorDatabase
Concrete implementation of VectorStoreBase using Milvus.
Source code in src/aeiva/storage/milvus/milvus_database.py
class MilvusDatabase(VectorDatabase):\n \"\"\"\n Concrete implementation of VectorStoreBase using Milvus.\n \"\"\"\n\n def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the Milvus vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.collection_name = config.get('collection_name')\n self.uri = config.get('uri')\n self.user = config.get('user')\n self.password = config.get('password')\n self.token = config.get('token')\n self.embedding_model_dims = config.get('embedding_model_dims')\n self.metric_type = config.get('metric_type', 'L2') # Default to 'L2' metric\n\n if not all([self.collection_name, self.uri, self.embedding_model_dims]):\n raise ValueError(\"Required configuration parameters are missing.\")\n\n self.create_client(\n uri=self.uri,\n user=self.user,\n password=self.password,\n token=self.token\n )\n self.create_collection(\n collection_name=self.collection_name,\n vector_size=self.embedding_model_dims,\n distance_metric=self.metric_type\n )\n\n def create_client(\n self,\n uri: str,\n user: Optional[str] = None,\n password: Optional[str] = None,\n token: Optional[str] = None,\n **kwargs\n ) -> None:\n \"\"\"\n Initializes the client connection to the Milvus vector store.\n\n Args:\n uri (str): The URI of the vector store instance.\n user (Optional[str]): Username for authentication.\n password (Optional[str]): Password for authentication.\n token (Optional[str]): Access token for authentication.\n **kwargs: Additional parameters.\n \"\"\"\n try:\n connections.connect(\n alias=\"default\",\n uri=uri,\n user=user,\n password=password,\n token=token,\n **kwargs\n )\n logger.info(f\"Connected to Milvus at {uri}.\")\n except MilvusException as e:\n logger.error(f\"Failed to connect to Milvus: {e}\")\n raise ConnectionError(f\"Failed to connect to Milvus: {e}\")\n\n def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection in Milvus.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use (e.g., 'L2', 'IP', 'COSINE').\n \"\"\"\n if utility.has_collection(collection_name):\n logger.info(f\"Collection {collection_name} already exists. Skipping creation.\")\n self.collection = Collection(collection_name)\n return\n\n # Define the schema\n fields = [\n FieldSchema(name=\"id\", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=64),\n FieldSchema(name=\"vector\", dtype=DataType.FLOAT_VECTOR, dim=vector_size),\n FieldSchema(name=\"payload\", dtype=DataType.JSON)\n ]\n schema = CollectionSchema(fields=fields, description=\"Milvus Vector Store Collection\")\n\n # Create the collection\n self.collection = Collection(name=collection_name, schema=schema)\n logger.info(f\"Collection {collection_name} created successfully.\")\n\n # Create index\n index_params = {\n \"metric_type\": distance_metric,\n \"index_type\": \"AUTOINDEX\",\n \"params\": {}\n }\n self.collection.create_index(field_name=\"vector\", index_params=index_params)\n logger.info(f\"Index created on collection {collection_name}.\")\n\n def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n ) -> None:\n \"\"\"\n Insert vectors into a collection.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n if ids is None:\n raise ValueError(\"Milvus requires IDs to be provided for each vector.\")\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n data = [\n ids,\n vectors,\n payloads\n ]\n self.collection.insert(data)\n logger.info(f\"Inserted {len(vectors)} vectors into collection {collection_name}.\")\n\n def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n search_params = {\n \"metric_type\": self.metric_type,\n \"params\": {}\n }\n\n expr = self._build_filter_expression(filters)\n results = self.collection.search(\n data=[query_vector],\n anns_field=\"vector\",\n param=search_params,\n limit=top_k,\n expr=expr,\n output_fields=[\"id\", \"payload\"]\n )\n\n output = []\n for hits in results:\n for hit in hits:\n result = {\n 'id': hit.entity.get('id'),\n 'score': hit.distance,\n 'payload': hit.entity.get('payload')\n }\n output.append(result)\n return output\n\n def _build_filter_expression(self, filters: Optional[Dict[str, Any]]) -> str:\n \"\"\"\n Build an expression string for filtering in Milvus.\n\n Args:\n filters (Optional[Dict[str, Any]]): Filters to apply.\n\n Returns:\n str: The expression string.\n \"\"\"\n if not filters:\n return \"\"\n\n expressions = []\n for key, value in filters.items():\n if isinstance(value, str):\n expressions.append(f'payload[\"{key}\"] == \"{value}\"')\n else:\n expressions.append(f'payload[\"{key}\"] == {value}')\n expr = \" and \".join(expressions)\n return expr\n\n def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n expr = f'id == \"{vector_id}\"'\n self.collection.delete(expr)\n logger.info(f\"Deleted vector with ID {vector_id} from collection {collection_name}.\")\n\n def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n ) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n # Milvus doesn't support direct updates; need to delete and re-insert\n # Fetch existing vector and payload\n expr = f'id == \"{vector_id}\"'\n results = self.collection.query(expr=expr, output_fields=[\"vector\", \"payload\"])\n\n if not results:\n raise KeyError(f\"Vector with ID {vector_id} not found in collection {collection_name}.\")\n\n existing_vector = results[0]['vector']\n existing_payload = results[0]['payload']\n\n new_vector = vector if vector is not None else existing_vector\n new_payload = payload if payload is not None else existing_payload\n\n # Delete the existing vector\n self.collection.delete(expr)\n\n # Re-insert with updated data\n self.insert_vectors(\n collection_name=collection_name,\n vectors=[new_vector],\n payloads=[new_payload],\n ids=[vector_id]\n )\n logger.info(f\"Updated vector with ID {vector_id} in collection {collection_name}.\")\n\n def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n expr = f'id == \"{vector_id}\"'\n results = self.collection.query(expr=expr, output_fields=[\"vector\", \"payload\"])\n\n if not results:\n raise KeyError(f\"Vector with ID {vector_id} not found in collection {collection_name}.\")\n\n vector_data = {\n 'id': vector_id,\n 'vector': results[0]['vector'],\n 'payload': results[0]['payload']\n }\n return vector_data\n\n def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections.\n\n Returns:\n List[str]: A list of collection names.\n \"\"\"\n return utility.list_collections()\n\n def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n \"\"\"\n self.collection.drop()\n logger.info(f\"Deleted collection {collection_name}.\")\n\n def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n info = self.collection.describe()\n return info\n\n def __del__(self):\n \"\"\"Clean up resources.\"\"\"\n connections.disconnect(\"default\")\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.__del__","title":"__del__()
","text":"Clean up resources.
Source code in src/aeiva/storage/milvus/milvus_database.py
def __del__(self):\n \"\"\"Clean up resources.\"\"\"\n connections.disconnect(\"default\")\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.__init__","title":"__init__(config)
","text":"Initialize the Milvus vector store.
Parameters:
Name Type Description Default config
Dict[str, Any]
Configuration dictionary.
required Source code in src/aeiva/storage/milvus/milvus_database.py
def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the Milvus vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.collection_name = config.get('collection_name')\n self.uri = config.get('uri')\n self.user = config.get('user')\n self.password = config.get('password')\n self.token = config.get('token')\n self.embedding_model_dims = config.get('embedding_model_dims')\n self.metric_type = config.get('metric_type', 'L2') # Default to 'L2' metric\n\n if not all([self.collection_name, self.uri, self.embedding_model_dims]):\n raise ValueError(\"Required configuration parameters are missing.\")\n\n self.create_client(\n uri=self.uri,\n user=self.user,\n password=self.password,\n token=self.token\n )\n self.create_collection(\n collection_name=self.collection_name,\n vector_size=self.embedding_model_dims,\n distance_metric=self.metric_type\n )\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.create_client","title":"create_client(uri, user=None, password=None, token=None, **kwargs)
","text":"Initializes the client connection to the Milvus vector store.
Parameters:
Name Type Description Default uri
str
The URI of the vector store instance.
required user
Optional[str]
Username for authentication.
None
password
Optional[str]
Password for authentication.
None
token
Optional[str]
Access token for authentication.
None
**kwargs
Additional parameters.
{}
Source code in src/aeiva/storage/milvus/milvus_database.py
def create_client(\n self,\n uri: str,\n user: Optional[str] = None,\n password: Optional[str] = None,\n token: Optional[str] = None,\n **kwargs\n) -> None:\n \"\"\"\n Initializes the client connection to the Milvus vector store.\n\n Args:\n uri (str): The URI of the vector store instance.\n user (Optional[str]): Username for authentication.\n password (Optional[str]): Password for authentication.\n token (Optional[str]): Access token for authentication.\n **kwargs: Additional parameters.\n \"\"\"\n try:\n connections.connect(\n alias=\"default\",\n uri=uri,\n user=user,\n password=password,\n token=token,\n **kwargs\n )\n logger.info(f\"Connected to Milvus at {uri}.\")\n except MilvusException as e:\n logger.error(f\"Failed to connect to Milvus: {e}\")\n raise ConnectionError(f\"Failed to connect to Milvus: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.create_collection","title":"create_collection(collection_name, vector_size, distance_metric)
","text":"Create a new vector collection in Milvus.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_size
int
The dimensionality of the vectors.
required distance_metric
str
The distance metric to use (e.g., 'L2', 'IP', 'COSINE').
required Source code in src/aeiva/storage/milvus/milvus_database.py
def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection in Milvus.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use (e.g., 'L2', 'IP', 'COSINE').\n \"\"\"\n if utility.has_collection(collection_name):\n logger.info(f\"Collection {collection_name} already exists. Skipping creation.\")\n self.collection = Collection(collection_name)\n return\n\n # Define the schema\n fields = [\n FieldSchema(name=\"id\", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=64),\n FieldSchema(name=\"vector\", dtype=DataType.FLOAT_VECTOR, dim=vector_size),\n FieldSchema(name=\"payload\", dtype=DataType.JSON)\n ]\n schema = CollectionSchema(fields=fields, description=\"Milvus Vector Store Collection\")\n\n # Create the collection\n self.collection = Collection(name=collection_name, schema=schema)\n logger.info(f\"Collection {collection_name} created successfully.\")\n\n # Create index\n index_params = {\n \"metric_type\": distance_metric,\n \"index_type\": \"AUTOINDEX\",\n \"params\": {}\n }\n self.collection.create_index(field_name=\"vector\", index_params=index_params)\n logger.info(f\"Index created on collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.delete_collection","title":"delete_collection(collection_name)
","text":"Delete an entire vector collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection to delete.
required Source code in src/aeiva/storage/milvus/milvus_database.py
def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n \"\"\"\n self.collection.drop()\n logger.info(f\"Deleted collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.delete_vector","title":"delete_vector(collection_name, vector_id)
","text":"Delete a vector from a collection by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to delete.
required Source code in src/aeiva/storage/milvus/milvus_database.py
def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n expr = f'id == \"{vector_id}\"'\n self.collection.delete(expr)\n logger.info(f\"Deleted vector with ID {vector_id} from collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.get_collection_info","title":"get_collection_info(collection_name)
","text":"Get information about a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: Information about the collection.
Source code in src/aeiva/storage/milvus/milvus_database.py
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n info = self.collection.describe()\n return info\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.get_vector","title":"get_vector(collection_name, vector_id)
","text":"Retrieve a vector by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the vector data and payload.
Source code in src/aeiva/storage/milvus/milvus_database.py
def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n expr = f'id == \"{vector_id}\"'\n results = self.collection.query(expr=expr, output_fields=[\"vector\", \"payload\"])\n\n if not results:\n raise KeyError(f\"Vector with ID {vector_id} not found in collection {collection_name}.\")\n\n vector_data = {\n 'id': vector_id,\n 'vector': results[0]['vector'],\n 'payload': results[0]['payload']\n }\n return vector_data\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.insert_vectors","title":"insert_vectors(collection_name, vectors, payloads=None, ids=None)
","text":"Insert vectors into a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vectors
List[List[float]]
A list of vectors to insert.
required payloads
Optional[List[Dict[str, Any]]]
Optional metadata associated with each vector.
None
ids
Optional[List[str]]
Optional unique identifiers for each vector.
None
Source code in src/aeiva/storage/milvus/milvus_database.py
def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n) -> None:\n \"\"\"\n Insert vectors into a collection.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n if ids is None:\n raise ValueError(\"Milvus requires IDs to be provided for each vector.\")\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n data = [\n ids,\n vectors,\n payloads\n ]\n self.collection.insert(data)\n logger.info(f\"Inserted {len(vectors)} vectors into collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.list_collections","title":"list_collections()
","text":"List all available vector collections.
Returns:
Type Description List[str]
List[str]: A list of collection names.
Source code in src/aeiva/storage/milvus/milvus_database.py
def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections.\n\n Returns:\n List[str]: A list of collection names.\n \"\"\"\n return utility.list_collections()\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.search_vectors","title":"search_vectors(collection_name, query_vector, top_k=5, filters=None)
","text":"Search for similar vectors in a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required query_vector
List[float]
The vector to search with.
required top_k
int
The number of top results to return.
5
filters
Optional[Dict[str, Any]]
Optional filters to apply to the search.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of search results.
Source code in src/aeiva/storage/milvus/milvus_database.py
def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n search_params = {\n \"metric_type\": self.metric_type,\n \"params\": {}\n }\n\n expr = self._build_filter_expression(filters)\n results = self.collection.search(\n data=[query_vector],\n anns_field=\"vector\",\n param=search_params,\n limit=top_k,\n expr=expr,\n output_fields=[\"id\", \"payload\"]\n )\n\n output = []\n for hits in results:\n for hit in hits:\n result = {\n 'id': hit.entity.get('id'),\n 'score': hit.distance,\n 'payload': hit.entity.get('payload')\n }\n output.append(result)\n return output\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.update_vector","title":"update_vector(collection_name, vector_id, vector=None, payload=None)
","text":"Update a vector's data or payload.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to update.
required vector
Optional[List[float]]
The new vector data.
None
payload
Optional[Dict[str, Any]]
The new payload data.
None
Source code in src/aeiva/storage/milvus/milvus_database.py
def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n # Milvus doesn't support direct updates; need to delete and re-insert\n # Fetch existing vector and payload\n expr = f'id == \"{vector_id}\"'\n results = self.collection.query(expr=expr, output_fields=[\"vector\", \"payload\"])\n\n if not results:\n raise KeyError(f\"Vector with ID {vector_id} not found in collection {collection_name}.\")\n\n existing_vector = results[0]['vector']\n existing_payload = results[0]['payload']\n\n new_vector = vector if vector is not None else existing_vector\n new_payload = payload if payload is not None else existing_payload\n\n # Delete the existing vector\n self.collection.delete(expr)\n\n # Re-insert with updated data\n self.insert_vectors(\n collection_name=collection_name,\n vectors=[new_vector],\n payloads=[new_payload],\n ids=[vector_id]\n )\n logger.info(f\"Updated vector with ID {vector_id} in collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb","title":"neo4jdb
","text":""},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_config","title":"neo4j_config
","text":""},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_config.Neo4jConfig","title":"Neo4jConfig
dataclass
","text":" Bases: BaseConfig
Configuration for Neo4j graph database.
Source code in src/aeiva/storage/neo4jdb/neo4j_config.py
@dataclass\nclass Neo4jConfig(BaseConfig):\n \"\"\"\n Configuration for Neo4j graph database.\n \"\"\"\n\n uri: str = field(\n default=\"bolt://localhost:7687\",\n metadata={\"help\": \"URI for connecting to Neo4j (e.g., 'bolt://localhost:7687').\"}\n )\n user: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Username for Neo4j authentication.\"}\n )\n password: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Password for Neo4j authentication.\"}\n )\n database: Optional[str] = field(\n default=\"neo4j\",\n metadata={\"help\": \"Neo4j database name.\"}\n )\n encrypted: bool = field(\n default=True,\n metadata={\"help\": \"Whether to use encrypted connection (True or False).\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n if not self.user or not self.password:\n raise ValueError(\"Both 'user' and 'password' must be provided for Neo4j authentication.\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database","title":"neo4j_database
","text":""},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase","title":"Neo4jDatabase
","text":" Bases: GraphDatabase
Concrete implementation of GraphStoreBase using Neo4j.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
class Neo4jDatabase(GraphDatabase):\n \"\"\"\n Concrete implementation of GraphStoreBase using Neo4j.\n \"\"\"\n\n def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the Neo4j graph database connection.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.uri = config.get('uri')\n self.user = config.get('user')\n self.password = config.get('password')\n self.database = config.get('database', 'neo4j')\n self.encrypted = config.get('encrypted', True)\n\n if not all([self.uri, self.user, self.password]):\n raise ValueError(\"Required configuration parameters 'uri', 'user', and 'password' are missing.\")\n\n self.create_client(\n uri=self.uri,\n user=self.user,\n password=self.password,\n encrypted=self.encrypted\n )\n\n def create_client(\n self,\n uri: str,\n user: str,\n password: str,\n encrypted: bool = True,\n **kwargs\n ) -> None:\n \"\"\"\n Initializes the client connection to the Neo4j graph database.\n\n Args:\n uri (str): The URI of the Neo4j instance.\n user (str): Username for authentication.\n password (str): Password for authentication.\n encrypted (bool): Whether to use encrypted connection.\n **kwargs: Additional parameters.\n\n Raises:\n ConnectionError: If the client fails to connect to the graph database.\n \"\"\"\n try:\n auth = basic_auth(user, password)\n self.driver = Neo4jGraphDatabase.driver(uri, auth=auth, encrypted=encrypted, **kwargs)\n self.session = self.driver.session(database=self.database)\n logger.info(f\"Connected to Neo4j at {uri}.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to connect to Neo4j: {e}\")\n raise ConnectionError(f\"Failed to connect to Neo4j: {e}\")\n\n def add_node(\n self,\n node_id: str,\n properties: Optional[Dict[str, Any]] = None,\n labels: Optional[List[str]] = None\n ) -> None:\n \"\"\"\n Adds a node to the graph.\n\n Args:\n node_id (str): Unique identifier for the node.\n properties (Optional[Dict[str, Any]]): Properties associated with the node.\n labels (Optional[List[str]]): Labels or types associated with the node.\n\n Raises:\n StorageError: If there is an issue adding the node.\n \"\"\"\n properties = properties or {}\n labels = labels or []\n labels_str = ':' + ':'.join(labels) if labels else ''\n cypher = f\"MERGE (n{labels_str} {{id: $node_id}}) SET n += $properties\"\n params = {\n 'node_id': node_id,\n 'properties': properties\n }\n try:\n self.session.run(cypher, params)\n logger.info(f\"Node with id '{node_id}' added to the graph.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to add node: {e}\")\n raise StorageError(f\"Failed to add node: {e}\")\n\n def add_edge(\n self,\n source_id: str,\n target_id: str,\n relationship: str,\n properties: Optional[Dict[str, Any]] = None\n ) -> None:\n \"\"\"\n Adds an edge (relationship) between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship.\n properties (Optional[Dict[str, Any]]): Properties associated with the edge.\n\n Raises:\n NodeNotFoundError: If either the source or target node does not exist.\n StorageError: If there is an issue adding the edge.\n \"\"\"\n properties = properties or {}\n # First, check if both nodes exist\n cypher_check = \"MATCH (a {id: $source_id}), (b {id: $target_id}) RETURN a, b\"\n params = {\n 'source_id': source_id,\n 'target_id': target_id\n }\n try:\n result = self.session.run(cypher_check, params)\n record = result.single()\n if not record:\n missing_nodes = []\n # Check if source node exists\n node_a_exists = self.session.run(\"MATCH (a {id: $source_id}) RETURN a\", {'source_id': source_id}).single()\n if not node_a_exists:\n missing_nodes.append(source_id)\n # Check if target node exists\n node_b_exists = self.session.run(\"MATCH (b {id: $target_id}) RETURN b\", {'target_id': target_id}).single()\n if not node_b_exists:\n missing_nodes.append(target_id)\n logger.warning(f\"Node(s) with id(s) {missing_nodes} not found.\")\n raise NodeNotFoundError(f\"Node(s) with id(s) {missing_nodes} not found.\")\n # Proceed to add the edge\n cypher_edge = (\n \"MATCH (a {id: $source_id}), (b {id: $target_id}) \"\n f\"MERGE (a)-[r:{relationship}]->(b) \"\n \"SET r += $properties\"\n )\n params['properties'] = properties\n self.session.run(cypher_edge, params)\n logger.info(f\"Relationship '{relationship}' added between '{source_id}' and '{target_id}'.\")\n except NodeNotFoundError:\n raise\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to add edge: {e}\")\n raise StorageError(f\"Failed to add edge: {e}\")\n\n def get_node(self, node_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieves a node by its identifier.\n\n Args:\n node_id (str): Unique identifier of the node.\n\n Returns:\n Dict[str, Any]: A dictionary containing the node's properties and labels.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue retrieving the node.\n \"\"\"\n cypher = \"MATCH (n {id: $node_id}) RETURN n\"\n params = {'node_id': node_id}\n try:\n result = self.session.run(cypher, params)\n record = result.single()\n if record:\n node = record['n']\n node_data = {\n 'id': node['id'],\n 'properties': {k: v for k, v in node.items() if k != 'id'},\n 'labels': list(node.labels)\n }\n logger.info(f\"Node with id '{node_id}' retrieved.\")\n return node_data\n else:\n logger.warning(f\"Node with id '{node_id}' not found.\")\n raise NodeNotFoundError(f\"Node with id '{node_id}' not found.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to get node: {e}\")\n raise StorageError(f\"Failed to get node: {e}\")\n\n def update_node(self, node_id: str, properties: Dict[str, Any]) -> None:\n \"\"\"\n Updates properties of a node.\n\n Args:\n node_id (str): Unique identifier of the node.\n properties (Dict[str, Any]): Properties to update.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue updating the node.\n \"\"\"\n cypher = \"MATCH (n {id: $node_id}) SET n += $properties RETURN n\"\n params = {\n 'node_id': node_id,\n 'properties': properties\n }\n try:\n result = self.session.run(cypher, params)\n record = result.single()\n if record:\n logger.info(f\"Node with id '{node_id}' updated.\")\n else:\n logger.warning(f\"Node with id '{node_id}' not found.\")\n raise NodeNotFoundError(f\"Node with id '{node_id}' not found.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to update node: {e}\")\n raise StorageError(f\"Failed to update node: {e}\")\n\n def delete_node(self, node_id: str) -> None:\n \"\"\"\n Deletes a node from the graph.\n\n Args:\n node_id (str): Unique identifier of the node.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue deleting the node.\n \"\"\"\n cypher = \"MATCH (n {id: $node_id}) DETACH DELETE n RETURN COUNT(n) AS count\"\n params = {'node_id': node_id}\n try:\n result = self.session.run(cypher, params)\n record = result.single()\n if record and record['count'] > 0:\n logger.info(f\"Node with id '{node_id}' deleted.\")\n else:\n logger.warning(f\"Node with id '{node_id}' not found.\")\n raise NodeNotFoundError(f\"Node with id '{node_id}' not found.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to delete node: {e}\")\n raise StorageError(f\"Failed to delete node: {e}\")\n\n def delete_edge(\n self,\n source_id: str,\n target_id: str,\n relationship: str\n ) -> None:\n \"\"\"\n Deletes a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to delete.\n\n Raises:\n StorageError: If there is an issue deleting the relationship.\n \"\"\"\n cypher = (\n \"MATCH (a {id: $source_id})-[r:%s]->(b {id: $target_id}) \"\n \"DELETE r\"\n ) % relationship\n params = {\n 'source_id': source_id,\n 'target_id': target_id\n }\n try:\n result = self.session.run(cypher, params)\n if result.consume().counters.relationships_deleted == 0:\n logger.warning(f\"No relationship '{relationship}' found between '{source_id}' and '{target_id}'.\")\n raise StorageError(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.\")\n logger.info(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' deleted.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to delete relationship: {e}\")\n raise StorageError(f\"Failed to delete relationship: {e}\")\n\n def update_edge(\n self,\n source_id: str,\n target_id: str,\n relationship: str,\n properties: Dict[str, Any]\n ) -> None:\n \"\"\"\n Updates properties of a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to update.\n properties (Dict[str, Any]): Properties to update on the relationship.\n\n Raises:\n StorageError: If there is an issue updating the relationship.\n \"\"\"\n cypher = (\n \"MATCH (a {id: $source_id})-[r:%s]->(b {id: $target_id}) \"\n \"SET r += $properties RETURN r\"\n ) % relationship\n params = {\n 'source_id': source_id,\n 'target_id': target_id,\n 'properties': properties\n }\n try:\n result = self.session.run(cypher, params)\n record = result.single()\n if record:\n logger.info(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' updated with properties {properties}.\")\n else:\n logger.warning(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.\")\n raise StorageError(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to update relationship: {e}\")\n raise StorageError(f\"Failed to update relationship: {e}\")\n\n def get_relationship(\n self,\n source_id: str,\n target_id: str,\n relationship: str\n ) -> Dict[str, Any]:\n \"\"\"\n Retrieves a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to retrieve.\n\n Returns:\n Dict[str, Any]: A dictionary containing the relationship's properties.\n\n Raises:\n StorageError: If there is an issue retrieving the relationship.\n \"\"\"\n cypher = (\n \"MATCH (a {id: $source_id})-[r:%s]->(b {id: $target_id}) \"\n \"RETURN r\"\n ) % relationship\n params = {\n 'source_id': source_id,\n 'target_id': target_id\n }\n try:\n result = self.session.run(cypher, params)\n record = result.single()\n if record:\n relationship_data = record['r']\n properties = dict(relationship_data)\n properties['type'] = relationship.type # Include relationship type\n logger.info(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' retrieved.\")\n return properties\n else:\n logger.warning(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.\")\n raise StorageError(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to retrieve relationship: {e}\")\n raise StorageError(f\"Failed to retrieve relationship: {e}\")\n\n def delete_all_edges(self) -> None:\n \"\"\"\n Deletes all relationships from the Neo4j graph database without deleting nodes.\n\n Raises:\n StorageError: If there is an issue deleting relationships.\n \"\"\"\n cypher = \"MATCH ()-[r]->() DELETE r\"\n try:\n self.session.run(cypher)\n logger.info(\"All relationships have been deleted from Neo4j.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to delete all relationships: {e}\")\n raise StorageError(f\"Failed to delete all relationships: {e}\")\n\n def delete_relationships_by_type(self, relationship: str) -> None:\n \"\"\"\n Deletes all relationships of a specific type from the Neo4j graph database.\n\n Args:\n relationship (str): The type of relationships to delete.\n\n Raises:\n StorageError: If there is an issue deleting the relationships.\n \"\"\"\n cypher = f\"MATCH ()-[r:{relationship}]->() DELETE r\"\n try:\n self.session.run(cypher)\n logger.info(f\"All relationships of type '{relationship}' have been deleted from Neo4j.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to delete relationships of type '{relationship}': {e}\")\n raise StorageError(f\"Failed to delete relationships of type '{relationship}': {e}\")\n\n def delete_all(self) -> None:\n \"\"\"\n Deletes all nodes and relationships from the Neo4j graph database.\n\n Raises:\n StorageError: If there is an issue deleting all nodes and relationships.\n \"\"\"\n cypher = \"MATCH (n) DETACH DELETE n\"\n try:\n self.session.run(cypher)\n logger.info(\"All nodes and relationships have been deleted from Neo4j.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to delete all nodes and relationships: {e}\")\n raise StorageError(f\"Failed to delete all nodes and relationships: {e}\")\n\n def get_neighbors(\n self,\n node_id: str,\n relationship: Optional[str] = None,\n direction: str = \"both\"\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Retrieves neighboring nodes connected by edges.\n\n Args:\n node_id (str): Unique identifier of the node.\n relationship (Optional[str]): Filter by relationship type.\n direction (str): Direction of the relationships ('in', 'out', 'both').\n\n Returns:\n List[Dict[str, Any]]: A list of neighboring nodes.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue retrieving neighbors.\n \"\"\"\n if direction not in [\"in\", \"out\", \"both\"]:\n raise ValueError(\"Invalid direction. Must be 'in', 'out', or 'both'.\")\n\n rel_type = f\":{relationship}\" if relationship else ''\n if direction == \"in\":\n pattern = f\"<-[r{rel_type}]-\"\n elif direction == \"out\":\n pattern = f\"-[r{rel_type}]->\"\n else: # both\n pattern = f\"-[r{rel_type}]-\"\n\n cypher = f\"MATCH (n {{id: $node_id}}){pattern}(neighbor) RETURN neighbor\"\n params = {'node_id': node_id}\n try:\n # First, check if the node exists\n node_exists_query = \"MATCH (n {id: $node_id}) RETURN n\"\n node_result = self.session.run(node_exists_query, params)\n if not node_result.single():\n logger.warning(f\"Node with id '{node_id}' not found.\")\n raise NodeNotFoundError(f\"Node with id '{node_id}' not found.\")\n # Get neighbors\n result = self.session.run(cypher, params)\n neighbors = []\n for record in result:\n node = record['neighbor']\n neighbor_data = {\n 'id': node['id'],\n 'properties': {k: v for k, v in node.items() if k != 'id'},\n 'labels': list(node.labels)\n }\n neighbors.append(neighbor_data)\n logger.info(f\"Neighbors of node '{node_id}' retrieved.\")\n return neighbors\n except NodeNotFoundError:\n raise\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to get neighbors: {e}\")\n raise StorageError(f\"Failed to get neighbors: {e}\")\n\n def query_nodes(\n self,\n properties: Dict[str, Any],\n labels: Optional[List[str]] = None\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Queries nodes based on properties and labels.\n\n Args:\n properties (Dict[str, Any]): Properties to filter nodes.\n labels (Optional[List[str]]): Labels to filter nodes.\n\n Returns:\n List[Dict[str, Any]]: A list of nodes matching the query.\n\n Raises:\n StorageError: If there is an issue querying nodes.\n \"\"\"\n labels_str = ':' + ':'.join(labels) if labels else ''\n params = {}\n cypher = f\"MATCH (n{labels_str})\"\n\n if properties:\n props_conditions = ' AND '.join([f\"n.{key} = ${key}\" for key in properties.keys()])\n cypher += f\" WHERE {props_conditions}\"\n params.update(properties)\n\n cypher += \" RETURN n\"\n\n try:\n result = self.session.run(cypher, params)\n nodes = []\n for record in result:\n node = record['n']\n node_data = {\n 'id': node['id'],\n 'properties': {k: v for k, v in node.items() if k != 'id'},\n 'labels': list(node.labels)\n }\n nodes.append(node_data)\n logger.info(f\"Query returned {len(nodes)} nodes.\")\n return nodes\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to query nodes: {e}\")\n raise StorageError(f\"Failed to query nodes: {e}\")\n\n def execute_query(self, query: str, parameters: Optional[Dict[str, Any]] = None) -> Any:\n \"\"\"\n Executes a raw query against the graph database.\n\n Args:\n query (str): The query string.\n parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.\n\n Returns:\n Any: The result of the query.\n\n Raises:\n StorageError: If there is an issue executing the query.\n \"\"\"\n try:\n result = self.session.run(query, parameters)\n records = [record.data() for record in result]\n logger.info(f\"Executed query: {query}\")\n return records\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to execute query: {e}\")\n raise StorageError(f\"Failed to execute query: {e}\")\n\n def close(self) -> None:\n \"\"\"\n Closes the graph database connection and releases resources.\n \"\"\"\n if hasattr(self, 'session') and self.session:\n self.session.close()\n if hasattr(self, 'driver') and self.driver:\n self.driver.close()\n logger.info(\"Closed connection to Neo4j database.\")\n\n def __del__(self):\n \"\"\"Destructor to ensure resources are cleaned up.\"\"\"\n self.close()\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.__del__","title":"__del__()
","text":"Destructor to ensure resources are cleaned up.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def __del__(self):\n \"\"\"Destructor to ensure resources are cleaned up.\"\"\"\n self.close()\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.__init__","title":"__init__(config)
","text":"Initialize the Neo4j graph database connection.
Parameters:
Name Type Description Default config
Dict[str, Any]
Configuration dictionary.
required Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the Neo4j graph database connection.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.uri = config.get('uri')\n self.user = config.get('user')\n self.password = config.get('password')\n self.database = config.get('database', 'neo4j')\n self.encrypted = config.get('encrypted', True)\n\n if not all([self.uri, self.user, self.password]):\n raise ValueError(\"Required configuration parameters 'uri', 'user', and 'password' are missing.\")\n\n self.create_client(\n uri=self.uri,\n user=self.user,\n password=self.password,\n encrypted=self.encrypted\n )\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.add_edge","title":"add_edge(source_id, target_id, relationship, properties=None)
","text":"Adds an edge (relationship) between two nodes.
Parameters:
Name Type Description Default source_id
str
Unique identifier of the source node.
required target_id
str
Unique identifier of the target node.
required relationship
str
Type of the relationship.
required properties
Optional[Dict[str, Any]]
Properties associated with the edge.
None
Raises:
Type Description NodeNotFoundError
If either the source or target node does not exist.
StorageError
If there is an issue adding the edge.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def add_edge(\n self,\n source_id: str,\n target_id: str,\n relationship: str,\n properties: Optional[Dict[str, Any]] = None\n) -> None:\n \"\"\"\n Adds an edge (relationship) between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship.\n properties (Optional[Dict[str, Any]]): Properties associated with the edge.\n\n Raises:\n NodeNotFoundError: If either the source or target node does not exist.\n StorageError: If there is an issue adding the edge.\n \"\"\"\n properties = properties or {}\n # First, check if both nodes exist\n cypher_check = \"MATCH (a {id: $source_id}), (b {id: $target_id}) RETURN a, b\"\n params = {\n 'source_id': source_id,\n 'target_id': target_id\n }\n try:\n result = self.session.run(cypher_check, params)\n record = result.single()\n if not record:\n missing_nodes = []\n # Check if source node exists\n node_a_exists = self.session.run(\"MATCH (a {id: $source_id}) RETURN a\", {'source_id': source_id}).single()\n if not node_a_exists:\n missing_nodes.append(source_id)\n # Check if target node exists\n node_b_exists = self.session.run(\"MATCH (b {id: $target_id}) RETURN b\", {'target_id': target_id}).single()\n if not node_b_exists:\n missing_nodes.append(target_id)\n logger.warning(f\"Node(s) with id(s) {missing_nodes} not found.\")\n raise NodeNotFoundError(f\"Node(s) with id(s) {missing_nodes} not found.\")\n # Proceed to add the edge\n cypher_edge = (\n \"MATCH (a {id: $source_id}), (b {id: $target_id}) \"\n f\"MERGE (a)-[r:{relationship}]->(b) \"\n \"SET r += $properties\"\n )\n params['properties'] = properties\n self.session.run(cypher_edge, params)\n logger.info(f\"Relationship '{relationship}' added between '{source_id}' and '{target_id}'.\")\n except NodeNotFoundError:\n raise\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to add edge: {e}\")\n raise StorageError(f\"Failed to add edge: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.add_node","title":"add_node(node_id, properties=None, labels=None)
","text":"Adds a node to the graph.
Parameters:
Name Type Description Default node_id
str
Unique identifier for the node.
required properties
Optional[Dict[str, Any]]
Properties associated with the node.
None
labels
Optional[List[str]]
Labels or types associated with the node.
None
Raises:
Type Description StorageError
If there is an issue adding the node.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def add_node(\n self,\n node_id: str,\n properties: Optional[Dict[str, Any]] = None,\n labels: Optional[List[str]] = None\n) -> None:\n \"\"\"\n Adds a node to the graph.\n\n Args:\n node_id (str): Unique identifier for the node.\n properties (Optional[Dict[str, Any]]): Properties associated with the node.\n labels (Optional[List[str]]): Labels or types associated with the node.\n\n Raises:\n StorageError: If there is an issue adding the node.\n \"\"\"\n properties = properties or {}\n labels = labels or []\n labels_str = ':' + ':'.join(labels) if labels else ''\n cypher = f\"MERGE (n{labels_str} {{id: $node_id}}) SET n += $properties\"\n params = {\n 'node_id': node_id,\n 'properties': properties\n }\n try:\n self.session.run(cypher, params)\n logger.info(f\"Node with id '{node_id}' added to the graph.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to add node: {e}\")\n raise StorageError(f\"Failed to add node: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.close","title":"close()
","text":"Closes the graph database connection and releases resources.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def close(self) -> None:\n \"\"\"\n Closes the graph database connection and releases resources.\n \"\"\"\n if hasattr(self, 'session') and self.session:\n self.session.close()\n if hasattr(self, 'driver') and self.driver:\n self.driver.close()\n logger.info(\"Closed connection to Neo4j database.\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.create_client","title":"create_client(uri, user, password, encrypted=True, **kwargs)
","text":"Initializes the client connection to the Neo4j graph database.
Parameters:
Name Type Description Default uri
str
The URI of the Neo4j instance.
required user
str
Username for authentication.
required password
str
Password for authentication.
required encrypted
bool
Whether to use encrypted connection.
True
**kwargs
Additional parameters.
{}
Raises:
Type Description ConnectionError
If the client fails to connect to the graph database.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def create_client(\n self,\n uri: str,\n user: str,\n password: str,\n encrypted: bool = True,\n **kwargs\n) -> None:\n \"\"\"\n Initializes the client connection to the Neo4j graph database.\n\n Args:\n uri (str): The URI of the Neo4j instance.\n user (str): Username for authentication.\n password (str): Password for authentication.\n encrypted (bool): Whether to use encrypted connection.\n **kwargs: Additional parameters.\n\n Raises:\n ConnectionError: If the client fails to connect to the graph database.\n \"\"\"\n try:\n auth = basic_auth(user, password)\n self.driver = Neo4jGraphDatabase.driver(uri, auth=auth, encrypted=encrypted, **kwargs)\n self.session = self.driver.session(database=self.database)\n logger.info(f\"Connected to Neo4j at {uri}.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to connect to Neo4j: {e}\")\n raise ConnectionError(f\"Failed to connect to Neo4j: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.delete_all","title":"delete_all()
","text":"Deletes all nodes and relationships from the Neo4j graph database.
Raises:
Type Description StorageError
If there is an issue deleting all nodes and relationships.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def delete_all(self) -> None:\n \"\"\"\n Deletes all nodes and relationships from the Neo4j graph database.\n\n Raises:\n StorageError: If there is an issue deleting all nodes and relationships.\n \"\"\"\n cypher = \"MATCH (n) DETACH DELETE n\"\n try:\n self.session.run(cypher)\n logger.info(\"All nodes and relationships have been deleted from Neo4j.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to delete all nodes and relationships: {e}\")\n raise StorageError(f\"Failed to delete all nodes and relationships: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.delete_all_edges","title":"delete_all_edges()
","text":"Deletes all relationships from the Neo4j graph database without deleting nodes.
Raises:
Type Description StorageError
If there is an issue deleting relationships.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def delete_all_edges(self) -> None:\n \"\"\"\n Deletes all relationships from the Neo4j graph database without deleting nodes.\n\n Raises:\n StorageError: If there is an issue deleting relationships.\n \"\"\"\n cypher = \"MATCH ()-[r]->() DELETE r\"\n try:\n self.session.run(cypher)\n logger.info(\"All relationships have been deleted from Neo4j.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to delete all relationships: {e}\")\n raise StorageError(f\"Failed to delete all relationships: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.delete_edge","title":"delete_edge(source_id, target_id, relationship)
","text":"Deletes a specific relationship between two nodes.
Parameters:
Name Type Description Default source_id
str
Unique identifier of the source node.
required target_id
str
Unique identifier of the target node.
required relationship
str
Type of the relationship to delete.
required Raises:
Type Description StorageError
If there is an issue deleting the relationship.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def delete_edge(\n self,\n source_id: str,\n target_id: str,\n relationship: str\n) -> None:\n \"\"\"\n Deletes a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to delete.\n\n Raises:\n StorageError: If there is an issue deleting the relationship.\n \"\"\"\n cypher = (\n \"MATCH (a {id: $source_id})-[r:%s]->(b {id: $target_id}) \"\n \"DELETE r\"\n ) % relationship\n params = {\n 'source_id': source_id,\n 'target_id': target_id\n }\n try:\n result = self.session.run(cypher, params)\n if result.consume().counters.relationships_deleted == 0:\n logger.warning(f\"No relationship '{relationship}' found between '{source_id}' and '{target_id}'.\")\n raise StorageError(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.\")\n logger.info(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' deleted.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to delete relationship: {e}\")\n raise StorageError(f\"Failed to delete relationship: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.delete_node","title":"delete_node(node_id)
","text":"Deletes a node from the graph.
Parameters:
Name Type Description Default node_id
str
Unique identifier of the node.
required Raises:
Type Description NodeNotFoundError
If the node does not exist.
StorageError
If there is an issue deleting the node.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def delete_node(self, node_id: str) -> None:\n \"\"\"\n Deletes a node from the graph.\n\n Args:\n node_id (str): Unique identifier of the node.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue deleting the node.\n \"\"\"\n cypher = \"MATCH (n {id: $node_id}) DETACH DELETE n RETURN COUNT(n) AS count\"\n params = {'node_id': node_id}\n try:\n result = self.session.run(cypher, params)\n record = result.single()\n if record and record['count'] > 0:\n logger.info(f\"Node with id '{node_id}' deleted.\")\n else:\n logger.warning(f\"Node with id '{node_id}' not found.\")\n raise NodeNotFoundError(f\"Node with id '{node_id}' not found.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to delete node: {e}\")\n raise StorageError(f\"Failed to delete node: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.delete_relationships_by_type","title":"delete_relationships_by_type(relationship)
","text":"Deletes all relationships of a specific type from the Neo4j graph database.
Parameters:
Name Type Description Default relationship
str
The type of relationships to delete.
required Raises:
Type Description StorageError
If there is an issue deleting the relationships.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def delete_relationships_by_type(self, relationship: str) -> None:\n \"\"\"\n Deletes all relationships of a specific type from the Neo4j graph database.\n\n Args:\n relationship (str): The type of relationships to delete.\n\n Raises:\n StorageError: If there is an issue deleting the relationships.\n \"\"\"\n cypher = f\"MATCH ()-[r:{relationship}]->() DELETE r\"\n try:\n self.session.run(cypher)\n logger.info(f\"All relationships of type '{relationship}' have been deleted from Neo4j.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to delete relationships of type '{relationship}': {e}\")\n raise StorageError(f\"Failed to delete relationships of type '{relationship}': {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.execute_query","title":"execute_query(query, parameters=None)
","text":"Executes a raw query against the graph database.
Parameters:
Name Type Description Default query
str
The query string.
required parameters
Optional[Dict[str, Any]]
Parameters for parameterized queries.
None
Returns:
Name Type Description Any
Any
The result of the query.
Raises:
Type Description StorageError
If there is an issue executing the query.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def execute_query(self, query: str, parameters: Optional[Dict[str, Any]] = None) -> Any:\n \"\"\"\n Executes a raw query against the graph database.\n\n Args:\n query (str): The query string.\n parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.\n\n Returns:\n Any: The result of the query.\n\n Raises:\n StorageError: If there is an issue executing the query.\n \"\"\"\n try:\n result = self.session.run(query, parameters)\n records = [record.data() for record in result]\n logger.info(f\"Executed query: {query}\")\n return records\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to execute query: {e}\")\n raise StorageError(f\"Failed to execute query: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.get_neighbors","title":"get_neighbors(node_id, relationship=None, direction='both')
","text":"Retrieves neighboring nodes connected by edges.
Parameters:
Name Type Description Default node_id
str
Unique identifier of the node.
required relationship
Optional[str]
Filter by relationship type.
None
direction
str
Direction of the relationships ('in', 'out', 'both').
'both'
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of neighboring nodes.
Raises:
Type Description NodeNotFoundError
If the node does not exist.
StorageError
If there is an issue retrieving neighbors.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def get_neighbors(\n self,\n node_id: str,\n relationship: Optional[str] = None,\n direction: str = \"both\"\n) -> List[Dict[str, Any]]:\n \"\"\"\n Retrieves neighboring nodes connected by edges.\n\n Args:\n node_id (str): Unique identifier of the node.\n relationship (Optional[str]): Filter by relationship type.\n direction (str): Direction of the relationships ('in', 'out', 'both').\n\n Returns:\n List[Dict[str, Any]]: A list of neighboring nodes.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue retrieving neighbors.\n \"\"\"\n if direction not in [\"in\", \"out\", \"both\"]:\n raise ValueError(\"Invalid direction. Must be 'in', 'out', or 'both'.\")\n\n rel_type = f\":{relationship}\" if relationship else ''\n if direction == \"in\":\n pattern = f\"<-[r{rel_type}]-\"\n elif direction == \"out\":\n pattern = f\"-[r{rel_type}]->\"\n else: # both\n pattern = f\"-[r{rel_type}]-\"\n\n cypher = f\"MATCH (n {{id: $node_id}}){pattern}(neighbor) RETURN neighbor\"\n params = {'node_id': node_id}\n try:\n # First, check if the node exists\n node_exists_query = \"MATCH (n {id: $node_id}) RETURN n\"\n node_result = self.session.run(node_exists_query, params)\n if not node_result.single():\n logger.warning(f\"Node with id '{node_id}' not found.\")\n raise NodeNotFoundError(f\"Node with id '{node_id}' not found.\")\n # Get neighbors\n result = self.session.run(cypher, params)\n neighbors = []\n for record in result:\n node = record['neighbor']\n neighbor_data = {\n 'id': node['id'],\n 'properties': {k: v for k, v in node.items() if k != 'id'},\n 'labels': list(node.labels)\n }\n neighbors.append(neighbor_data)\n logger.info(f\"Neighbors of node '{node_id}' retrieved.\")\n return neighbors\n except NodeNotFoundError:\n raise\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to get neighbors: {e}\")\n raise StorageError(f\"Failed to get neighbors: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.get_node","title":"get_node(node_id)
","text":"Retrieves a node by its identifier.
Parameters:
Name Type Description Default node_id
str
Unique identifier of the node.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the node's properties and labels.
Raises:
Type Description NodeNotFoundError
If the node does not exist.
StorageError
If there is an issue retrieving the node.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def get_node(self, node_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieves a node by its identifier.\n\n Args:\n node_id (str): Unique identifier of the node.\n\n Returns:\n Dict[str, Any]: A dictionary containing the node's properties and labels.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue retrieving the node.\n \"\"\"\n cypher = \"MATCH (n {id: $node_id}) RETURN n\"\n params = {'node_id': node_id}\n try:\n result = self.session.run(cypher, params)\n record = result.single()\n if record:\n node = record['n']\n node_data = {\n 'id': node['id'],\n 'properties': {k: v for k, v in node.items() if k != 'id'},\n 'labels': list(node.labels)\n }\n logger.info(f\"Node with id '{node_id}' retrieved.\")\n return node_data\n else:\n logger.warning(f\"Node with id '{node_id}' not found.\")\n raise NodeNotFoundError(f\"Node with id '{node_id}' not found.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to get node: {e}\")\n raise StorageError(f\"Failed to get node: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.get_relationship","title":"get_relationship(source_id, target_id, relationship)
","text":"Retrieves a specific relationship between two nodes.
Parameters:
Name Type Description Default source_id
str
Unique identifier of the source node.
required target_id
str
Unique identifier of the target node.
required relationship
str
Type of the relationship to retrieve.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the relationship's properties.
Raises:
Type Description StorageError
If there is an issue retrieving the relationship.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def get_relationship(\n self,\n source_id: str,\n target_id: str,\n relationship: str\n) -> Dict[str, Any]:\n \"\"\"\n Retrieves a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to retrieve.\n\n Returns:\n Dict[str, Any]: A dictionary containing the relationship's properties.\n\n Raises:\n StorageError: If there is an issue retrieving the relationship.\n \"\"\"\n cypher = (\n \"MATCH (a {id: $source_id})-[r:%s]->(b {id: $target_id}) \"\n \"RETURN r\"\n ) % relationship\n params = {\n 'source_id': source_id,\n 'target_id': target_id\n }\n try:\n result = self.session.run(cypher, params)\n record = result.single()\n if record:\n relationship_data = record['r']\n properties = dict(relationship_data)\n properties['type'] = relationship.type # Include relationship type\n logger.info(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' retrieved.\")\n return properties\n else:\n logger.warning(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.\")\n raise StorageError(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to retrieve relationship: {e}\")\n raise StorageError(f\"Failed to retrieve relationship: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.query_nodes","title":"query_nodes(properties, labels=None)
","text":"Queries nodes based on properties and labels.
Parameters:
Name Type Description Default properties
Dict[str, Any]
Properties to filter nodes.
required labels
Optional[List[str]]
Labels to filter nodes.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of nodes matching the query.
Raises:
Type Description StorageError
If there is an issue querying nodes.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def query_nodes(\n self,\n properties: Dict[str, Any],\n labels: Optional[List[str]] = None\n) -> List[Dict[str, Any]]:\n \"\"\"\n Queries nodes based on properties and labels.\n\n Args:\n properties (Dict[str, Any]): Properties to filter nodes.\n labels (Optional[List[str]]): Labels to filter nodes.\n\n Returns:\n List[Dict[str, Any]]: A list of nodes matching the query.\n\n Raises:\n StorageError: If there is an issue querying nodes.\n \"\"\"\n labels_str = ':' + ':'.join(labels) if labels else ''\n params = {}\n cypher = f\"MATCH (n{labels_str})\"\n\n if properties:\n props_conditions = ' AND '.join([f\"n.{key} = ${key}\" for key in properties.keys()])\n cypher += f\" WHERE {props_conditions}\"\n params.update(properties)\n\n cypher += \" RETURN n\"\n\n try:\n result = self.session.run(cypher, params)\n nodes = []\n for record in result:\n node = record['n']\n node_data = {\n 'id': node['id'],\n 'properties': {k: v for k, v in node.items() if k != 'id'},\n 'labels': list(node.labels)\n }\n nodes.append(node_data)\n logger.info(f\"Query returned {len(nodes)} nodes.\")\n return nodes\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to query nodes: {e}\")\n raise StorageError(f\"Failed to query nodes: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.update_edge","title":"update_edge(source_id, target_id, relationship, properties)
","text":"Updates properties of a specific relationship between two nodes.
Parameters:
Name Type Description Default source_id
str
Unique identifier of the source node.
required target_id
str
Unique identifier of the target node.
required relationship
str
Type of the relationship to update.
required properties
Dict[str, Any]
Properties to update on the relationship.
required Raises:
Type Description StorageError
If there is an issue updating the relationship.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def update_edge(\n self,\n source_id: str,\n target_id: str,\n relationship: str,\n properties: Dict[str, Any]\n) -> None:\n \"\"\"\n Updates properties of a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to update.\n properties (Dict[str, Any]): Properties to update on the relationship.\n\n Raises:\n StorageError: If there is an issue updating the relationship.\n \"\"\"\n cypher = (\n \"MATCH (a {id: $source_id})-[r:%s]->(b {id: $target_id}) \"\n \"SET r += $properties RETURN r\"\n ) % relationship\n params = {\n 'source_id': source_id,\n 'target_id': target_id,\n 'properties': properties\n }\n try:\n result = self.session.run(cypher, params)\n record = result.single()\n if record:\n logger.info(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' updated with properties {properties}.\")\n else:\n logger.warning(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.\")\n raise StorageError(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to update relationship: {e}\")\n raise StorageError(f\"Failed to update relationship: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.update_node","title":"update_node(node_id, properties)
","text":"Updates properties of a node.
Parameters:
Name Type Description Default node_id
str
Unique identifier of the node.
required properties
Dict[str, Any]
Properties to update.
required Raises:
Type Description NodeNotFoundError
If the node does not exist.
StorageError
If there is an issue updating the node.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def update_node(self, node_id: str, properties: Dict[str, Any]) -> None:\n \"\"\"\n Updates properties of a node.\n\n Args:\n node_id (str): Unique identifier of the node.\n properties (Dict[str, Any]): Properties to update.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue updating the node.\n \"\"\"\n cypher = \"MATCH (n {id: $node_id}) SET n += $properties RETURN n\"\n params = {\n 'node_id': node_id,\n 'properties': properties\n }\n try:\n result = self.session.run(cypher, params)\n record = result.single()\n if record:\n logger.info(f\"Node with id '{node_id}' updated.\")\n else:\n logger.warning(f\"Node with id '{node_id}' not found.\")\n raise NodeNotFoundError(f\"Node with id '{node_id}' not found.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to update node: {e}\")\n raise StorageError(f\"Failed to update node: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.NodeNotFoundError","title":"NodeNotFoundError
","text":" Bases: Exception
Exception raised when a node is not found in the graph database.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
class NodeNotFoundError(Exception):\n \"\"\"Exception raised when a node is not found in the graph database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.StorageError","title":"StorageError
","text":" Bases: Exception
Exception raised when there is a storage-related error in the graph database.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
class StorageError(Exception):\n \"\"\"Exception raised when there is a storage-related error in the graph database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.pgvector","title":"pgvector
","text":""},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_config","title":"pgvector_config
","text":""},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_config.PGVectorConfig","title":"PGVectorConfig
dataclass
","text":" Bases: BaseConfig
Configuration for PGVector (PostgreSQL with vector extension).
Source code in src/aeiva/storage/pgvector/pgvector_config.py
@dataclass\nclass PGVectorConfig(BaseConfig):\n \"\"\"\n Configuration for PGVector (PostgreSQL with vector extension).\n \"\"\"\n\n dbname: str = field(\n default=\"postgres\",\n metadata={\"help\": \"Name of the database.\"}\n )\n collection_name: str = field(\n default=\"mem0\",\n metadata={\"help\": \"Name of the collection (table name).\"}\n )\n embedding_model_dims: int = field(\n default=1536,\n metadata={\"help\": \"Dimensions of the embedding model.\"}\n )\n user: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Database user.\"}\n )\n password: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Database password.\"}\n )\n host: str = field(\n default=\"localhost\",\n metadata={\"help\": \"Database host.\"}\n )\n port: int = field(\n default=5432,\n metadata={\"help\": \"Database port.\"}\n )\n use_diskann: bool = field(\n default=True,\n metadata={\"help\": \"Whether to use diskann for approximate nearest neighbors search.\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n # Validate that user and password are provided\n if not self.user or not self.password:\n raise ValueError(\"Both 'user' and 'password' must be provided.\")\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database","title":"pgvector_database
","text":""},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase","title":"PGVectorDatabase
","text":" Bases: VectorDatabase
Concrete implementation of VectorStoreBase using PGVector.
Source code in src/aeiva/storage/pgvector/pgvector_database.py
class PGVectorDatabase(VectorDatabase):\n \"\"\"\n Concrete implementation of VectorStoreBase using PGVector.\n \"\"\"\n\n def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the PGVector vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.collection_name = config.get('collection_name')\n self.dbname = config.get('dbname')\n self.user = config.get('user')\n self.password = config.get('password')\n self.host = config.get('host', 'localhost')\n self.port = config.get('port', 5432)\n self.embedding_model_dims = config.get('embedding_model_dims')\n self.use_diskann = config.get('use_diskann', False)\n\n if not all([self.collection_name, self.dbname, self.user, self.password, self.embedding_model_dims]):\n raise ValueError(\"Required configuration parameters are missing.\")\n\n self.create_client()\n self.create_collection(\n collection_name=self.collection_name,\n vector_size=self.embedding_model_dims,\n distance_metric='cosine' # PGVector uses cosine by default\n )\n\n def create_client(self, **kwargs) -> None:\n \"\"\"\n Initializes the client connection to the PGVector database.\n\n Args:\n **kwargs: Additional parameters.\n \"\"\"\n try:\n self.conn = psycopg2.connect(\n dbname=self.dbname,\n user=self.user,\n password=self.password,\n host=self.host,\n port=self.port,\n **kwargs\n )\n self.cur = self.conn.cursor()\n logger.info(\"Connected to PGVector database.\")\n except psycopg2.Error as e:\n logger.error(f\"Failed to connect to PGVector database: {e}\")\n raise ConnectionError(f\"Failed to connect to PGVector database: {e}\")\n\n def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection (table) in PGVector.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use (e.g., 'cosine').\n \"\"\"\n # Check if table exists\n self.cur.execute(\n \"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name=%s);\",\n (collection_name,)\n )\n exists = self.cur.fetchone()[0]\n if exists:\n logger.info(f\"Table {collection_name} already exists. Skipping creation.\")\n return\n\n # Create table\n create_table_query = f\"\"\"\n CREATE TABLE {collection_name} (\n id VARCHAR(64) PRIMARY KEY,\n vector vector({vector_size}),\n payload JSONB\n );\n \"\"\"\n self.cur.execute(create_table_query)\n self.conn.commit()\n logger.info(f\"Table {collection_name} created successfully.\")\n\n # Create index if use_diskann is True\n if self.use_diskann:\n create_index_query = f\"\"\"\n CREATE INDEX {collection_name}_vector_idx\n ON {collection_name}\n USING ivfflat (vector vector_cosine_ops)\n WITH (lists = 100);\n \"\"\"\n self.cur.execute(create_index_query)\n self.conn.commit()\n logger.info(f\"Index created on table {collection_name}.\")\n\n def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n ) -> None:\n \"\"\"\n Insert vectors into a collection.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n if ids is None:\n raise ValueError(\"PGVector requires IDs to be provided for each vector.\")\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n records = [\n (id_, vector, Json(payload))\n for id_, vector, payload in zip(ids, vectors, payloads)\n ]\n insert_query = f\"INSERT INTO {collection_name} (id, vector, payload) VALUES %s;\"\n execute_values(self.cur, insert_query, records)\n self.conn.commit()\n logger.info(f\"Inserted {len(vectors)} vectors into table {collection_name}.\")\n\n def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n filter_clause = \"\"\n params = [query_vector]\n\n if filters:\n filter_conditions = []\n for key, value in filters.items():\n filter_conditions.append(f\"payload ->> %s = %s\")\n params.extend([key, str(value)])\n filter_clause = \"WHERE \" + \" AND \".join(filter_conditions)\n\n search_query = f\"\"\"\n SELECT id, vector, payload, 1 - (vector <#> %s::vector) AS score\n FROM {collection_name}\n {filter_clause}\n ORDER BY vector <#> %s::vector\n LIMIT %s;\n \"\"\"\n params.extend([query_vector, top_k])\n self.cur.execute(search_query, params)\n results = self.cur.fetchall()\n\n output = []\n for row in results:\n result = {\n 'id': row[0],\n 'score': row[3],\n 'payload': row[2]\n }\n output.append(result)\n return output\n\n def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n delete_query = f\"DELETE FROM {collection_name} WHERE id = %s;\"\n self.cur.execute(delete_query, (vector_id,))\n self.conn.commit()\n logger.info(f\"Deleted vector with ID {vector_id} from table {collection_name}.\")\n\n def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n ) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n if vector is not None:\n update_query = f\"UPDATE {collection_name} SET vector = %s WHERE id = %s;\"\n self.cur.execute(update_query, (vector, vector_id))\n if payload is not None:\n update_query = f\"UPDATE {collection_name} SET payload = %s WHERE id = %s;\"\n self.cur.execute(update_query, (Json(payload), vector_id))\n self.conn.commit()\n logger.info(f\"Updated vector with ID {vector_id} in table {collection_name}.\")\n\n def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n select_query = f\"SELECT id, vector, payload FROM {collection_name} WHERE id = %s;\"\n self.cur.execute(select_query, (vector_id,))\n result = self.cur.fetchone()\n\n if not result:\n raise KeyError(f\"Vector with ID {vector_id} not found in table {collection_name}.\")\n\n vector_data = {\n 'id': result[0],\n 'vector': result[1],\n 'payload': result[2]\n }\n return vector_data\n\n def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections (tables).\n\n Returns:\n List[str]: A list of collection names.\n \"\"\"\n self.cur.execute(\n \"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';\"\n )\n tables = self.cur.fetchall()\n return [table[0] for table in tables]\n\n def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n \"\"\"\n drop_query = f\"DROP TABLE IF EXISTS {collection_name};\"\n self.cur.execute(drop_query)\n self.conn.commit()\n logger.info(f\"Deleted table {collection_name}.\")\n\n def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection.\n \"\"\"\n self.cur.execute(\n \"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s;\",\n (collection_name,)\n )\n columns = self.cur.fetchall()\n info = {\n 'name': collection_name,\n 'columns': {column[0]: column[1] for column in columns}\n }\n return info\n\n def __del__(self):\n \"\"\"Clean up resources.\"\"\"\n if hasattr(self, 'cur') and self.cur:\n self.cur.close()\n if hasattr(self, 'conn') and self.conn:\n self.conn.close()\n logger.info(\"Closed connection to PGVector database.\")\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.__del__","title":"__del__()
","text":"Clean up resources.
Source code in src/aeiva/storage/pgvector/pgvector_database.py
def __del__(self):\n \"\"\"Clean up resources.\"\"\"\n if hasattr(self, 'cur') and self.cur:\n self.cur.close()\n if hasattr(self, 'conn') and self.conn:\n self.conn.close()\n logger.info(\"Closed connection to PGVector database.\")\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.__init__","title":"__init__(config)
","text":"Initialize the PGVector vector store.
Parameters:
Name Type Description Default config
Dict[str, Any]
Configuration dictionary.
required Source code in src/aeiva/storage/pgvector/pgvector_database.py
def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the PGVector vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.collection_name = config.get('collection_name')\n self.dbname = config.get('dbname')\n self.user = config.get('user')\n self.password = config.get('password')\n self.host = config.get('host', 'localhost')\n self.port = config.get('port', 5432)\n self.embedding_model_dims = config.get('embedding_model_dims')\n self.use_diskann = config.get('use_diskann', False)\n\n if not all([self.collection_name, self.dbname, self.user, self.password, self.embedding_model_dims]):\n raise ValueError(\"Required configuration parameters are missing.\")\n\n self.create_client()\n self.create_collection(\n collection_name=self.collection_name,\n vector_size=self.embedding_model_dims,\n distance_metric='cosine' # PGVector uses cosine by default\n )\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.create_client","title":"create_client(**kwargs)
","text":"Initializes the client connection to the PGVector database.
Parameters:
Name Type Description Default **kwargs
Additional parameters.
{}
Source code in src/aeiva/storage/pgvector/pgvector_database.py
def create_client(self, **kwargs) -> None:\n \"\"\"\n Initializes the client connection to the PGVector database.\n\n Args:\n **kwargs: Additional parameters.\n \"\"\"\n try:\n self.conn = psycopg2.connect(\n dbname=self.dbname,\n user=self.user,\n password=self.password,\n host=self.host,\n port=self.port,\n **kwargs\n )\n self.cur = self.conn.cursor()\n logger.info(\"Connected to PGVector database.\")\n except psycopg2.Error as e:\n logger.error(f\"Failed to connect to PGVector database: {e}\")\n raise ConnectionError(f\"Failed to connect to PGVector database: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.create_collection","title":"create_collection(collection_name, vector_size, distance_metric)
","text":"Create a new vector collection (table) in PGVector.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_size
int
The dimensionality of the vectors.
required distance_metric
str
The distance metric to use (e.g., 'cosine').
required Source code in src/aeiva/storage/pgvector/pgvector_database.py
def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection (table) in PGVector.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use (e.g., 'cosine').\n \"\"\"\n # Check if table exists\n self.cur.execute(\n \"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name=%s);\",\n (collection_name,)\n )\n exists = self.cur.fetchone()[0]\n if exists:\n logger.info(f\"Table {collection_name} already exists. Skipping creation.\")\n return\n\n # Create table\n create_table_query = f\"\"\"\n CREATE TABLE {collection_name} (\n id VARCHAR(64) PRIMARY KEY,\n vector vector({vector_size}),\n payload JSONB\n );\n \"\"\"\n self.cur.execute(create_table_query)\n self.conn.commit()\n logger.info(f\"Table {collection_name} created successfully.\")\n\n # Create index if use_diskann is True\n if self.use_diskann:\n create_index_query = f\"\"\"\n CREATE INDEX {collection_name}_vector_idx\n ON {collection_name}\n USING ivfflat (vector vector_cosine_ops)\n WITH (lists = 100);\n \"\"\"\n self.cur.execute(create_index_query)\n self.conn.commit()\n logger.info(f\"Index created on table {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.delete_collection","title":"delete_collection(collection_name)
","text":"Delete an entire vector collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection to delete.
required Source code in src/aeiva/storage/pgvector/pgvector_database.py
def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n \"\"\"\n drop_query = f\"DROP TABLE IF EXISTS {collection_name};\"\n self.cur.execute(drop_query)\n self.conn.commit()\n logger.info(f\"Deleted table {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.delete_vector","title":"delete_vector(collection_name, vector_id)
","text":"Delete a vector from a collection by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to delete.
required Source code in src/aeiva/storage/pgvector/pgvector_database.py
def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n delete_query = f\"DELETE FROM {collection_name} WHERE id = %s;\"\n self.cur.execute(delete_query, (vector_id,))\n self.conn.commit()\n logger.info(f\"Deleted vector with ID {vector_id} from table {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.get_collection_info","title":"get_collection_info(collection_name)
","text":"Get information about a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: Information about the collection.
Source code in src/aeiva/storage/pgvector/pgvector_database.py
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection.\n \"\"\"\n self.cur.execute(\n \"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s;\",\n (collection_name,)\n )\n columns = self.cur.fetchall()\n info = {\n 'name': collection_name,\n 'columns': {column[0]: column[1] for column in columns}\n }\n return info\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.get_vector","title":"get_vector(collection_name, vector_id)
","text":"Retrieve a vector by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the vector data and payload.
Source code in src/aeiva/storage/pgvector/pgvector_database.py
def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n select_query = f\"SELECT id, vector, payload FROM {collection_name} WHERE id = %s;\"\n self.cur.execute(select_query, (vector_id,))\n result = self.cur.fetchone()\n\n if not result:\n raise KeyError(f\"Vector with ID {vector_id} not found in table {collection_name}.\")\n\n vector_data = {\n 'id': result[0],\n 'vector': result[1],\n 'payload': result[2]\n }\n return vector_data\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.insert_vectors","title":"insert_vectors(collection_name, vectors, payloads=None, ids=None)
","text":"Insert vectors into a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vectors
List[List[float]]
A list of vectors to insert.
required payloads
Optional[List[Dict[str, Any]]]
Optional metadata associated with each vector.
None
ids
Optional[List[str]]
Optional unique identifiers for each vector.
None
Source code in src/aeiva/storage/pgvector/pgvector_database.py
def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n) -> None:\n \"\"\"\n Insert vectors into a collection.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n if ids is None:\n raise ValueError(\"PGVector requires IDs to be provided for each vector.\")\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n records = [\n (id_, vector, Json(payload))\n for id_, vector, payload in zip(ids, vectors, payloads)\n ]\n insert_query = f\"INSERT INTO {collection_name} (id, vector, payload) VALUES %s;\"\n execute_values(self.cur, insert_query, records)\n self.conn.commit()\n logger.info(f\"Inserted {len(vectors)} vectors into table {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.list_collections","title":"list_collections()
","text":"List all available vector collections (tables).
Returns:
Type Description List[str]
List[str]: A list of collection names.
Source code in src/aeiva/storage/pgvector/pgvector_database.py
def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections (tables).\n\n Returns:\n List[str]: A list of collection names.\n \"\"\"\n self.cur.execute(\n \"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';\"\n )\n tables = self.cur.fetchall()\n return [table[0] for table in tables]\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.search_vectors","title":"search_vectors(collection_name, query_vector, top_k=5, filters=None)
","text":"Search for similar vectors in a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required query_vector
List[float]
The vector to search with.
required top_k
int
The number of top results to return.
5
filters
Optional[Dict[str, Any]]
Optional filters to apply to the search.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of search results.
Source code in src/aeiva/storage/pgvector/pgvector_database.py
def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n filter_clause = \"\"\n params = [query_vector]\n\n if filters:\n filter_conditions = []\n for key, value in filters.items():\n filter_conditions.append(f\"payload ->> %s = %s\")\n params.extend([key, str(value)])\n filter_clause = \"WHERE \" + \" AND \".join(filter_conditions)\n\n search_query = f\"\"\"\n SELECT id, vector, payload, 1 - (vector <#> %s::vector) AS score\n FROM {collection_name}\n {filter_clause}\n ORDER BY vector <#> %s::vector\n LIMIT %s;\n \"\"\"\n params.extend([query_vector, top_k])\n self.cur.execute(search_query, params)\n results = self.cur.fetchall()\n\n output = []\n for row in results:\n result = {\n 'id': row[0],\n 'score': row[3],\n 'payload': row[2]\n }\n output.append(result)\n return output\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.update_vector","title":"update_vector(collection_name, vector_id, vector=None, payload=None)
","text":"Update a vector's data or payload.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to update.
required vector
Optional[List[float]]
The new vector data.
None
payload
Optional[Dict[str, Any]]
The new payload data.
None
Source code in src/aeiva/storage/pgvector/pgvector_database.py
def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n if vector is not None:\n update_query = f\"UPDATE {collection_name} SET vector = %s WHERE id = %s;\"\n self.cur.execute(update_query, (vector, vector_id))\n if payload is not None:\n update_query = f\"UPDATE {collection_name} SET payload = %s WHERE id = %s;\"\n self.cur.execute(update_query, (Json(payload), vector_id))\n self.conn.commit()\n logger.info(f\"Updated vector with ID {vector_id} in table {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.postgresql","title":"postgresql
","text":""},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_config","title":"postgresql_config
","text":""},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_config.PostgreSQLConfig","title":"PostgreSQLConfig
dataclass
","text":" Bases: BaseConfig
Configuration for PostgreSQL database.
Source code in src/aeiva/storage/postgresql/postgresql_config.py
@dataclass\nclass PostgreSQLConfig(BaseConfig):\n \"\"\"\n Configuration for PostgreSQL database.\n \"\"\"\n dbname: str = field(\n default='postgres',\n metadata={\"help\": \"Name of the PostgreSQL database.\"}\n )\n user: str = field(\n default='postgres',\n metadata={\"help\": \"Username for PostgreSQL authentication.\"}\n )\n password: str = field(\n default='',\n metadata={\"help\": \"Password for PostgreSQL authentication.\"}\n )\n host: str = field(\n default='localhost',\n metadata={\"help\": \"Host address for PostgreSQL server.\"}\n )\n port: int = field(\n default=5432,\n metadata={\"help\": \"Port number for PostgreSQL server.\"}\n )\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database","title":"postgresql_database
","text":""},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase","title":"PostgreSQLDatabase
","text":" Bases: RelationalDatabase
Concrete implementation of RelationalStoreBase using PostgreSQL.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
class PostgreSQLDatabase(RelationalDatabase):\n \"\"\"\n Concrete implementation of RelationalStoreBase using PostgreSQL.\n \"\"\"\n\n def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the PostgreSQL database connection.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.connection = None\n self.cursor = None\n self.connect()\n\n def connect(self) -> None:\n \"\"\"\n Establishes a connection to the PostgreSQL database.\n \"\"\"\n try:\n self.connection = psycopg2.connect(\n dbname=self.config.get('dbname'),\n user=self.config.get('user'),\n password=self.config.get('password'),\n host=self.config.get('host'),\n port=self.config.get('port')\n )\n self.connection.autocommit = True # Enable autocommit for DDL statements\n self.cursor = self.connection.cursor(cursor_factory=RealDictCursor)\n except psycopg2.Error as e:\n raise ConnectionError(f\"Failed to connect to PostgreSQL database: {e}\")\n\n def close(self) -> None:\n \"\"\"\n Closes the database connection and releases resources.\n \"\"\"\n if self.cursor:\n self.cursor.close()\n if self.connection:\n self.connection.close()\n\n def insert_record(self, table: str, record: Dict[str, Any]) -> Any:\n \"\"\"\n Inserts a record into a table.\n\n Args:\n table (str): The name of the table.\n record (Dict[str, Any]): A dictionary representing the record to insert.\n\n Returns:\n Any: The primary key of the inserted record.\n\n Raises:\n StorageError: If there is an issue inserting the record.\n \"\"\"\n try:\n columns = ', '.join(record.keys())\n placeholders = ', '.join(f\"%({key})s\" for key in record.keys())\n sql = f\"INSERT INTO {table} ({columns}) VALUES ({placeholders}) RETURNING id\"\n self.cursor.execute(sql, record)\n result = self.cursor.fetchone()\n return result['id']\n except psycopg2.IntegrityError as e:\n self.connection.rollback()\n raise StorageError(f\"Integrity error: {e}\")\n except psycopg2.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to insert record: {e}\")\n\n def get_record(self, table: str, primary_key: Any) -> Dict[str, Any]:\n \"\"\"\n Retrieves a record by its primary key.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Returns:\n Dict[str, Any]: The retrieved record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue retrieving the record.\n \"\"\"\n try:\n sql = f\"SELECT * FROM {table} WHERE id = %s\"\n self.cursor.execute(sql, (primary_key,))\n row = self.cursor.fetchone()\n if row is None:\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n return dict(row)\n except psycopg2.Error as e:\n raise StorageError(f\"Failed to get record: {e}\")\n\n def update_record(self, table: str, primary_key: Any, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a record in a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n updates (Dict[str, Any]): A dictionary of fields to update.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue updating the record.\n \"\"\"\n try:\n set_clause = ', '.join(f\"{key} = %({key})s\" for key in updates.keys())\n sql = f\"UPDATE {table} SET {set_clause} WHERE id = %(id)s\"\n updates['id'] = primary_key\n self.cursor.execute(sql, updates)\n if self.cursor.rowcount == 0:\n self.connection.rollback()\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n except psycopg2.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to update record: {e}\")\n\n def delete_record(self, table: str, primary_key: Any) -> None:\n \"\"\"\n Deletes a record from a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue deleting the record.\n \"\"\"\n try:\n sql = f\"DELETE FROM {table} WHERE id = %s\"\n self.cursor.execute(sql, (primary_key,))\n if self.cursor.rowcount == 0:\n self.connection.rollback()\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n except psycopg2.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to delete record: {e}\")\n\n def query_records(\n self,\n table: str,\n conditions: Optional[Dict[str, Any]] = None,\n limit: Optional[int] = None,\n offset: Optional[int] = None\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Queries records from a table based on conditions.\n\n Args:\n table (str): The name of the table.\n conditions (Optional[Dict[str, Any]]): Conditions to filter records.\n limit (Optional[int]): Maximum number of records to return.\n offset (Optional[int]): Number of records to skip.\n\n Returns:\n List[Dict[str, Any]]: A list of records matching the query.\n\n Raises:\n StorageError: If there is an issue querying records.\n \"\"\"\n try:\n sql = f\"SELECT * FROM {table}\"\n params = {}\n if conditions:\n where_clause = ' AND '.join(f\"{key} = %({key})s\" for key in conditions.keys())\n sql += f\" WHERE {where_clause}\"\n params.update(conditions)\n if limit is not None:\n sql += f\" LIMIT {limit}\"\n if offset is not None:\n sql += f\" OFFSET {offset}\"\n self.cursor.execute(sql, params)\n rows = self.cursor.fetchall()\n return [dict(row) for row in rows]\n except psycopg2.Error as e:\n raise StorageError(f\"Failed to query records: {e}\")\n\n def execute_sql(self, query: str, parameters: Optional[Dict[str, Any]] = None) -> Any:\n \"\"\"\n Executes a raw SQL query.\n\n Args:\n query (str): The SQL query string.\n parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.\n\n Returns:\n Any: The result of the query.\n\n Raises:\n StorageError: If there is an issue executing the query.\n \"\"\"\n cursor = self.connection.cursor(cursor_factory=psycopg2.extras.DictCursor)\n try:\n if parameters:\n cursor.execute(query, parameters)\n else:\n cursor.execute(query)\n if query.strip().upper().startswith(\"SELECT\"):\n return cursor\n else:\n self.connection.commit()\n return cursor\n except psycopg2.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to execute SQL query: {e}\")\n\n def begin_transaction(self) -> None:\n \"\"\"\n Begins a transaction.\n \"\"\"\n self.connection.autocommit = False\n\n def commit_transaction(self) -> None:\n \"\"\"\n Commits the current transaction.\n \"\"\"\n self.connection.commit()\n self.connection.autocommit = True\n\n def rollback_transaction(self) -> None:\n \"\"\"\n Rolls back the current transaction.\n \"\"\"\n self.connection.rollback()\n self.connection.autocommit = True\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.__init__","title":"__init__(config)
","text":"Initialize the PostgreSQL database connection.
Parameters:
Name Type Description Default config
Dict[str, Any]
Configuration dictionary.
required Source code in src/aeiva/storage/postgresql/postgresql_database.py
def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the PostgreSQL database connection.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.connection = None\n self.cursor = None\n self.connect()\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.begin_transaction","title":"begin_transaction()
","text":"Begins a transaction.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def begin_transaction(self) -> None:\n \"\"\"\n Begins a transaction.\n \"\"\"\n self.connection.autocommit = False\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.close","title":"close()
","text":"Closes the database connection and releases resources.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def close(self) -> None:\n \"\"\"\n Closes the database connection and releases resources.\n \"\"\"\n if self.cursor:\n self.cursor.close()\n if self.connection:\n self.connection.close()\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.commit_transaction","title":"commit_transaction()
","text":"Commits the current transaction.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def commit_transaction(self) -> None:\n \"\"\"\n Commits the current transaction.\n \"\"\"\n self.connection.commit()\n self.connection.autocommit = True\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.connect","title":"connect()
","text":"Establishes a connection to the PostgreSQL database.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def connect(self) -> None:\n \"\"\"\n Establishes a connection to the PostgreSQL database.\n \"\"\"\n try:\n self.connection = psycopg2.connect(\n dbname=self.config.get('dbname'),\n user=self.config.get('user'),\n password=self.config.get('password'),\n host=self.config.get('host'),\n port=self.config.get('port')\n )\n self.connection.autocommit = True # Enable autocommit for DDL statements\n self.cursor = self.connection.cursor(cursor_factory=RealDictCursor)\n except psycopg2.Error as e:\n raise ConnectionError(f\"Failed to connect to PostgreSQL database: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.delete_record","title":"delete_record(table, primary_key)
","text":"Deletes a record from a table.
Parameters:
Name Type Description Default table
str
The name of the table.
required primary_key
Any
The primary key of the record.
required Raises:
Type Description RecordNotFoundError
If the record does not exist.
StorageError
If there is an issue deleting the record.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def delete_record(self, table: str, primary_key: Any) -> None:\n \"\"\"\n Deletes a record from a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue deleting the record.\n \"\"\"\n try:\n sql = f\"DELETE FROM {table} WHERE id = %s\"\n self.cursor.execute(sql, (primary_key,))\n if self.cursor.rowcount == 0:\n self.connection.rollback()\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n except psycopg2.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to delete record: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.execute_sql","title":"execute_sql(query, parameters=None)
","text":"Executes a raw SQL query.
Parameters:
Name Type Description Default query
str
The SQL query string.
required parameters
Optional[Dict[str, Any]]
Parameters for parameterized queries.
None
Returns:
Name Type Description Any
Any
The result of the query.
Raises:
Type Description StorageError
If there is an issue executing the query.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def execute_sql(self, query: str, parameters: Optional[Dict[str, Any]] = None) -> Any:\n \"\"\"\n Executes a raw SQL query.\n\n Args:\n query (str): The SQL query string.\n parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.\n\n Returns:\n Any: The result of the query.\n\n Raises:\n StorageError: If there is an issue executing the query.\n \"\"\"\n cursor = self.connection.cursor(cursor_factory=psycopg2.extras.DictCursor)\n try:\n if parameters:\n cursor.execute(query, parameters)\n else:\n cursor.execute(query)\n if query.strip().upper().startswith(\"SELECT\"):\n return cursor\n else:\n self.connection.commit()\n return cursor\n except psycopg2.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to execute SQL query: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.get_record","title":"get_record(table, primary_key)
","text":"Retrieves a record by its primary key.
Parameters:
Name Type Description Default table
str
The name of the table.
required primary_key
Any
The primary key of the record.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: The retrieved record.
Raises:
Type Description RecordNotFoundError
If the record does not exist.
StorageError
If there is an issue retrieving the record.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def get_record(self, table: str, primary_key: Any) -> Dict[str, Any]:\n \"\"\"\n Retrieves a record by its primary key.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Returns:\n Dict[str, Any]: The retrieved record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue retrieving the record.\n \"\"\"\n try:\n sql = f\"SELECT * FROM {table} WHERE id = %s\"\n self.cursor.execute(sql, (primary_key,))\n row = self.cursor.fetchone()\n if row is None:\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n return dict(row)\n except psycopg2.Error as e:\n raise StorageError(f\"Failed to get record: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.insert_record","title":"insert_record(table, record)
","text":"Inserts a record into a table.
Parameters:
Name Type Description Default table
str
The name of the table.
required record
Dict[str, Any]
A dictionary representing the record to insert.
required Returns:
Name Type Description Any
Any
The primary key of the inserted record.
Raises:
Type Description StorageError
If there is an issue inserting the record.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def insert_record(self, table: str, record: Dict[str, Any]) -> Any:\n \"\"\"\n Inserts a record into a table.\n\n Args:\n table (str): The name of the table.\n record (Dict[str, Any]): A dictionary representing the record to insert.\n\n Returns:\n Any: The primary key of the inserted record.\n\n Raises:\n StorageError: If there is an issue inserting the record.\n \"\"\"\n try:\n columns = ', '.join(record.keys())\n placeholders = ', '.join(f\"%({key})s\" for key in record.keys())\n sql = f\"INSERT INTO {table} ({columns}) VALUES ({placeholders}) RETURNING id\"\n self.cursor.execute(sql, record)\n result = self.cursor.fetchone()\n return result['id']\n except psycopg2.IntegrityError as e:\n self.connection.rollback()\n raise StorageError(f\"Integrity error: {e}\")\n except psycopg2.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to insert record: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.query_records","title":"query_records(table, conditions=None, limit=None, offset=None)
","text":"Queries records from a table based on conditions.
Parameters:
Name Type Description Default table
str
The name of the table.
required conditions
Optional[Dict[str, Any]]
Conditions to filter records.
None
limit
Optional[int]
Maximum number of records to return.
None
offset
Optional[int]
Number of records to skip.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of records matching the query.
Raises:
Type Description StorageError
If there is an issue querying records.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def query_records(\n self,\n table: str,\n conditions: Optional[Dict[str, Any]] = None,\n limit: Optional[int] = None,\n offset: Optional[int] = None\n) -> List[Dict[str, Any]]:\n \"\"\"\n Queries records from a table based on conditions.\n\n Args:\n table (str): The name of the table.\n conditions (Optional[Dict[str, Any]]): Conditions to filter records.\n limit (Optional[int]): Maximum number of records to return.\n offset (Optional[int]): Number of records to skip.\n\n Returns:\n List[Dict[str, Any]]: A list of records matching the query.\n\n Raises:\n StorageError: If there is an issue querying records.\n \"\"\"\n try:\n sql = f\"SELECT * FROM {table}\"\n params = {}\n if conditions:\n where_clause = ' AND '.join(f\"{key} = %({key})s\" for key in conditions.keys())\n sql += f\" WHERE {where_clause}\"\n params.update(conditions)\n if limit is not None:\n sql += f\" LIMIT {limit}\"\n if offset is not None:\n sql += f\" OFFSET {offset}\"\n self.cursor.execute(sql, params)\n rows = self.cursor.fetchall()\n return [dict(row) for row in rows]\n except psycopg2.Error as e:\n raise StorageError(f\"Failed to query records: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.rollback_transaction","title":"rollback_transaction()
","text":"Rolls back the current transaction.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def rollback_transaction(self) -> None:\n \"\"\"\n Rolls back the current transaction.\n \"\"\"\n self.connection.rollback()\n self.connection.autocommit = True\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.update_record","title":"update_record(table, primary_key, updates)
","text":"Updates a record in a table.
Parameters:
Name Type Description Default table
str
The name of the table.
required primary_key
Any
The primary key of the record.
required updates
Dict[str, Any]
A dictionary of fields to update.
required Raises:
Type Description RecordNotFoundError
If the record does not exist.
StorageError
If there is an issue updating the record.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def update_record(self, table: str, primary_key: Any, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a record in a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n updates (Dict[str, Any]): A dictionary of fields to update.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue updating the record.\n \"\"\"\n try:\n set_clause = ', '.join(f\"{key} = %({key})s\" for key in updates.keys())\n sql = f\"UPDATE {table} SET {set_clause} WHERE id = %(id)s\"\n updates['id'] = primary_key\n self.cursor.execute(sql, updates)\n if self.cursor.rowcount == 0:\n self.connection.rollback()\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n except psycopg2.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to update record: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.RecordNotFoundError","title":"RecordNotFoundError
","text":" Bases: Exception
Exception raised when a record is not found in the database.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
class RecordNotFoundError(Exception):\n \"\"\"Exception raised when a record is not found in the database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.StorageError","title":"StorageError
","text":" Bases: Exception
Exception raised when there is a storage-related error in the database.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
class StorageError(Exception):\n \"\"\"Exception raised when there is a storage-related error in the database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.postgresql.test","title":"test
","text":""},{"location":"reference/#src.aeiva.storage.postgresql.test.RecordNotFoundError","title":"RecordNotFoundError
","text":" Bases: Exception
Exception raised when a record is not found in the database.
Source code in src/aeiva/storage/postgresql/test.py
class RecordNotFoundError(Exception):\n \"\"\"Exception raised when a record is not found in the database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.qdrant","title":"qdrant
","text":""},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_config","title":"qdrant_config
","text":""},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_config.QdrantConfig","title":"QdrantConfig
dataclass
","text":" Bases: BaseConfig
Configuration for Qdrant vector database.
Source code in src/aeiva/storage/qdrant/qdrant_config.py
@dataclass\nclass QdrantConfig(BaseConfig):\n \"\"\"\n Configuration for Qdrant vector database.\n \"\"\"\n\n collection_name: str = field(\n default=\"mem0\",\n metadata={\"help\": \"Name of the collection.\"}\n )\n embedding_model_dims: int = field(\n default=1536,\n metadata={\"help\": \"Dimensions of the embedding model.\"}\n )\n client: Optional[Any] = field(\n default=None,\n metadata={\"help\": \"Existing Qdrant client instance (if any).\"}\n )\n host: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Host address for Qdrant server.\"}\n )\n port: Optional[int] = field(\n default=None,\n metadata={\"help\": \"Port for Qdrant server.\"}\n )\n path: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Path for local Qdrant database storage.\"}\n )\n url: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Full URL for Qdrant server.\"}\n )\n api_key: Optional[str] = field(\n default=None,\n metadata={\"help\": \"API key for Qdrant server authentication.\"}\n )\n on_disk: bool = field(\n default=False,\n metadata={\"help\": \"Whether to enable persistent storage on disk.\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n # Validate that connection parameters are provided\n if not self.path and not ((self.host and self.port) or (self.url and self.api_key)):\n raise ValueError(\"Provide 'path' for local storage, or 'host' and 'port', or 'url' and 'api_key' for remote connection.\")\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database","title":"qdrant_database
","text":""},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase","title":"QdrantDatabase
","text":" Bases: VectorDatabase
Concrete implementation of VectorStoreBase using Qdrant.
Source code in src/aeiva/storage/qdrant/qdrant_database.py
class QdrantDatabase(VectorDatabase):\n \"\"\"\n Concrete implementation of VectorStoreBase using Qdrant.\n \"\"\"\n\n def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the Qdrant vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.collection_name = config.get('collection_name')\n self.embedding_model_dims = config.get('embedding_model_dims')\n self.client = config.get('client')\n self.host = config.get('host')\n self.port = config.get('port')\n self.path = config.get('path')\n self.url = config.get('url')\n self.api_key = config.get('api_key')\n self.on_disk = config.get('on_disk', False)\n\n if not all([self.collection_name, self.embedding_model_dims]):\n raise ValueError(\"Required configuration parameters are missing.\")\n\n self.create_client()\n self.create_collection(\n collection_name=self.collection_name,\n vector_size=self.embedding_model_dims,\n distance_metric='COSINE'\n )\n\n def create_client(self, **kwargs) -> None:\n \"\"\"\n Initializes the client connection to the Qdrant vector store.\n\n Args:\n **kwargs: Additional parameters.\n \"\"\"\n if self.client:\n return # Client already provided\n\n client_params = {}\n if self.api_key:\n client_params['api_key'] = self.api_key\n if self.url:\n client_params['url'] = self.url\n elif self.host and self.port:\n client_params['host'] = self.host\n client_params['port'] = self.port\n else:\n client_params['path'] = self.path\n\n self.client = QdrantClient(**client_params)\n logger.info(\"Qdrant client initialized.\")\n\n def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection in Qdrant.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use (e.g., 'COSINE').\n \"\"\"\n # Check if collection exists\n collections = self.list_collections()\n if collection_name in collections:\n logger.info(f\"Collection {collection_name} already exists. Skipping creation.\")\n return\n\n vector_params = VectorParams(\n size=vector_size,\n distance=getattr(Distance, distance_metric.upper()),\n on_disk=self.on_disk\n )\n self.client.create_collection(\n collection_name=collection_name,\n vectors_config=vector_params\n )\n logger.info(f\"Collection {collection_name} created successfully.\")\n\n def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n ) -> None:\n \"\"\"\n Insert vectors into a collection.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n if ids is None:\n ids = [i for i in range(len(vectors))]\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n points = [\n PointStruct(\n id=id_,\n vector=vector,\n payload=payload\n )\n for id_, vector, payload in zip(ids, vectors, payloads)\n ]\n self.client.upsert(\n collection_name=collection_name,\n points=points\n )\n logger.info(f\"Inserted {len(vectors)} vectors into collection {collection_name}.\")\n\n def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n query_filter = self._build_filter(filters)\n results = self.client.search(\n collection_name=collection_name,\n query_vector=query_vector,\n limit=top_k,\n query_filter=query_filter\n )\n\n output = []\n for hit in results:\n result = {\n 'id': hit.id,\n 'score': hit.score,\n 'payload': hit.payload\n }\n output.append(result)\n return output\n\n def _build_filter(self, filters: Optional[Dict[str, Any]]) -> Optional[Filter]:\n \"\"\"\n Build a Qdrant filter object from a dictionary.\n\n Args:\n filters (Optional[Dict[str, Any]]): Filters to apply.\n\n Returns:\n Optional[Filter]: A Qdrant Filter object.\n \"\"\"\n if not filters:\n return None\n\n conditions = []\n for key, value in filters.items():\n conditions.append(\n FieldCondition(\n key=key,\n match=MatchValue(value=value)\n )\n )\n return Filter(must=conditions)\n\n def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n self.client.delete(\n collection_name=collection_name,\n points_selector=[vector_id]\n )\n logger.info(f\"Deleted vector with ID {vector_id} from collection {collection_name}.\")\n\n def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n ) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n point = PointStruct(\n id=vector_id,\n vector=vector,\n payload=payload\n )\n self.client.upsert(\n collection_name=collection_name,\n points=[point]\n )\n logger.info(f\"Updated vector with ID {vector_id} in collection {collection_name}.\")\n\n def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n result = self.client.retrieve(\n collection_name=collection_name,\n ids=[vector_id]\n )\n if not result:\n raise KeyError(f\"Vector with ID {vector_id} not found in collection {collection_name}.\")\n\n point = result[0]\n vector_data = {\n 'id': point.id,\n 'vector': point.vector,\n 'payload': point.payload\n }\n return vector_data\n\n def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections.\n\n Returns:\n List[str]: A list of collection names.\n \"\"\"\n collections = self.client.get_collections().collections\n return [collection.name for collection in collections]\n\n def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n \"\"\"\n self.client.delete_collection(collection_name=collection_name)\n logger.info(f\"Deleted collection {collection_name}.\")\n\n def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection.\n \"\"\"\n info = self.client.get_collection(collection_name=collection_name)\n return info.dict()\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.__init__","title":"__init__(config)
","text":"Initialize the Qdrant vector store.
Parameters:
Name Type Description Default config
Dict[str, Any]
Configuration dictionary.
required Source code in src/aeiva/storage/qdrant/qdrant_database.py
def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the Qdrant vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.collection_name = config.get('collection_name')\n self.embedding_model_dims = config.get('embedding_model_dims')\n self.client = config.get('client')\n self.host = config.get('host')\n self.port = config.get('port')\n self.path = config.get('path')\n self.url = config.get('url')\n self.api_key = config.get('api_key')\n self.on_disk = config.get('on_disk', False)\n\n if not all([self.collection_name, self.embedding_model_dims]):\n raise ValueError(\"Required configuration parameters are missing.\")\n\n self.create_client()\n self.create_collection(\n collection_name=self.collection_name,\n vector_size=self.embedding_model_dims,\n distance_metric='COSINE'\n )\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.create_client","title":"create_client(**kwargs)
","text":"Initializes the client connection to the Qdrant vector store.
Parameters:
Name Type Description Default **kwargs
Additional parameters.
{}
Source code in src/aeiva/storage/qdrant/qdrant_database.py
def create_client(self, **kwargs) -> None:\n \"\"\"\n Initializes the client connection to the Qdrant vector store.\n\n Args:\n **kwargs: Additional parameters.\n \"\"\"\n if self.client:\n return # Client already provided\n\n client_params = {}\n if self.api_key:\n client_params['api_key'] = self.api_key\n if self.url:\n client_params['url'] = self.url\n elif self.host and self.port:\n client_params['host'] = self.host\n client_params['port'] = self.port\n else:\n client_params['path'] = self.path\n\n self.client = QdrantClient(**client_params)\n logger.info(\"Qdrant client initialized.\")\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.create_collection","title":"create_collection(collection_name, vector_size, distance_metric)
","text":"Create a new vector collection in Qdrant.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_size
int
The dimensionality of the vectors.
required distance_metric
str
The distance metric to use (e.g., 'COSINE').
required Source code in src/aeiva/storage/qdrant/qdrant_database.py
def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection in Qdrant.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use (e.g., 'COSINE').\n \"\"\"\n # Check if collection exists\n collections = self.list_collections()\n if collection_name in collections:\n logger.info(f\"Collection {collection_name} already exists. Skipping creation.\")\n return\n\n vector_params = VectorParams(\n size=vector_size,\n distance=getattr(Distance, distance_metric.upper()),\n on_disk=self.on_disk\n )\n self.client.create_collection(\n collection_name=collection_name,\n vectors_config=vector_params\n )\n logger.info(f\"Collection {collection_name} created successfully.\")\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.delete_collection","title":"delete_collection(collection_name)
","text":"Delete an entire vector collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection to delete.
required Source code in src/aeiva/storage/qdrant/qdrant_database.py
def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n \"\"\"\n self.client.delete_collection(collection_name=collection_name)\n logger.info(f\"Deleted collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.delete_vector","title":"delete_vector(collection_name, vector_id)
","text":"Delete a vector from a collection by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to delete.
required Source code in src/aeiva/storage/qdrant/qdrant_database.py
def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n self.client.delete(\n collection_name=collection_name,\n points_selector=[vector_id]\n )\n logger.info(f\"Deleted vector with ID {vector_id} from collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.get_collection_info","title":"get_collection_info(collection_name)
","text":"Get information about a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: Information about the collection.
Source code in src/aeiva/storage/qdrant/qdrant_database.py
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection.\n \"\"\"\n info = self.client.get_collection(collection_name=collection_name)\n return info.dict()\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.get_vector","title":"get_vector(collection_name, vector_id)
","text":"Retrieve a vector by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the vector data and payload.
Source code in src/aeiva/storage/qdrant/qdrant_database.py
def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n result = self.client.retrieve(\n collection_name=collection_name,\n ids=[vector_id]\n )\n if not result:\n raise KeyError(f\"Vector with ID {vector_id} not found in collection {collection_name}.\")\n\n point = result[0]\n vector_data = {\n 'id': point.id,\n 'vector': point.vector,\n 'payload': point.payload\n }\n return vector_data\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.insert_vectors","title":"insert_vectors(collection_name, vectors, payloads=None, ids=None)
","text":"Insert vectors into a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vectors
List[List[float]]
A list of vectors to insert.
required payloads
Optional[List[Dict[str, Any]]]
Optional metadata associated with each vector.
None
ids
Optional[List[str]]
Optional unique identifiers for each vector.
None
Source code in src/aeiva/storage/qdrant/qdrant_database.py
def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n) -> None:\n \"\"\"\n Insert vectors into a collection.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n if ids is None:\n ids = [i for i in range(len(vectors))]\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n points = [\n PointStruct(\n id=id_,\n vector=vector,\n payload=payload\n )\n for id_, vector, payload in zip(ids, vectors, payloads)\n ]\n self.client.upsert(\n collection_name=collection_name,\n points=points\n )\n logger.info(f\"Inserted {len(vectors)} vectors into collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.list_collections","title":"list_collections()
","text":"List all available vector collections.
Returns:
Type Description List[str]
List[str]: A list of collection names.
Source code in src/aeiva/storage/qdrant/qdrant_database.py
def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections.\n\n Returns:\n List[str]: A list of collection names.\n \"\"\"\n collections = self.client.get_collections().collections\n return [collection.name for collection in collections]\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.search_vectors","title":"search_vectors(collection_name, query_vector, top_k=5, filters=None)
","text":"Search for similar vectors in a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required query_vector
List[float]
The vector to search with.
required top_k
int
The number of top results to return.
5
filters
Optional[Dict[str, Any]]
Optional filters to apply to the search.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of search results.
Source code in src/aeiva/storage/qdrant/qdrant_database.py
def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n query_filter = self._build_filter(filters)\n results = self.client.search(\n collection_name=collection_name,\n query_vector=query_vector,\n limit=top_k,\n query_filter=query_filter\n )\n\n output = []\n for hit in results:\n result = {\n 'id': hit.id,\n 'score': hit.score,\n 'payload': hit.payload\n }\n output.append(result)\n return output\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.update_vector","title":"update_vector(collection_name, vector_id, vector=None, payload=None)
","text":"Update a vector's data or payload.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to update.
required vector
Optional[List[float]]
The new vector data.
None
payload
Optional[Dict[str, Any]]
The new payload data.
None
Source code in src/aeiva/storage/qdrant/qdrant_database.py
def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n point = PointStruct(\n id=vector_id,\n vector=vector,\n payload=payload\n )\n self.client.upsert(\n collection_name=collection_name,\n points=[point]\n )\n logger.info(f\"Updated vector with ID {vector_id} in collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.relational_database","title":"relational_database
","text":""},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase","title":"RelationalDatabase
","text":" Bases: ABC
Abstract base class for relational database operations.
Source code in src/aeiva/storage/relational_database.py
class RelationalDatabase(ABC):\n \"\"\"\n Abstract base class for relational database operations.\n \"\"\"\n\n @abstractmethod\n def insert_record(self, table: str, record: Dict[str, Any]) -> Any:\n \"\"\"\n Inserts a record into a table.\n\n Args:\n table (str): The name of the table.\n record (Dict[str, Any]): A dictionary representing the record to insert.\n\n Returns:\n Any: The primary key of the inserted record.\n\n Raises:\n StorageError: If there is an issue inserting the record.\n \"\"\"\n pass\n\n @abstractmethod\n def get_record(self, table: str, primary_key: Any) -> Dict[str, Any]:\n \"\"\"\n Retrieves a record by its primary key.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Returns:\n Dict[str, Any]: The retrieved record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue retrieving the record.\n \"\"\"\n pass\n\n @abstractmethod\n def update_record(self, table: str, primary_key: Any, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a record in a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n updates (Dict[str, Any]): A dictionary of fields to update.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue updating the record.\n \"\"\"\n pass\n\n @abstractmethod\n def delete_record(self, table: str, primary_key: Any) -> None:\n \"\"\"\n Deletes a record from a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue deleting the record.\n \"\"\"\n pass\n\n @abstractmethod\n def query_records(self, table: str, conditions: Optional[Dict[str, Any]] = None, limit: Optional[int] = None, offset: Optional[int] = None) -> List[Dict[str, Any]]:\n \"\"\"\n Queries records from a table based on conditions.\n\n Args:\n table (str): The name of the table.\n conditions (Optional[Dict[str, Any]]): Conditions to filter records.\n limit (Optional[int]): Maximum number of records to return.\n offset (Optional[int]): Number of records to skip.\n\n Returns:\n List[Dict[str, Any]]: A list of records matching the query.\n\n Raises:\n StorageError: If there is an issue querying records.\n \"\"\"\n pass\n\n @abstractmethod\n def execute_sql(self, query: str, parameters: Optional[Dict[str, Any]] = None) -> Any:\n \"\"\"\n Executes a raw SQL query.\n\n Args:\n query (str): The SQL query string.\n parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.\n\n Returns:\n Any: The result of the query.\n\n Raises:\n StorageError: If there is an issue executing the query.\n \"\"\"\n pass\n\n @abstractmethod\n def begin_transaction(self) -> None:\n \"\"\"\n Begins a transaction.\n \"\"\"\n pass\n\n @abstractmethod\n def commit_transaction(self) -> None:\n \"\"\"\n Commits the current transaction.\n \"\"\"\n pass\n\n @abstractmethod\n def rollback_transaction(self) -> None:\n \"\"\"\n Rolls back the current transaction.\n \"\"\"\n pass\n\n @abstractmethod\n def close(self) -> None:\n \"\"\"\n Closes the database connection and releases resources.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase.begin_transaction","title":"begin_transaction()
abstractmethod
","text":"Begins a transaction.
Source code in src/aeiva/storage/relational_database.py
@abstractmethod\ndef begin_transaction(self) -> None:\n \"\"\"\n Begins a transaction.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase.close","title":"close()
abstractmethod
","text":"Closes the database connection and releases resources.
Source code in src/aeiva/storage/relational_database.py
@abstractmethod\ndef close(self) -> None:\n \"\"\"\n Closes the database connection and releases resources.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase.commit_transaction","title":"commit_transaction()
abstractmethod
","text":"Commits the current transaction.
Source code in src/aeiva/storage/relational_database.py
@abstractmethod\ndef commit_transaction(self) -> None:\n \"\"\"\n Commits the current transaction.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase.delete_record","title":"delete_record(table, primary_key)
abstractmethod
","text":"Deletes a record from a table.
Parameters:
Name Type Description Default table
str
The name of the table.
required primary_key
Any
The primary key of the record.
required Raises:
Type Description RecordNotFoundError
If the record does not exist.
StorageError
If there is an issue deleting the record.
Source code in src/aeiva/storage/relational_database.py
@abstractmethod\ndef delete_record(self, table: str, primary_key: Any) -> None:\n \"\"\"\n Deletes a record from a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue deleting the record.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase.execute_sql","title":"execute_sql(query, parameters=None)
abstractmethod
","text":"Executes a raw SQL query.
Parameters:
Name Type Description Default query
str
The SQL query string.
required parameters
Optional[Dict[str, Any]]
Parameters for parameterized queries.
None
Returns:
Name Type Description Any
Any
The result of the query.
Raises:
Type Description StorageError
If there is an issue executing the query.
Source code in src/aeiva/storage/relational_database.py
@abstractmethod\ndef execute_sql(self, query: str, parameters: Optional[Dict[str, Any]] = None) -> Any:\n \"\"\"\n Executes a raw SQL query.\n\n Args:\n query (str): The SQL query string.\n parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.\n\n Returns:\n Any: The result of the query.\n\n Raises:\n StorageError: If there is an issue executing the query.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase.get_record","title":"get_record(table, primary_key)
abstractmethod
","text":"Retrieves a record by its primary key.
Parameters:
Name Type Description Default table
str
The name of the table.
required primary_key
Any
The primary key of the record.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: The retrieved record.
Raises:
Type Description RecordNotFoundError
If the record does not exist.
StorageError
If there is an issue retrieving the record.
Source code in src/aeiva/storage/relational_database.py
@abstractmethod\ndef get_record(self, table: str, primary_key: Any) -> Dict[str, Any]:\n \"\"\"\n Retrieves a record by its primary key.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Returns:\n Dict[str, Any]: The retrieved record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue retrieving the record.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase.insert_record","title":"insert_record(table, record)
abstractmethod
","text":"Inserts a record into a table.
Parameters:
Name Type Description Default table
str
The name of the table.
required record
Dict[str, Any]
A dictionary representing the record to insert.
required Returns:
Name Type Description Any
Any
The primary key of the inserted record.
Raises:
Type Description StorageError
If there is an issue inserting the record.
Source code in src/aeiva/storage/relational_database.py
@abstractmethod\ndef insert_record(self, table: str, record: Dict[str, Any]) -> Any:\n \"\"\"\n Inserts a record into a table.\n\n Args:\n table (str): The name of the table.\n record (Dict[str, Any]): A dictionary representing the record to insert.\n\n Returns:\n Any: The primary key of the inserted record.\n\n Raises:\n StorageError: If there is an issue inserting the record.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase.query_records","title":"query_records(table, conditions=None, limit=None, offset=None)
abstractmethod
","text":"Queries records from a table based on conditions.
Parameters:
Name Type Description Default table
str
The name of the table.
required conditions
Optional[Dict[str, Any]]
Conditions to filter records.
None
limit
Optional[int]
Maximum number of records to return.
None
offset
Optional[int]
Number of records to skip.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of records matching the query.
Raises:
Type Description StorageError
If there is an issue querying records.
Source code in src/aeiva/storage/relational_database.py
@abstractmethod\ndef query_records(self, table: str, conditions: Optional[Dict[str, Any]] = None, limit: Optional[int] = None, offset: Optional[int] = None) -> List[Dict[str, Any]]:\n \"\"\"\n Queries records from a table based on conditions.\n\n Args:\n table (str): The name of the table.\n conditions (Optional[Dict[str, Any]]): Conditions to filter records.\n limit (Optional[int]): Maximum number of records to return.\n offset (Optional[int]): Number of records to skip.\n\n Returns:\n List[Dict[str, Any]]: A list of records matching the query.\n\n Raises:\n StorageError: If there is an issue querying records.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase.rollback_transaction","title":"rollback_transaction()
abstractmethod
","text":"Rolls back the current transaction.
Source code in src/aeiva/storage/relational_database.py
@abstractmethod\ndef rollback_transaction(self) -> None:\n \"\"\"\n Rolls back the current transaction.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase.update_record","title":"update_record(table, primary_key, updates)
abstractmethod
","text":"Updates a record in a table.
Parameters:
Name Type Description Default table
str
The name of the table.
required primary_key
Any
The primary key of the record.
required updates
Dict[str, Any]
A dictionary of fields to update.
required Raises:
Type Description RecordNotFoundError
If the record does not exist.
StorageError
If there is an issue updating the record.
Source code in src/aeiva/storage/relational_database.py
@abstractmethod\ndef update_record(self, table: str, primary_key: Any, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a record in a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n updates (Dict[str, Any]): A dictionary of fields to update.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue updating the record.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.sqlite","title":"sqlite
","text":""},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_config","title":"sqlite_config
","text":""},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_config.SQLiteConfig","title":"SQLiteConfig
dataclass
","text":" Bases: BaseConfig
Configuration for SQLite database.
Source code in src/aeiva/storage/sqlite/sqlite_config.py
@dataclass\nclass SQLiteConfig(BaseConfig):\n \"\"\"\n Configuration for SQLite database.\n \"\"\"\n database: str = field(\n default=':memory:',\n metadata={\"help\": \"Path to the SQLite database file. Use ':memory:' for an in-memory database.\"}\n )\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database","title":"sqlite_database
","text":""},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.RecordNotFoundError","title":"RecordNotFoundError
","text":" Bases: Exception
Exception raised when a record is not found in the database.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
class RecordNotFoundError(Exception):\n \"\"\"Exception raised when a record is not found in the database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase","title":"SQLiteDatabase
","text":" Bases: RelationalDatabase
Concrete implementation of RelationalStoreBase using SQLite.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
class SQLiteDatabase(RelationalDatabase):\n \"\"\"\n Concrete implementation of RelationalStoreBase using SQLite.\n \"\"\"\n\n def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the SQLite database connection.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.database = config.get('database', ':memory:')\n self.connection = None\n self.cursor = None\n self.connect()\n\n def connect(self) -> None:\n \"\"\"\n Establishes a connection to the SQLite database.\n \"\"\"\n try:\n self.connection = sqlite3.connect(self.database)\n self.connection.row_factory = sqlite3.Row # To get dict-like rows\n self.cursor = self.connection.cursor()\n # self.connection.execute('PRAGMA foreign_keys = ON') # Enable foreign key support\n except sqlite3.Error as e:\n raise ConnectionError(f\"Failed to connect to SQLite database: {e}\")\n\n def close(self) -> None:\n \"\"\"\n Closes the database connection and releases resources.\n \"\"\"\n if self.cursor:\n self.cursor.close()\n if self.connection:\n self.connection.close()\n\n def insert_record(self, table: str, record: Dict[str, Any]) -> Any:\n \"\"\"\n Inserts a record into a table.\n\n Args:\n table (str): The name of the table.\n record (Dict[str, Any]): A dictionary representing the record to insert.\n\n Returns:\n Any: The primary key of the inserted record.\n\n Raises:\n StorageError: If there is an issue inserting the record.\n \"\"\"\n try:\n columns = ', '.join(record.keys())\n placeholders = ', '.join('?' for _ in record)\n sql = f\"INSERT INTO {table} ({columns}) VALUES ({placeholders})\"\n values = list(record.values())\n self.cursor.execute(sql, values)\n self.connection.commit()\n return self.cursor.lastrowid\n except sqlite3.IntegrityError as e:\n self.connection.rollback()\n raise StorageError(f\"Integrity error: {e}\")\n except sqlite3.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to insert record: {e}\")\n\n def get_record(self, table: str, primary_key: Any) -> Dict[str, Any]:\n \"\"\"\n Retrieves a record by its primary key.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Returns:\n Dict[str, Any]: The retrieved record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue retrieving the record.\n \"\"\"\n try:\n sql = f\"SELECT * FROM {table} WHERE id = ?\"\n self.cursor.execute(sql, (primary_key,))\n row = self.cursor.fetchone()\n if row is None:\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n return dict(row)\n except sqlite3.Error as e:\n raise StorageError(f\"Failed to get record: {e}\")\n\n def update_record(self, table: str, primary_key: Any, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a record in a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n updates (Dict[str, Any]): A dictionary of fields to update.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue updating the record.\n \"\"\"\n try:\n set_clause = ', '.join(f\"{key} = ?\" for key in updates.keys())\n sql = f\"UPDATE {table} SET {set_clause} WHERE id = ?\"\n values = list(updates.values()) + [primary_key]\n self.cursor.execute(sql, values)\n if self.cursor.rowcount == 0:\n self.connection.rollback()\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n self.connection.commit()\n except sqlite3.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to update record: {e}\")\n\n def delete_record(self, table: str, primary_key: Any) -> None:\n \"\"\"\n Deletes a record from a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue deleting the record.\n \"\"\"\n try:\n sql = f\"DELETE FROM {table} WHERE id = ?\"\n self.cursor.execute(sql, (primary_key,))\n if self.cursor.rowcount == 0:\n self.connection.rollback()\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n self.connection.commit()\n except sqlite3.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to delete record: {e}\")\n\n def query_records(\n self,\n table: str,\n conditions: Optional[Dict[str, Any]] = None,\n limit: Optional[int] = None,\n offset: Optional[int] = None\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Queries records from a table based on conditions.\n\n Args:\n table (str): The name of the table.\n conditions (Optional[Dict[str, Any]]): Conditions to filter records.\n limit (Optional[int]): Maximum number of records to return.\n offset (Optional[int]): Number of records to skip.\n\n Returns:\n List[Dict[str, Any]]: A list of records matching the query.\n\n Raises:\n StorageError: If there is an issue querying records.\n \"\"\"\n try:\n sql = f\"SELECT * FROM {table}\"\n params = []\n if conditions:\n where_clause = ' AND '.join(f\"{key} = ?\" for key in conditions.keys())\n sql += f\" WHERE {where_clause}\"\n params.extend(conditions.values())\n if limit is not None:\n sql += f\" LIMIT {limit}\"\n if offset is not None:\n sql += f\" OFFSET {offset}\"\n self.cursor.execute(sql, params)\n rows = self.cursor.fetchall()\n return [dict(row) for row in rows]\n except sqlite3.Error as e:\n raise StorageError(f\"Failed to query records: {e}\")\n\n def execute_sql(self, query: str, params: Optional[Tuple] = None):\n \"\"\"\n Executes a SQL query and returns the cursor.\n\n Args:\n query (str): The SQL query to execute.\n params (Optional[Tuple]): Parameters to substitute into the query.\n\n Returns:\n sqlite3.Cursor: The cursor after executing the query.\n \"\"\"\n cursor = self.connection.cursor()\n try:\n if params:\n cursor.execute(query, params)\n else:\n cursor.execute(query)\n # For SELECT queries, do not commit. For INSERT/UPDATE/DELETE, you may need to commit.\n if query.strip().upper().startswith(\"SELECT\"):\n return cursor\n else:\n self.connection.commit()\n return cursor\n except sqlite3.Error as e:\n print(f\"SQLite query failed: {e}\")\n raise e\n\n def begin_transaction(self) -> None:\n \"\"\"\n Begins a transaction.\n \"\"\"\n self.connection.isolation_level = None\n self.cursor.execute('BEGIN')\n\n def commit_transaction(self) -> None:\n \"\"\"\n Commits the current transaction.\n \"\"\"\n self.connection.commit()\n self.connection.isolation_level = None\n\n def rollback_transaction(self) -> None:\n \"\"\"\n Rolls back the current transaction.\n \"\"\"\n self.connection.rollback()\n self.connection.isolation_level = None\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.__init__","title":"__init__(config)
","text":"Initialize the SQLite database connection.
Parameters:
Name Type Description Default config
Dict[str, Any]
Configuration dictionary.
required Source code in src/aeiva/storage/sqlite/sqlite_database.py
def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the SQLite database connection.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.database = config.get('database', ':memory:')\n self.connection = None\n self.cursor = None\n self.connect()\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.begin_transaction","title":"begin_transaction()
","text":"Begins a transaction.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def begin_transaction(self) -> None:\n \"\"\"\n Begins a transaction.\n \"\"\"\n self.connection.isolation_level = None\n self.cursor.execute('BEGIN')\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.close","title":"close()
","text":"Closes the database connection and releases resources.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def close(self) -> None:\n \"\"\"\n Closes the database connection and releases resources.\n \"\"\"\n if self.cursor:\n self.cursor.close()\n if self.connection:\n self.connection.close()\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.commit_transaction","title":"commit_transaction()
","text":"Commits the current transaction.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def commit_transaction(self) -> None:\n \"\"\"\n Commits the current transaction.\n \"\"\"\n self.connection.commit()\n self.connection.isolation_level = None\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.connect","title":"connect()
","text":"Establishes a connection to the SQLite database.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def connect(self) -> None:\n \"\"\"\n Establishes a connection to the SQLite database.\n \"\"\"\n try:\n self.connection = sqlite3.connect(self.database)\n self.connection.row_factory = sqlite3.Row # To get dict-like rows\n self.cursor = self.connection.cursor()\n # self.connection.execute('PRAGMA foreign_keys = ON') # Enable foreign key support\n except sqlite3.Error as e:\n raise ConnectionError(f\"Failed to connect to SQLite database: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.delete_record","title":"delete_record(table, primary_key)
","text":"Deletes a record from a table.
Parameters:
Name Type Description Default table
str
The name of the table.
required primary_key
Any
The primary key of the record.
required Raises:
Type Description RecordNotFoundError
If the record does not exist.
StorageError
If there is an issue deleting the record.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def delete_record(self, table: str, primary_key: Any) -> None:\n \"\"\"\n Deletes a record from a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue deleting the record.\n \"\"\"\n try:\n sql = f\"DELETE FROM {table} WHERE id = ?\"\n self.cursor.execute(sql, (primary_key,))\n if self.cursor.rowcount == 0:\n self.connection.rollback()\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n self.connection.commit()\n except sqlite3.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to delete record: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.execute_sql","title":"execute_sql(query, params=None)
","text":"Executes a SQL query and returns the cursor.
Parameters:
Name Type Description Default query
str
The SQL query to execute.
required params
Optional[Tuple]
Parameters to substitute into the query.
None
Returns:
Type Description sqlite3.Cursor: The cursor after executing the query.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def execute_sql(self, query: str, params: Optional[Tuple] = None):\n \"\"\"\n Executes a SQL query and returns the cursor.\n\n Args:\n query (str): The SQL query to execute.\n params (Optional[Tuple]): Parameters to substitute into the query.\n\n Returns:\n sqlite3.Cursor: The cursor after executing the query.\n \"\"\"\n cursor = self.connection.cursor()\n try:\n if params:\n cursor.execute(query, params)\n else:\n cursor.execute(query)\n # For SELECT queries, do not commit. For INSERT/UPDATE/DELETE, you may need to commit.\n if query.strip().upper().startswith(\"SELECT\"):\n return cursor\n else:\n self.connection.commit()\n return cursor\n except sqlite3.Error as e:\n print(f\"SQLite query failed: {e}\")\n raise e\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.get_record","title":"get_record(table, primary_key)
","text":"Retrieves a record by its primary key.
Parameters:
Name Type Description Default table
str
The name of the table.
required primary_key
Any
The primary key of the record.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: The retrieved record.
Raises:
Type Description RecordNotFoundError
If the record does not exist.
StorageError
If there is an issue retrieving the record.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def get_record(self, table: str, primary_key: Any) -> Dict[str, Any]:\n \"\"\"\n Retrieves a record by its primary key.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Returns:\n Dict[str, Any]: The retrieved record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue retrieving the record.\n \"\"\"\n try:\n sql = f\"SELECT * FROM {table} WHERE id = ?\"\n self.cursor.execute(sql, (primary_key,))\n row = self.cursor.fetchone()\n if row is None:\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n return dict(row)\n except sqlite3.Error as e:\n raise StorageError(f\"Failed to get record: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.insert_record","title":"insert_record(table, record)
","text":"Inserts a record into a table.
Parameters:
Name Type Description Default table
str
The name of the table.
required record
Dict[str, Any]
A dictionary representing the record to insert.
required Returns:
Name Type Description Any
Any
The primary key of the inserted record.
Raises:
Type Description StorageError
If there is an issue inserting the record.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def insert_record(self, table: str, record: Dict[str, Any]) -> Any:\n \"\"\"\n Inserts a record into a table.\n\n Args:\n table (str): The name of the table.\n record (Dict[str, Any]): A dictionary representing the record to insert.\n\n Returns:\n Any: The primary key of the inserted record.\n\n Raises:\n StorageError: If there is an issue inserting the record.\n \"\"\"\n try:\n columns = ', '.join(record.keys())\n placeholders = ', '.join('?' for _ in record)\n sql = f\"INSERT INTO {table} ({columns}) VALUES ({placeholders})\"\n values = list(record.values())\n self.cursor.execute(sql, values)\n self.connection.commit()\n return self.cursor.lastrowid\n except sqlite3.IntegrityError as e:\n self.connection.rollback()\n raise StorageError(f\"Integrity error: {e}\")\n except sqlite3.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to insert record: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.query_records","title":"query_records(table, conditions=None, limit=None, offset=None)
","text":"Queries records from a table based on conditions.
Parameters:
Name Type Description Default table
str
The name of the table.
required conditions
Optional[Dict[str, Any]]
Conditions to filter records.
None
limit
Optional[int]
Maximum number of records to return.
None
offset
Optional[int]
Number of records to skip.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of records matching the query.
Raises:
Type Description StorageError
If there is an issue querying records.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def query_records(\n self,\n table: str,\n conditions: Optional[Dict[str, Any]] = None,\n limit: Optional[int] = None,\n offset: Optional[int] = None\n) -> List[Dict[str, Any]]:\n \"\"\"\n Queries records from a table based on conditions.\n\n Args:\n table (str): The name of the table.\n conditions (Optional[Dict[str, Any]]): Conditions to filter records.\n limit (Optional[int]): Maximum number of records to return.\n offset (Optional[int]): Number of records to skip.\n\n Returns:\n List[Dict[str, Any]]: A list of records matching the query.\n\n Raises:\n StorageError: If there is an issue querying records.\n \"\"\"\n try:\n sql = f\"SELECT * FROM {table}\"\n params = []\n if conditions:\n where_clause = ' AND '.join(f\"{key} = ?\" for key in conditions.keys())\n sql += f\" WHERE {where_clause}\"\n params.extend(conditions.values())\n if limit is not None:\n sql += f\" LIMIT {limit}\"\n if offset is not None:\n sql += f\" OFFSET {offset}\"\n self.cursor.execute(sql, params)\n rows = self.cursor.fetchall()\n return [dict(row) for row in rows]\n except sqlite3.Error as e:\n raise StorageError(f\"Failed to query records: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.rollback_transaction","title":"rollback_transaction()
","text":"Rolls back the current transaction.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def rollback_transaction(self) -> None:\n \"\"\"\n Rolls back the current transaction.\n \"\"\"\n self.connection.rollback()\n self.connection.isolation_level = None\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.update_record","title":"update_record(table, primary_key, updates)
","text":"Updates a record in a table.
Parameters:
Name Type Description Default table
str
The name of the table.
required primary_key
Any
The primary key of the record.
required updates
Dict[str, Any]
A dictionary of fields to update.
required Raises:
Type Description RecordNotFoundError
If the record does not exist.
StorageError
If there is an issue updating the record.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def update_record(self, table: str, primary_key: Any, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a record in a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n updates (Dict[str, Any]): A dictionary of fields to update.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue updating the record.\n \"\"\"\n try:\n set_clause = ', '.join(f\"{key} = ?\" for key in updates.keys())\n sql = f\"UPDATE {table} SET {set_clause} WHERE id = ?\"\n values = list(updates.values()) + [primary_key]\n self.cursor.execute(sql, values)\n if self.cursor.rowcount == 0:\n self.connection.rollback()\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n self.connection.commit()\n except sqlite3.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to update record: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.StorageError","title":"StorageError
","text":" Bases: Exception
Exception raised when there is a storage-related error in the database.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
class StorageError(Exception):\n \"\"\"Exception raised when there is a storage-related error in the database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.sqlite.test","title":"test
","text":""},{"location":"reference/#src.aeiva.storage.sqlite.test.RecordNotFoundError","title":"RecordNotFoundError
","text":" Bases: Exception
Exception raised when a record is not found in the database.
Source code in src/aeiva/storage/sqlite/test.py
class RecordNotFoundError(Exception):\n \"\"\"Exception raised when a record is not found in the database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.test","title":"test
","text":""},{"location":"reference/#src.aeiva.storage.test.main","title":"main()
","text":"Main function to run tests for Milvus, Neo4j, and SQLite databases.
Source code in src/aeiva/storage/test.py
def main():\n \"\"\"\n Main function to run tests for Milvus, Neo4j, and SQLite databases.\n \"\"\"\n test_milvus()\n test_neo4j()\n test_sqlite()\n
"},{"location":"reference/#src.aeiva.storage.test.test_milvus","title":"test_milvus()
","text":"Test the DatabaseFactory and DatabaseConfigFactory with Milvus database.
Source code in src/aeiva/storage/test.py
def test_milvus():\n \"\"\"\n Test the DatabaseFactory and DatabaseConfigFactory with Milvus database.\n \"\"\"\n print(\"\\n--- Testing Milvus Database ---\")\n # Create configuration for Milvus\n milvus_config = DatabaseConfigFactory.create(\n 'milvus',\n # uri='tcp://localhost:19530',\n uri='storage/milvus_demo.db',\n collection_name='test_collection',\n embedding_model_dims=128,\n metric_type='COSINE',\n )\n\n # Create Milvus database instance\n milvus_db = DatabaseFactory.create('milvus', milvus_config)\n\n try:\n # Prepare sample data\n vector_dimension = milvus_config.embedding_model_dims\n vectors = [\n [float(i) for i in range(vector_dimension)], # Sample vector 1\n [float(i + 1) for i in range(vector_dimension)], # Sample vector 2\n ]\n payloads = [\n {'name': 'Vector 1', 'description': 'First test vector.'},\n {'name': 'Vector 2', 'description': 'Second test vector.'},\n ]\n ids = [str(uuid.uuid4()), str(uuid.uuid4())] # Generate unique IDs\n\n # Insert vectors into the collection\n milvus_db.insert_vectors(\n collection_name=milvus_config.collection_name,\n vectors=vectors,\n payloads=payloads,\n ids=ids\n )\n logging.info(f\"Inserted vectors with IDs: {ids}\")\n\n # Search for similar vectors\n query_vector = [float(i + 0.5) for i in range(vector_dimension)] # Query vector\n search_results = milvus_db.search_vectors(\n collection_name=milvus_config.collection_name,\n query_vector=query_vector,\n top_k=2\n )\n print(f\"Milvus Search results:\\n{search_results}\")\n\n except Exception as e:\n logging.error(f\"An error occurred while testing Milvus: {e}\")\n finally:\n # Close the connection\n del milvus_db\n
"},{"location":"reference/#src.aeiva.storage.test.test_neo4j","title":"test_neo4j()
","text":"Test the DatabaseFactory and DatabaseConfigFactory with Neo4j database.
Source code in src/aeiva/storage/test.py
def test_neo4j():\n \"\"\"\n Test the DatabaseFactory and DatabaseConfigFactory with Neo4j database.\n \"\"\"\n print(\"\\n--- Testing Neo4j Database ---\")\n # Create configuration for Neo4j\n neo4j_config = DatabaseConfigFactory.create(\n 'neo4j',\n uri='bolt://localhost:7687',\n user='neo4j',\n password='cf57bwP9pcdcEK3', # Replace with your actual password\n database='neo4j',\n encrypted=False,\n )\n\n # Create Neo4j database instance\n neo4j_db = DatabaseFactory.create('neo4j', neo4j_config)\n\n try:\n # Add a node\n node_id = 'node1'\n neo4j_db.add_node(\n node_id=node_id,\n properties={'name': 'Alice', 'age': 30},\n labels=['Person']\n )\n logging.info(f\"Added node with ID: {node_id}\")\n\n # Retrieve the node\n node_data = neo4j_db.get_node(node_id)\n print(f\"Neo4j Node data: {node_data}\")\n\n # Add another node and create a relationship\n node_id2 = 'node2'\n neo4j_db.add_node(\n node_id=node_id2,\n properties={'name': 'Bob', 'age': 25},\n labels=['Person']\n )\n neo4j_db.add_edge(\n source_id=node_id,\n target_id=node_id2,\n relationship='KNOWS',\n properties={'since': 2020}\n )\n logging.info(f\"Added edge between {node_id} and {node_id2}\")\n\n # Get neighbors\n neighbors = neo4j_db.get_neighbors(node_id, relationship='KNOWS', direction='out')\n print(f\"Neo4j Neighbors of {node_id}: {neighbors}\")\n\n except Exception as e:\n logging.error(f\"An error occurred while testing Neo4j: {e}\")\n finally:\n # Close the connection\n neo4j_db.close()\n
"},{"location":"reference/#src.aeiva.storage.test.test_sqlite","title":"test_sqlite()
","text":"Test the DatabaseFactory and DatabaseConfigFactory with SQLite database.
Source code in src/aeiva/storage/test.py
def test_sqlite():\n \"\"\"\n Test the DatabaseFactory and DatabaseConfigFactory with SQLite database.\n \"\"\"\n print(\"\\n--- Testing SQLite Database ---\")\n # Create configuration for SQLite\n sqlite_config = DatabaseConfigFactory.create(\n 'sqlite',\n database='storage/test_database.db' # Use a file-based database for persistence\n )\n\n # Create SQLite database instance\n sqlite_db = DatabaseFactory.create('sqlite', sqlite_config)\n\n try:\n # Create a sample table\n create_table_sql = \"\"\"\n CREATE TABLE IF NOT EXISTS users (\n id INTEGER PRIMARY KEY AUTOINCREMENT,\n name TEXT NOT NULL,\n age INTEGER,\n email TEXT UNIQUE\n );\n \"\"\"\n sqlite_db.execute_sql(create_table_sql)\n logging.info(\"Created table 'users' in SQLite database.\")\n\n # Insert a record\n record = {'name': 'Alice', 'age': 30, 'email': 'alice@example.com'}\n user_id = sqlite_db.insert_record('users', record)\n logging.info(f\"Inserted user with ID: {user_id}\")\n\n # Retrieve the record\n retrieved_record = sqlite_db.get_record('users', user_id)\n print(f\"SQLite Retrieved record: {retrieved_record}\")\n\n # Update the record\n updates = {'age': 31}\n sqlite_db.update_record('users', user_id, updates)\n logging.info(f\"Updated user with ID: {user_id}\")\n\n # Query records\n conditions = {'age': 31}\n users = sqlite_db.query_records('users', conditions)\n print(f\"SQLite Users with age 31: {users}\")\n\n except Exception as e:\n logging.error(f\"An error occurred while testing SQLite: {e}\")\n finally:\n # Close the database connection\n sqlite_db.close()\n
"},{"location":"reference/#src.aeiva.storage.vector_database","title":"vector_database
","text":""},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase","title":"VectorDatabase
","text":" Bases: ABC
Abstract base class for vector storage operations.
Source code in src/aeiva/storage/vector_database.py
class VectorDatabase(ABC):\n \"\"\"\n Abstract base class for vector storage operations.\n \"\"\"\n\n @abstractmethod\n def create_client(\n self,\n uri: str,\n user: Optional[str] = None,\n password: Optional[str] = None,\n db_name: Optional[str] = None,\n token: Optional[str] = None,\n timeout: Optional[float] = None,\n **kwargs\n ) -> None:\n \"\"\"\n Initializes the client connection to the vector store.\n\n Args:\n uri (str): The URI of the vector store instance.\n user (Optional[str]): Username for authentication.\n password (Optional[str]): Password for authentication.\n db_name (Optional[str]): Name of the database.\n token (Optional[str]): Access token for authentication.\n timeout (Optional[float]): Timeout duration for operations.\n **kwargs: Additional implementation-specific parameters.\n\n Raises:\n ConnectionError: If the client fails to connect to the vector store.\n \"\"\"\n pass\n\n @abstractmethod\n def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use (e.g., 'euclidean', 'cosine').\n\n Raises:\n CollectionAlreadyExistsError: If a collection with the given name already exists.\n StorageError: If there is an issue creating the collection.\n \"\"\"\n pass\n\n @abstractmethod\n def insert_vectors(self, collection_name: str, vectors: List[List[float]], payloads: Optional[List[Dict[str, Any]]] = None, ids: Optional[List[str]] = None) -> None:\n \"\"\"\n Insert vectors into a collection.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n StorageError: If there is an issue inserting the vectors.\n \"\"\"\n pass\n\n @abstractmethod\n def search_vectors(self, collection_name: str, query_vector: List[float], top_k: int = 5, filters: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results, each containing the vector ID, score, and payload.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n StorageError: If there is an issue performing the search.\n \"\"\"\n pass\n\n @abstractmethod\n def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n VectorNotFoundError: If the vector with the specified ID does not exist.\n StorageError: If there is an issue deleting the vector.\n \"\"\"\n pass\n\n @abstractmethod\n def update_vector(self, collection_name: str, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict[str, Any]] = None) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n VectorNotFoundError: If the vector with the specified ID does not exist.\n StorageError: If there is an issue updating the vector.\n \"\"\"\n pass\n\n @abstractmethod\n def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n VectorNotFoundError: If the vector with the specified ID does not exist.\n StorageError: If there is an issue retrieving the vector.\n \"\"\"\n pass\n\n @abstractmethod\n def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections.\n\n Returns:\n List[str]: A list of collection names.\n\n Raises:\n StorageError: If there is an issue retrieving the collection list.\n \"\"\"\n pass\n\n @abstractmethod\n def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n StorageError: If there is an issue deleting the collection.\n \"\"\"\n pass\n\n @abstractmethod\n def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection, such as vector size and distance metric.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n StorageError: If there is an issue retrieving the collection information.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase.create_client","title":"create_client(uri, user=None, password=None, db_name=None, token=None, timeout=None, **kwargs)
abstractmethod
","text":"Initializes the client connection to the vector store.
Parameters:
Name Type Description Default uri
str
The URI of the vector store instance.
required user
Optional[str]
Username for authentication.
None
password
Optional[str]
Password for authentication.
None
db_name
Optional[str]
Name of the database.
None
token
Optional[str]
Access token for authentication.
None
timeout
Optional[float]
Timeout duration for operations.
None
**kwargs
Additional implementation-specific parameters.
{}
Raises:
Type Description ConnectionError
If the client fails to connect to the vector store.
Source code in src/aeiva/storage/vector_database.py
@abstractmethod\ndef create_client(\n self,\n uri: str,\n user: Optional[str] = None,\n password: Optional[str] = None,\n db_name: Optional[str] = None,\n token: Optional[str] = None,\n timeout: Optional[float] = None,\n **kwargs\n) -> None:\n \"\"\"\n Initializes the client connection to the vector store.\n\n Args:\n uri (str): The URI of the vector store instance.\n user (Optional[str]): Username for authentication.\n password (Optional[str]): Password for authentication.\n db_name (Optional[str]): Name of the database.\n token (Optional[str]): Access token for authentication.\n timeout (Optional[float]): Timeout duration for operations.\n **kwargs: Additional implementation-specific parameters.\n\n Raises:\n ConnectionError: If the client fails to connect to the vector store.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase.create_collection","title":"create_collection(collection_name, vector_size, distance_metric)
abstractmethod
","text":"Create a new vector collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_size
int
The dimensionality of the vectors.
required distance_metric
str
The distance metric to use (e.g., 'euclidean', 'cosine').
required Raises:
Type Description CollectionAlreadyExistsError
If a collection with the given name already exists.
StorageError
If there is an issue creating the collection.
Source code in src/aeiva/storage/vector_database.py
@abstractmethod\ndef create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use (e.g., 'euclidean', 'cosine').\n\n Raises:\n CollectionAlreadyExistsError: If a collection with the given name already exists.\n StorageError: If there is an issue creating the collection.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase.delete_collection","title":"delete_collection(collection_name)
abstractmethod
","text":"Delete an entire vector collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection to delete.
required Raises:
Type Description CollectionNotFoundError
If the specified collection does not exist.
StorageError
If there is an issue deleting the collection.
Source code in src/aeiva/storage/vector_database.py
@abstractmethod\ndef delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n StorageError: If there is an issue deleting the collection.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase.delete_vector","title":"delete_vector(collection_name, vector_id)
abstractmethod
","text":"Delete a vector from a collection by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to delete.
required Raises:
Type Description CollectionNotFoundError
If the specified collection does not exist.
VectorNotFoundError
If the vector with the specified ID does not exist.
StorageError
If there is an issue deleting the vector.
Source code in src/aeiva/storage/vector_database.py
@abstractmethod\ndef delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n VectorNotFoundError: If the vector with the specified ID does not exist.\n StorageError: If there is an issue deleting the vector.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase.get_collection_info","title":"get_collection_info(collection_name)
abstractmethod
","text":"Get information about a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: Information about the collection, such as vector size and distance metric.
Raises:
Type Description CollectionNotFoundError
If the specified collection does not exist.
StorageError
If there is an issue retrieving the collection information.
Source code in src/aeiva/storage/vector_database.py
@abstractmethod\ndef get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection, such as vector size and distance metric.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n StorageError: If there is an issue retrieving the collection information.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase.get_vector","title":"get_vector(collection_name, vector_id)
abstractmethod
","text":"Retrieve a vector by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the vector data and payload.
Raises:
Type Description CollectionNotFoundError
If the specified collection does not exist.
VectorNotFoundError
If the vector with the specified ID does not exist.
StorageError
If there is an issue retrieving the vector.
Source code in src/aeiva/storage/vector_database.py
@abstractmethod\ndef get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n VectorNotFoundError: If the vector with the specified ID does not exist.\n StorageError: If there is an issue retrieving the vector.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase.insert_vectors","title":"insert_vectors(collection_name, vectors, payloads=None, ids=None)
abstractmethod
","text":"Insert vectors into a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vectors
List[List[float]]
A list of vectors to insert.
required payloads
Optional[List[Dict[str, Any]]]
Optional metadata associated with each vector.
None
ids
Optional[List[str]]
Optional unique identifiers for each vector.
None
Raises:
Type Description CollectionNotFoundError
If the specified collection does not exist.
StorageError
If there is an issue inserting the vectors.
Source code in src/aeiva/storage/vector_database.py
@abstractmethod\ndef insert_vectors(self, collection_name: str, vectors: List[List[float]], payloads: Optional[List[Dict[str, Any]]] = None, ids: Optional[List[str]] = None) -> None:\n \"\"\"\n Insert vectors into a collection.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n StorageError: If there is an issue inserting the vectors.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase.list_collections","title":"list_collections()
abstractmethod
","text":"List all available vector collections.
Returns:
Type Description List[str]
List[str]: A list of collection names.
Raises:
Type Description StorageError
If there is an issue retrieving the collection list.
Source code in src/aeiva/storage/vector_database.py
@abstractmethod\ndef list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections.\n\n Returns:\n List[str]: A list of collection names.\n\n Raises:\n StorageError: If there is an issue retrieving the collection list.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase.search_vectors","title":"search_vectors(collection_name, query_vector, top_k=5, filters=None)
abstractmethod
","text":"Search for similar vectors in a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required query_vector
List[float]
The vector to search with.
required top_k
int
The number of top results to return.
5
filters
Optional[Dict[str, Any]]
Optional filters to apply to the search.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of search results, each containing the vector ID, score, and payload.
Raises:
Type Description CollectionNotFoundError
If the specified collection does not exist.
StorageError
If there is an issue performing the search.
Source code in src/aeiva/storage/vector_database.py
@abstractmethod\ndef search_vectors(self, collection_name: str, query_vector: List[float], top_k: int = 5, filters: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results, each containing the vector ID, score, and payload.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n StorageError: If there is an issue performing the search.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase.update_vector","title":"update_vector(collection_name, vector_id, vector=None, payload=None)
abstractmethod
","text":"Update a vector's data or payload.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to update.
required vector
Optional[List[float]]
The new vector data.
None
payload
Optional[Dict[str, Any]]
The new payload data.
None
Raises:
Type Description CollectionNotFoundError
If the specified collection does not exist.
VectorNotFoundError
If the vector with the specified ID does not exist.
StorageError
If there is an issue updating the vector.
Source code in src/aeiva/storage/vector_database.py
@abstractmethod\ndef update_vector(self, collection_name: str, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict[str, Any]] = None) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n VectorNotFoundError: If the vector with the specified ID does not exist.\n StorageError: If there is an issue updating the vector.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.weaviate","title":"weaviate
","text":""},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_config","title":"weaviate_config
","text":""},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_config.WeaviateConfig","title":"WeaviateConfig
dataclass
","text":" Bases: BaseConfig
Configuration for Weaviate vector database.
Source code in src/aeiva/storage/weaviate/weaviate_config.py
@dataclass\nclass WeaviateConfig(BaseConfig):\n \"\"\"\n Configuration for Weaviate vector database.\n \"\"\"\n\n url: str = field(\n default='http://localhost:8080',\n metadata={\"help\": \"URL of the Weaviate instance (e.g., 'http://localhost:8080').\"}\n )\n api_key: Optional[str] = field(\n default=None,\n metadata={\"help\": \"API key for Weaviate authentication (if required).\"}\n )\n auth_client_secret: Optional[Dict[str, Any]] = field(\n default=None,\n metadata={\"help\": \"Authentication client secret for Weaviate (if using OIDC).\"}\n )\n timeout_config: Optional[Tuple[float, float]] = field(\n default=(2, 20),\n metadata={\"help\": \"Timeout configuration for requests (connect timeout, read timeout).\"}\n )\n additional_headers: Optional[Dict[str, str]] = field(\n default=None,\n metadata={\"help\": \"Additional headers to include in requests to Weaviate.\"}\n )\n embedding_model: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Name of the embedding model used (if required).\"}\n )\n index_name: str = field(\n default='MyIndex',\n metadata={\"help\": \"Name of the Weaviate index (class).\"}\n )\n vector_dim: int = field(\n default=512,\n metadata={\"help\": \"Dimensionality of the vectors stored in Weaviate.\"}\n )\n distance_metric: str = field(\n default='cosine',\n metadata={\"help\": \"Distance metric to use (e.g., 'cosine', 'l2-squared', 'dot').\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n if not self.url:\n raise ValueError(\"The 'url' parameter is required for Weaviate configuration.\")\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database","title":"weaviate_database
","text":""},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase","title":"WeaviateDatabase
","text":" Bases: VectorDatabase
Concrete implementation of VectorStoreBase using Weaviate.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
class WeaviateDatabase(VectorDatabase):\n \"\"\"\n Concrete implementation of VectorStoreBase using Weaviate.\n \"\"\"\n\n def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the Weaviate vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.url = config.get('url', 'http://localhost:8080')\n self.api_key = config.get('api_key')\n self.auth_client_secret = config.get('auth_client_secret')\n self.timeout_config = config.get('timeout_config', (2, 20))\n self.additional_headers = config.get('additional_headers')\n self.embedding_model = config.get('embedding_model')\n self.index_name = config.get('index_name', 'MyIndex')\n self.vector_dim = config.get('vector_dim', 512)\n self.distance_metric = config.get('distance_metric', 'cosine')\n\n self.client = self.create_client()\n self.create_index(\n index_name=self.index_name,\n vector_dim=self.vector_dim,\n distance_metric=self.distance_metric\n )\n\n def create_client(self) -> Client:\n \"\"\"\n Initializes the client connection to the Weaviate vector store.\n\n Returns:\n Client: The Weaviate client instance.\n\n Raises:\n ConnectionError: If the client fails to connect to the Weaviate instance.\n \"\"\"\n try:\n if self.api_key:\n auth_config = AuthApiKey(api_key=self.api_key)\n elif self.auth_client_secret:\n auth_config = AuthClientPassword(**self.auth_client_secret)\n else:\n auth_config = None\n\n client = weaviate.Client(\n url=self.url,\n auth_client_secret=auth_config,\n timeout_config=self.timeout_config,\n additional_headers=self.additional_headers\n )\n\n if not client.is_ready():\n raise ConnectionError(f\"Weaviate at {self.url} is not ready.\")\n\n logger.info(f\"Connected to Weaviate at {self.url}.\")\n return client\n except Exception as e:\n logger.error(f\"Failed to connect to Weaviate: {e}\")\n raise ConnectionError(f\"Failed to connect to Weaviate: {e}\")\n\n def create_index(self, index_name: str, vector_dim: int, distance_metric: str) -> None:\n \"\"\"\n Create a new index (class) in Weaviate.\n\n Args:\n index_name (str): The name of the index.\n vector_dim (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use.\n\n Raises:\n WeaviateException: If there is an issue creating the index.\n \"\"\"\n try:\n if self.client.schema.contains(index_name):\n logger.info(f\"Index {index_name} already exists. Skipping creation.\")\n return\n\n class_obj = {\n \"class\": index_name,\n \"vectorizer\": \"none\",\n \"vectorIndexType\": \"hnsw\",\n \"vectorIndexConfig\": {\n \"distance\": distance_metric\n },\n \"properties\": [\n {\n \"name\": \"id\",\n \"dataType\": [\"string\"],\n \"description\": \"Unique identifier\",\n },\n {\n \"name\": \"payload\",\n \"dataType\": [\"blob\"],\n \"description\": \"Payload data\",\n },\n ]\n }\n\n self.client.schema.create_class(class_obj)\n logger.info(f\"Index {index_name} created successfully.\")\n except WeaviateException as e:\n logger.error(f\"Failed to create index: {e}\")\n raise\n\n def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n ) -> None:\n \"\"\"\n Insert vectors into the collection.\n\n Args:\n collection_name (str): The name of the collection (index).\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n\n Raises:\n ValueError: If input data is invalid.\n WeaviateException: If there is an issue inserting vectors.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n if ids is None:\n raise ValueError(\"Weaviate requires IDs to be provided for each vector.\")\n\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n try:\n with self.client.batch(batch_size=100) as batch:\n for id_, vector, payload in zip(ids, vectors, payloads):\n data_object = {\n \"id\": id_,\n \"payload\": payload\n }\n batch.add_data_object(\n data_object=data_object,\n class_name=collection_name,\n vector=vector\n )\n logger.info(f\"Inserted {len(vectors)} vectors into index {collection_name}.\")\n except WeaviateException as e:\n logger.error(f\"Failed to insert vectors: {e}\")\n raise\n\n def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in the collection.\n\n Args:\n collection_name (str): The name of the collection (index).\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n\n Raises:\n ValueError: If collection name does not match.\n WeaviateException: If there is an issue performing the search.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n try:\n near_vector = {\n \"vector\": query_vector,\n }\n\n where_filter = self._build_filters(filters)\n\n result = self.client.query.get(\n class_name=collection_name,\n properties=[\"id\", \"payload\"]\n ).with_near_vector(near_vector).with_where(where_filter).with_limit(top_k).do()\n\n output = []\n for item in result[\"data\"][\"Get\"][collection_name]:\n result_item = {\n \"id\": item[\"id\"],\n \"score\": item[\"_additional\"][\"certainty\"], # or distance\n \"payload\": item[\"payload\"]\n }\n output.append(result_item)\n return output\n except WeaviateException as e:\n logger.error(f\"Failed to search vectors: {e}\")\n raise\n\n def _build_filters(self, filters: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:\n \"\"\"\n Build a Weaviate where filter from a dictionary.\n\n Args:\n filters (Optional[Dict[str, Any]]): Filters to apply.\n\n Returns:\n Optional[Dict[str, Any]]: A Weaviate where filter.\n \"\"\"\n if not filters:\n return None\n\n conditions = []\n for key, value in filters.items():\n condition = {\n \"path\": [key],\n \"operator\": \"Equal\",\n \"valueString\": value if isinstance(value, str) else None,\n \"valueInt\": value if isinstance(value, int) else None,\n \"valueBoolean\": value if isinstance(value, bool) else None,\n \"valueNumber\": value if isinstance(value, float) else None,\n }\n conditions.append(condition)\n\n where_filter = {\n \"operator\": \"And\",\n \"operands\": conditions\n }\n\n return where_filter\n\n def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from the collection by its ID.\n\n Args:\n collection_name (str): The name of the collection (index).\n vector_id (str): The unique identifier of the vector to delete.\n\n Raises:\n ValueError: If collection name does not match.\n WeaviateException: If there is an issue deleting the vector.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n try:\n self.client.data_object.delete(\n uuid=vector_id,\n class_name=collection_name\n )\n logger.info(f\"Deleted vector with ID {vector_id} from index {collection_name}.\")\n except WeaviateException as e:\n logger.error(f\"Failed to delete vector: {e}\")\n raise\n\n def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n ) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection (index).\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n\n Raises:\n ValueError: If collection name does not match.\n WeaviateException: If there is an issue updating the vector.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n try:\n data_object = {}\n if payload is not None:\n data_object[\"payload\"] = payload\n\n self.client.data_object.update(\n data_object=data_object,\n class_name=collection_name,\n uuid=vector_id,\n vector=vector\n )\n logger.info(f\"Updated vector with ID {vector_id} in index {collection_name}.\")\n except WeaviateException as e:\n logger.error(f\"Failed to update vector: {e}\")\n raise\n\n def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection (index).\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n\n Raises:\n ValueError: If collection name does not match.\n KeyError: If the vector is not found.\n WeaviateException: If there is an issue retrieving the vector.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n try:\n result = self.client.data_object.get_by_id(\n uuid=vector_id,\n class_name=collection_name,\n additional_properties=[\"vector\"]\n )\n if result is None:\n raise KeyError(f\"Vector with ID {vector_id} not found in index {collection_name}.\")\n\n vector_data = {\n \"id\": result[\"id\"],\n \"vector\": result[\"vector\"],\n \"payload\": result[\"payload\"]\n }\n return vector_data\n except WeaviateException as e:\n logger.error(f\"Failed to retrieve vector: {e}\")\n raise\n\n def list_collections(self) -> List[str]:\n \"\"\"\n List all available indexes (classes).\n\n Returns:\n List[str]: A list of index names.\n \"\"\"\n try:\n schema = self.client.schema.get()\n return [clazz[\"class\"] for clazz in schema[\"classes\"]]\n except WeaviateException as e:\n logger.error(f\"Failed to list collections: {e}\")\n raise\n\n def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire index (class).\n\n Args:\n collection_name (str): The name of the collection (index) to delete.\n\n Raises:\n WeaviateException: If there is an issue deleting the collection.\n \"\"\"\n try:\n self.client.schema.delete_class(collection_name)\n logger.info(f\"Deleted index {collection_name}.\")\n except WeaviateException as e:\n logger.error(f\"Failed to delete collection: {e}\")\n raise\n\n def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection (index).\n\n Args:\n collection_name (str): The name of the collection (index).\n\n Returns:\n Dict[str, Any]: Information about the collection.\n\n Raises:\n WeaviateException: If there is an issue retrieving the collection info.\n \"\"\"\n try:\n class_schema = self.client.schema.get(class_name=collection_name)\n return class_schema\n except WeaviateException as e:\n logger.error(f\"Failed to get collection info: {e}\")\n raise\n\n def __del__(self):\n \"\"\"Clean up resources.\"\"\"\n if hasattr(self, 'client'):\n self.client.close()\n logger.info(\"Closed connection to Weaviate.\")\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.__del__","title":"__del__()
","text":"Clean up resources.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def __del__(self):\n \"\"\"Clean up resources.\"\"\"\n if hasattr(self, 'client'):\n self.client.close()\n logger.info(\"Closed connection to Weaviate.\")\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.__init__","title":"__init__(config)
","text":"Initialize the Weaviate vector store.
Parameters:
Name Type Description Default config
Dict[str, Any]
Configuration dictionary.
required Source code in src/aeiva/storage/weaviate/weaviate_database.py
def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the Weaviate vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.url = config.get('url', 'http://localhost:8080')\n self.api_key = config.get('api_key')\n self.auth_client_secret = config.get('auth_client_secret')\n self.timeout_config = config.get('timeout_config', (2, 20))\n self.additional_headers = config.get('additional_headers')\n self.embedding_model = config.get('embedding_model')\n self.index_name = config.get('index_name', 'MyIndex')\n self.vector_dim = config.get('vector_dim', 512)\n self.distance_metric = config.get('distance_metric', 'cosine')\n\n self.client = self.create_client()\n self.create_index(\n index_name=self.index_name,\n vector_dim=self.vector_dim,\n distance_metric=self.distance_metric\n )\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.create_client","title":"create_client()
","text":"Initializes the client connection to the Weaviate vector store.
Returns:
Name Type Description Client
Client
The Weaviate client instance.
Raises:
Type Description ConnectionError
If the client fails to connect to the Weaviate instance.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def create_client(self) -> Client:\n \"\"\"\n Initializes the client connection to the Weaviate vector store.\n\n Returns:\n Client: The Weaviate client instance.\n\n Raises:\n ConnectionError: If the client fails to connect to the Weaviate instance.\n \"\"\"\n try:\n if self.api_key:\n auth_config = AuthApiKey(api_key=self.api_key)\n elif self.auth_client_secret:\n auth_config = AuthClientPassword(**self.auth_client_secret)\n else:\n auth_config = None\n\n client = weaviate.Client(\n url=self.url,\n auth_client_secret=auth_config,\n timeout_config=self.timeout_config,\n additional_headers=self.additional_headers\n )\n\n if not client.is_ready():\n raise ConnectionError(f\"Weaviate at {self.url} is not ready.\")\n\n logger.info(f\"Connected to Weaviate at {self.url}.\")\n return client\n except Exception as e:\n logger.error(f\"Failed to connect to Weaviate: {e}\")\n raise ConnectionError(f\"Failed to connect to Weaviate: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.create_index","title":"create_index(index_name, vector_dim, distance_metric)
","text":"Create a new index (class) in Weaviate.
Parameters:
Name Type Description Default index_name
str
The name of the index.
required vector_dim
int
The dimensionality of the vectors.
required distance_metric
str
The distance metric to use.
required Raises:
Type Description WeaviateException
If there is an issue creating the index.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def create_index(self, index_name: str, vector_dim: int, distance_metric: str) -> None:\n \"\"\"\n Create a new index (class) in Weaviate.\n\n Args:\n index_name (str): The name of the index.\n vector_dim (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use.\n\n Raises:\n WeaviateException: If there is an issue creating the index.\n \"\"\"\n try:\n if self.client.schema.contains(index_name):\n logger.info(f\"Index {index_name} already exists. Skipping creation.\")\n return\n\n class_obj = {\n \"class\": index_name,\n \"vectorizer\": \"none\",\n \"vectorIndexType\": \"hnsw\",\n \"vectorIndexConfig\": {\n \"distance\": distance_metric\n },\n \"properties\": [\n {\n \"name\": \"id\",\n \"dataType\": [\"string\"],\n \"description\": \"Unique identifier\",\n },\n {\n \"name\": \"payload\",\n \"dataType\": [\"blob\"],\n \"description\": \"Payload data\",\n },\n ]\n }\n\n self.client.schema.create_class(class_obj)\n logger.info(f\"Index {index_name} created successfully.\")\n except WeaviateException as e:\n logger.error(f\"Failed to create index: {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.delete_collection","title":"delete_collection(collection_name)
","text":"Delete an entire index (class).
Parameters:
Name Type Description Default collection_name
str
The name of the collection (index) to delete.
required Raises:
Type Description WeaviateException
If there is an issue deleting the collection.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire index (class).\n\n Args:\n collection_name (str): The name of the collection (index) to delete.\n\n Raises:\n WeaviateException: If there is an issue deleting the collection.\n \"\"\"\n try:\n self.client.schema.delete_class(collection_name)\n logger.info(f\"Deleted index {collection_name}.\")\n except WeaviateException as e:\n logger.error(f\"Failed to delete collection: {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.delete_vector","title":"delete_vector(collection_name, vector_id)
","text":"Delete a vector from the collection by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection (index).
required vector_id
str
The unique identifier of the vector to delete.
required Raises:
Type Description ValueError
If collection name does not match.
WeaviateException
If there is an issue deleting the vector.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from the collection by its ID.\n\n Args:\n collection_name (str): The name of the collection (index).\n vector_id (str): The unique identifier of the vector to delete.\n\n Raises:\n ValueError: If collection name does not match.\n WeaviateException: If there is an issue deleting the vector.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n try:\n self.client.data_object.delete(\n uuid=vector_id,\n class_name=collection_name\n )\n logger.info(f\"Deleted vector with ID {vector_id} from index {collection_name}.\")\n except WeaviateException as e:\n logger.error(f\"Failed to delete vector: {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.get_collection_info","title":"get_collection_info(collection_name)
","text":"Get information about a collection (index).
Parameters:
Name Type Description Default collection_name
str
The name of the collection (index).
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: Information about the collection.
Raises:
Type Description WeaviateException
If there is an issue retrieving the collection info.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection (index).\n\n Args:\n collection_name (str): The name of the collection (index).\n\n Returns:\n Dict[str, Any]: Information about the collection.\n\n Raises:\n WeaviateException: If there is an issue retrieving the collection info.\n \"\"\"\n try:\n class_schema = self.client.schema.get(class_name=collection_name)\n return class_schema\n except WeaviateException as e:\n logger.error(f\"Failed to get collection info: {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.get_vector","title":"get_vector(collection_name, vector_id)
","text":"Retrieve a vector by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection (index).
required vector_id
str
The unique identifier of the vector.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the vector data and payload.
Raises:
Type Description ValueError
If collection name does not match.
KeyError
If the vector is not found.
WeaviateException
If there is an issue retrieving the vector.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection (index).\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n\n Raises:\n ValueError: If collection name does not match.\n KeyError: If the vector is not found.\n WeaviateException: If there is an issue retrieving the vector.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n try:\n result = self.client.data_object.get_by_id(\n uuid=vector_id,\n class_name=collection_name,\n additional_properties=[\"vector\"]\n )\n if result is None:\n raise KeyError(f\"Vector with ID {vector_id} not found in index {collection_name}.\")\n\n vector_data = {\n \"id\": result[\"id\"],\n \"vector\": result[\"vector\"],\n \"payload\": result[\"payload\"]\n }\n return vector_data\n except WeaviateException as e:\n logger.error(f\"Failed to retrieve vector: {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.insert_vectors","title":"insert_vectors(collection_name, vectors, payloads=None, ids=None)
","text":"Insert vectors into the collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection (index).
required vectors
List[List[float]]
A list of vectors to insert.
required payloads
Optional[List[Dict[str, Any]]]
Optional metadata associated with each vector.
None
ids
Optional[List[str]]
Optional unique identifiers for each vector.
None
Raises:
Type Description ValueError
If input data is invalid.
WeaviateException
If there is an issue inserting vectors.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n) -> None:\n \"\"\"\n Insert vectors into the collection.\n\n Args:\n collection_name (str): The name of the collection (index).\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n\n Raises:\n ValueError: If input data is invalid.\n WeaviateException: If there is an issue inserting vectors.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n if ids is None:\n raise ValueError(\"Weaviate requires IDs to be provided for each vector.\")\n\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n try:\n with self.client.batch(batch_size=100) as batch:\n for id_, vector, payload in zip(ids, vectors, payloads):\n data_object = {\n \"id\": id_,\n \"payload\": payload\n }\n batch.add_data_object(\n data_object=data_object,\n class_name=collection_name,\n vector=vector\n )\n logger.info(f\"Inserted {len(vectors)} vectors into index {collection_name}.\")\n except WeaviateException as e:\n logger.error(f\"Failed to insert vectors: {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.list_collections","title":"list_collections()
","text":"List all available indexes (classes).
Returns:
Type Description List[str]
List[str]: A list of index names.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def list_collections(self) -> List[str]:\n \"\"\"\n List all available indexes (classes).\n\n Returns:\n List[str]: A list of index names.\n \"\"\"\n try:\n schema = self.client.schema.get()\n return [clazz[\"class\"] for clazz in schema[\"classes\"]]\n except WeaviateException as e:\n logger.error(f\"Failed to list collections: {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.search_vectors","title":"search_vectors(collection_name, query_vector, top_k=5, filters=None)
","text":"Search for similar vectors in the collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection (index).
required query_vector
List[float]
The vector to search with.
required top_k
int
The number of top results to return.
5
filters
Optional[Dict[str, Any]]
Optional filters to apply to the search.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of search results.
Raises:
Type Description ValueError
If collection name does not match.
WeaviateException
If there is an issue performing the search.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in the collection.\n\n Args:\n collection_name (str): The name of the collection (index).\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n\n Raises:\n ValueError: If collection name does not match.\n WeaviateException: If there is an issue performing the search.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n try:\n near_vector = {\n \"vector\": query_vector,\n }\n\n where_filter = self._build_filters(filters)\n\n result = self.client.query.get(\n class_name=collection_name,\n properties=[\"id\", \"payload\"]\n ).with_near_vector(near_vector).with_where(where_filter).with_limit(top_k).do()\n\n output = []\n for item in result[\"data\"][\"Get\"][collection_name]:\n result_item = {\n \"id\": item[\"id\"],\n \"score\": item[\"_additional\"][\"certainty\"], # or distance\n \"payload\": item[\"payload\"]\n }\n output.append(result_item)\n return output\n except WeaviateException as e:\n logger.error(f\"Failed to search vectors: {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.update_vector","title":"update_vector(collection_name, vector_id, vector=None, payload=None)
","text":"Update a vector's data or payload.
Parameters:
Name Type Description Default collection_name
str
The name of the collection (index).
required vector_id
str
The unique identifier of the vector to update.
required vector
Optional[List[float]]
The new vector data.
None
payload
Optional[Dict[str, Any]]
The new payload data.
None
Raises:
Type Description ValueError
If collection name does not match.
WeaviateException
If there is an issue updating the vector.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection (index).\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n\n Raises:\n ValueError: If collection name does not match.\n WeaviateException: If there is an issue updating the vector.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n try:\n data_object = {}\n if payload is not None:\n data_object[\"payload\"] = payload\n\n self.client.data_object.update(\n data_object=data_object,\n class_name=collection_name,\n uuid=vector_id,\n vector=vector\n )\n logger.info(f\"Updated vector with ID {vector_id} in index {collection_name}.\")\n except WeaviateException as e:\n logger.error(f\"Failed to update vector: {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.tool","title":"tool
","text":""},{"location":"reference/#src.aeiva.tool.api_server","title":"api_server
","text":""},{"location":"reference/#src.aeiva.tool.api_server.call_api_action","title":"call_api_action(api_name, action_name, request)
async
","text":"Endpoint to dynamically call an action within a specified API.
Parameters:
Name Type Description Default api_name
str
The name of the API.
required action_name
str
The name of the action/function to execute.
required request
Request
The incoming HTTP request.
required Returns:
Name Type Description dict
The result of the action or an error message.
Source code in src/aeiva/tool/api_server.py
@app.get(\"/api/{api_name}/{action_name}\")\nasync def call_api_action(api_name: str, action_name: str, request: Request):\n \"\"\"\n Endpoint to dynamically call an action within a specified API.\n\n Args:\n api_name (str): The name of the API.\n action_name (str): The name of the action/function to execute.\n request (Request): The incoming HTTP request.\n\n Returns:\n dict: The result of the action or an error message.\n \"\"\"\n try:\n logger.info(f\"Starting call_api_action for API '{api_name}', Action '{action_name}'\")\n\n # Load the API module\n module = load_api_module(api_name)\n\n # Retrieve the action function\n try:\n action = getattr(module, action_name)\n logger.info(f\"Retrieved action '{action_name}' from API '{api_name}'\")\n except AttributeError:\n logger.error(f\"Action '{action_name}' not found in API '{api_name}'\")\n raise HTTPException(status_code=404, detail=f\"Action '{action_name}' not found in API '{api_name}'\")\n\n # Extract parameters based on request method\n params = {}\n if request.method in [\"POST\", \"PUT\", \"PATCH\"]:\n try:\n params = await request.json()\n logger.info(f\"Received JSON payload: {params}\")\n except json.JSONDecodeError:\n logger.error(\"Invalid JSON payload\")\n raise HTTPException(status_code=400, detail=\"Invalid JSON payload\")\n else:\n # For GET requests, extract query parameters\n params = dict(request.query_params)\n logger.info(f\"Received query parameters: {params}\")\n\n # Get the function signature\n sig = signature(action)\n logger.info(f\"Function signature for '{action_name}': {sig}\")\n\n # Prepare to collect converted parameters\n converted_params = {}\n\n for param_name, param in sig.parameters.items():\n if param_name in params:\n value = params[param_name]\n param_type = param.annotation if param.annotation != Parameter.empty else str\n try:\n if param_type == bool:\n # Convert to boolean\n if isinstance(value, bool):\n converted_value = value\n elif isinstance(value, str):\n converted_value = value.lower() in (\"true\", \"1\", \"yes\")\n else:\n converted_value = bool(value)\n elif param_type in [int, float, str]:\n converted_value = param_type(value)\n elif param_type == list or param_type == dict:\n converted_value = json.loads(value)\n else:\n # For more complex types, assume Pydantic models or custom parsing\n converted_value = param_type(value)\n converted_params[param_name] = converted_value\n logger.debug(f\"Converted parameter '{param_name}': {converted_value} (Type: {param_type})\")\n except (ValueError, json.JSONDecodeError, TypeError) as e:\n logger.error(f\"Invalid value for parameter '{param_name}': {value} ({e})\")\n raise HTTPException(\n status_code=400,\n detail=f\"Invalid value for parameter '{param_name}': {value}. Expected type {param_type.__name__}.\"\n )\n else:\n if param.default == Parameter.empty:\n logger.error(f\"Missing required parameter: {param_name}\")\n raise HTTPException(status_code=400, detail=f\"Missing required parameter: {param_name}\")\n else:\n # Use default value\n converted_params[param_name] = param.default\n logger.debug(f\"Using default value for parameter '{param_name}': {param.default}\")\n\n # Determine if the action is asynchronous\n if asyncio.iscoroutinefunction(action):\n logger.info(f\"Action '{action_name}' is asynchronous. Awaiting execution.\")\n result = await action(**converted_params)\n else:\n logger.info(f\"Action '{action_name}' is synchronous. Executing directly.\")\n result = action(**converted_params)\n\n logger.info(f\"Action '{action_name}' executed successfully with result: {result}\")\n return {\"result\": result}\n\n except FileNotFoundError as e:\n logger.error(f\"API module not found: {e}\")\n raise HTTPException(status_code=404, detail=str(e))\n except HTTPException as he:\n # Re-raise HTTP exceptions to be handled by FastAPI\n raise he\n except Exception as e:\n logger.error(f\"Unhandled exception in call_api_action: {e}\", exc_info=True)\n raise HTTPException(status_code=500, detail=\"Internal Server Error\")\n
"},{"location":"reference/#src.aeiva.tool.api_server.load_api_module","title":"load_api_module(api_name)
","text":"Dynamically load the API module for the given api_name.
Parameters:
Name Type Description Default api_name
str
The name of the API.
required Returns:
Name Type Description module
The loaded API module.
Raises:
Type Description FileNotFoundError
If the API module does not exist.
ImportError
If the module cannot be imported.
Source code in src/aeiva/tool/api_server.py
def load_api_module(api_name: str):\n \"\"\"\n Dynamically load the API module for the given api_name.\n\n Args:\n api_name (str): The name of the API.\n\n Returns:\n module: The loaded API module.\n\n Raises:\n FileNotFoundError: If the API module does not exist.\n ImportError: If the module cannot be imported.\n \"\"\"\n # Construct the path to the API module\n api_path = BASE_DIR / \"api\" / api_name / \"api.py\"\n\n if not api_path.exists():\n logger.error(f\"API module not found at path: {api_path}\")\n raise FileNotFoundError(f\"API module not found at path: {api_path}\")\n\n module_name = f\"aeiva.tool.api.{api_name}.api\"\n spec = importlib.util.spec_from_file_location(module_name, str(api_path))\n module = importlib.util.module_from_spec(spec)\n try:\n spec.loader.exec_module(module)\n logger.info(f\"Successfully loaded module '{module_name}'\")\n except Exception as e:\n logger.error(f\"Failed to load module '{module_name}': {e}\")\n raise ImportError(f\"Failed to load module '{module_name}': {e}\")\n return module\n
"},{"location":"reference/#src.aeiva.tool.api_server.root","title":"root()
async
","text":"Root endpoint to confirm the API server is running.
Source code in src/aeiva/tool/api_server.py
@app.get(\"/\")\nasync def root():\n \"\"\"\n Root endpoint to confirm the API server is running.\n \"\"\"\n return {\"message\": \"Welcome to the AI Agent API system!\"}\n
"},{"location":"reference/#src.aeiva.tool.tool","title":"tool
","text":""},{"location":"reference/#src.aeiva.tool.tool.Tool","title":"Tool
","text":"Source code in src/aeiva/tool/tool.py
class Tool:\n def __init__(self, api_name: str):\n \"\"\"\n Initialize the tool, determining whether it should run locally or via an external service.\n Args:\n api_name (str): The name of the tool API (matches the function name).\n \"\"\"\n self.api_name = api_name\n self.schema = self.load_tool_schema(api_name)\n\n @classmethod\n def load_tool_schema(cls, api_name: str) -> dict:\n \"\"\"\n Load the tool's schema from the JSON file.\n Args:\n api_name (str): The name of the API or function.\n Returns:\n dict: The loaded schema from the JSON file.\n \"\"\"\n current_path = os.path.dirname(os.path.abspath(__file__))\n project_root = os.path.abspath(os.path.join(current_path, \"../../..\"))\n path = os.path.join(\n project_root,\n f\"src/aeiva/tool/api/{api_name}/{api_name}.json\",\n )\n with open(path, \"r\") as file:\n return json.load(file)\n\n async def aexecute(self, params: dict) -> Any:\n \"\"\"\n Execute the tool by calling the corresponding function (whether it's for a local function or encapsulated API call).\n Args:\n params (dict): Parameters to pass to the tool.\n Returns:\n Any: The result of the tool execution.\n \"\"\"\n function_module = f\"aeiva.tool.api.{self.api_name}.api\"\n func_module = import_module(function_module)\n\n # Check if the function is async\n function: Callable = getattr(func_module, self.api_name)\n if asyncio.iscoroutinefunction(function):\n return await function(**params)\n else:\n return function(**params)\n\n def execute(self, params: dict) -> Any:\n \"\"\"\n Execute the tool synchronously by calling the corresponding function.\n\n Args:\n params (dict): Parameters to pass to the tool.\n\n Returns:\n Any: The result of the tool execution.\n \"\"\"\n function_module = f\"aeiva.tool.api.{self.api_name}.api\"\n func_module = import_module(function_module)\n\n function: Callable = getattr(func_module, self.api_name)\n if asyncio.iscoroutinefunction(function):\n # If the function is async, attempt to run it in an event loop\n try:\n loop = asyncio.get_running_loop()\n # If an event loop is running, create a task and wait for it\n task = loop.create_task(function(**params))\n return loop.run_until_complete(task)\n except RuntimeError:\n # No event loop running, use asyncio.run\n return asyncio.run(function(**params))\n else:\n # If the function is synchronous, call it directly\n return function(**params)\n
"},{"location":"reference/#src.aeiva.tool.tool.Tool.__init__","title":"__init__(api_name)
","text":"Initialize the tool, determining whether it should run locally or via an external service. Args: api_name (str): The name of the tool API (matches the function name).
Source code in src/aeiva/tool/tool.py
def __init__(self, api_name: str):\n \"\"\"\n Initialize the tool, determining whether it should run locally or via an external service.\n Args:\n api_name (str): The name of the tool API (matches the function name).\n \"\"\"\n self.api_name = api_name\n self.schema = self.load_tool_schema(api_name)\n
"},{"location":"reference/#src.aeiva.tool.tool.Tool.aexecute","title":"aexecute(params)
async
","text":"Execute the tool by calling the corresponding function (whether it's for a local function or encapsulated API call). Args: params (dict): Parameters to pass to the tool. Returns: Any: The result of the tool execution.
Source code in src/aeiva/tool/tool.py
async def aexecute(self, params: dict) -> Any:\n \"\"\"\n Execute the tool by calling the corresponding function (whether it's for a local function or encapsulated API call).\n Args:\n params (dict): Parameters to pass to the tool.\n Returns:\n Any: The result of the tool execution.\n \"\"\"\n function_module = f\"aeiva.tool.api.{self.api_name}.api\"\n func_module = import_module(function_module)\n\n # Check if the function is async\n function: Callable = getattr(func_module, self.api_name)\n if asyncio.iscoroutinefunction(function):\n return await function(**params)\n else:\n return function(**params)\n
"},{"location":"reference/#src.aeiva.tool.tool.Tool.execute","title":"execute(params)
","text":"Execute the tool synchronously by calling the corresponding function.
Parameters:
Name Type Description Default params
dict
Parameters to pass to the tool.
required Returns:
Name Type Description Any
Any
The result of the tool execution.
Source code in src/aeiva/tool/tool.py
def execute(self, params: dict) -> Any:\n \"\"\"\n Execute the tool synchronously by calling the corresponding function.\n\n Args:\n params (dict): Parameters to pass to the tool.\n\n Returns:\n Any: The result of the tool execution.\n \"\"\"\n function_module = f\"aeiva.tool.api.{self.api_name}.api\"\n func_module = import_module(function_module)\n\n function: Callable = getattr(func_module, self.api_name)\n if asyncio.iscoroutinefunction(function):\n # If the function is async, attempt to run it in an event loop\n try:\n loop = asyncio.get_running_loop()\n # If an event loop is running, create a task and wait for it\n task = loop.create_task(function(**params))\n return loop.run_until_complete(task)\n except RuntimeError:\n # No event loop running, use asyncio.run\n return asyncio.run(function(**params))\n else:\n # If the function is synchronous, call it directly\n return function(**params)\n
"},{"location":"reference/#src.aeiva.tool.tool.Tool.load_tool_schema","title":"load_tool_schema(api_name)
classmethod
","text":"Load the tool's schema from the JSON file. Args: api_name (str): The name of the API or function. Returns: dict: The loaded schema from the JSON file.
Source code in src/aeiva/tool/tool.py
@classmethod\ndef load_tool_schema(cls, api_name: str) -> dict:\n \"\"\"\n Load the tool's schema from the JSON file.\n Args:\n api_name (str): The name of the API or function.\n Returns:\n dict: The loaded schema from the JSON file.\n \"\"\"\n current_path = os.path.dirname(os.path.abspath(__file__))\n project_root = os.path.abspath(os.path.join(current_path, \"../../..\"))\n path = os.path.join(\n project_root,\n f\"src/aeiva/tool/api/{api_name}/{api_name}.json\",\n )\n with open(path, \"r\") as file:\n return json.load(file)\n
"},{"location":"reference/#src.aeiva.tool.toolkit","title":"toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.arxiv_toolkit","title":"arxiv_toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.arxiv_toolkit.ArxivToolkit","title":"ArxivToolkit
","text":" Bases: Toolkit
A toolkit for interacting with the arXiv API.
Source code in src/aeiva/tool/toolkit/arxiv_toolkit.py
class ArxivToolkit(Toolkit):\n \"\"\"\n A toolkit for interacting with the arXiv API.\n \"\"\"\n\n def __init__(self, config=None):\n super().__init__(\n name=\"ArxivToolkit\",\n tool_names=[\n \"download_arxiv_papers\",\n \"search_arxiv_papers\"\n ],\n config=config\n )\n
"},{"location":"reference/#src.aeiva.tool.toolkit.auto_ui_toolkit","title":"auto_ui_toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.auto_ui_toolkit.AutoUIToolkit","title":"AutoUIToolkit
","text":" Bases: Toolkit
A toolkit for automating GUI interactions.
Source code in src/aeiva/tool/toolkit/auto_ui_toolkit.py
class AutoUIToolkit(Toolkit):\n \"\"\"\n A toolkit for automating GUI interactions.\n \"\"\"\n\n def __init__(self, config=None):\n super().__init__(\n name=\"AutoUIToolkit\",\n tool_names=[\n \"analyze_gui\",\n \"analyze_gui_by_ocr\",\n \"click_mouse\",\n \"click_on_element\",\n \"move_mouse\",\n \"operate_computer\",\n \"scroll\",\n \"type_into_element\",\n \"type_keyboard\"\n ],\n config=config\n )\n
"},{"location":"reference/#src.aeiva.tool.toolkit.docx_toolkit","title":"docx_toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.docx_toolkit.DocxToolkit","title":"DocxToolkit
","text":" Bases: Toolkit
A toolkit for interacting with Docx files.
Source code in src/aeiva/tool/toolkit/docx_toolkit.py
class DocxToolkit(Toolkit):\n \"\"\"\n A toolkit for interacting with Docx files.\n \"\"\"\n\n def __init__(self, config=None):\n super().__init__(\n name=\"DocxToolkit\",\n tool_names=[\n \"create_docx\",\n \"docx2html\",\n \"docx2images\",\n \"docx2markdown\",\n \"docx2metadata\",\n \"docx2pdf\",\n \"docx2text\",\n \"modify_docx\"\n ],\n config=config\n )\n
"},{"location":"reference/#src.aeiva.tool.toolkit.file_toolkit","title":"file_toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.file_toolkit.FileToolkit","title":"FileToolkit
","text":" Bases: Toolkit
A toolkit for file-related operations.
Source code in src/aeiva/tool/toolkit/file_toolkit.py
class FileToolkit(Toolkit):\n \"\"\"\n A toolkit for file-related operations.\n \"\"\"\n\n def __init__(self, config=None):\n super().__init__(\n name=\"FileToolkit\",\n tool_names=[\n \"create_file_or_folder\",\n \"open_file_or_folder\",\n \"search_file_or_folder\",\n \"copy_file_or_folder\",\n \"move_file_or_folder\",\n \"change_permissions\",\n \"get_file_metadata\",\n \"delete_file\",\n \"edit_file\",\n \"find_file\",\n \"list_files\",\n \"read_file\",\n \"rename_file\",\n \"write_file\"\n ],\n config=config\n )\n
"},{"location":"reference/#src.aeiva.tool.toolkit.git_toolkit","title":"git_toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.git_toolkit.GitToolkit","title":"GitToolkit
","text":" Bases: Toolkit
A toolkit for interacting with Git repositories.
Source code in src/aeiva/tool/toolkit/git_toolkit.py
class GitToolkit(Toolkit):\n \"\"\"\n A toolkit for interacting with Git repositories.\n \"\"\"\n\n def __init__(self, config=None):\n super().__init__(\n name=\"GitToolkit\",\n tool_names=[\n \"git_apply_patch\",\n \"git_clone\",\n \"git_custom\",\n \"git_patch\",\n \"git_repo_tree\"\n ],\n config=config\n )\n
"},{"location":"reference/#src.aeiva.tool.toolkit.math_toolkit","title":"math_toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.math_toolkit.MathToolkit","title":"MathToolkit
","text":" Bases: Toolkit
A toolkit for mathematical operations.
Source code in src/aeiva/tool/toolkit/math_toolkit.py
class MathToolkit(Toolkit):\n \"\"\"\n A toolkit for mathematical operations.\n \"\"\"\n\n def __init__(self, config=None):\n super().__init__(\n name=\"MathToolkit\",\n tool_names=[\"calculator\"],\n config=config\n )\n
"},{"location":"reference/#src.aeiva.tool.toolkit.pdf_toolkit","title":"pdf_toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.pdf_toolkit.PdfToolkit","title":"PdfToolkit
","text":" Bases: Toolkit
A toolkit for interacting with PDF files.
Source code in src/aeiva/tool/toolkit/pdf_toolkit.py
class PdfToolkit(Toolkit):\n \"\"\"\n A toolkit for interacting with PDF files.\n \"\"\"\n\n def __init__(self, config=None):\n super().__init__(\n name=\"PdfToolkit\",\n tool_names=[\n \"pdf2markdown\",\n \"pdf2text\",\n \"pdf2tables\",\n \"pdf2images\",\n \"pdf2metadata\",\n \"pdf2ocr\"\n ],\n config=config\n )\n
"},{"location":"reference/#src.aeiva.tool.toolkit.rbac","title":"rbac
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.rbac.PermissionError","title":"PermissionError
","text":" Bases: Exception
Custom exception for permission-related errors.
Source code in src/aeiva/tool/toolkit/rbac.py
class PermissionError(Exception):\n \"\"\"Custom exception for permission-related errors.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.tool.toolkit.rbac.check_permission","title":"check_permission(user_role, api_name, config)
","text":"Check if the user_role has permission to execute the given api_name.
Parameters:
Name Type Description Default user_role
str
The role of the user.
required api_name
str
The name of the API function.
required config
ToolkitConfig
The toolkit configuration containing role permissions.
required Returns:
Name Type Description bool
bool
True if permitted, False otherwise.
Raises:
Type Description PermissionError
If the user does not have the required permission.
Source code in src/aeiva/tool/toolkit/rbac.py
def check_permission(user_role: str, api_name: str, config: ToolkitConfig) -> bool:\n \"\"\"\n Check if the user_role has permission to execute the given api_name.\n\n Args:\n user_role (str): The role of the user.\n api_name (str): The name of the API function.\n config (ToolkitConfig): The toolkit configuration containing role permissions.\n\n Returns:\n bool: True if permitted, False otherwise.\n\n Raises:\n PermissionError: If the user does not have the required permission.\n \"\"\"\n allowed_apis: List[str] = config.role_permissions.get(user_role, [])\n if api_name in allowed_apis:\n return True\n else:\n return False\n
"},{"location":"reference/#src.aeiva.tool.toolkit.security","title":"security
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.security.sanitize_file_path","title":"sanitize_file_path(file_path, config)
","text":"Sanitize the file path to prevent directory traversal attacks.
Parameters:
Name Type Description Default file_path
str
The input file path.
required config
ToolkitConfig
The configuration instance.
required Returns:
Name Type Description str
str
The sanitized absolute file path.
Raises:
Type Description ValueError
If the file path is not within allowed directories.
Source code in src/aeiva/tool/toolkit/security.py
def sanitize_file_path(file_path: str, config: ToolkitConfig) -> str:\n \"\"\"\n Sanitize the file path to prevent directory traversal attacks.\n\n Args:\n file_path (str): The input file path.\n config (ToolkitConfig): The configuration instance.\n\n Returns:\n str: The sanitized absolute file path.\n\n Raises:\n ValueError: If the file path is not within allowed directories.\n \"\"\"\n # Resolve the absolute path\n try:\n absolute_path = Path(file_path).resolve(strict=False)\n except Exception as e:\n logger.error(f\"Error resolving file path '{file_path}': {e}\")\n raise ValueError(f\"Invalid file path: {e}\")\n\n # Check if the path is within allowed directories\n allowed = False\n for dir_path in config.allowed_directories:\n try:\n allowed_dir = Path(dir_path).resolve(strict=False)\n if allowed_dir in absolute_path.parents or allowed_dir == absolute_path.parent:\n allowed = True\n break\n except Exception as e:\n logger.error(f\"Error resolving allowed directory '{dir_path}': {e}\")\n continue\n\n if not allowed:\n logger.error(f\"Unauthorized file path access attempt: {absolute_path}\")\n raise ValueError(\"Unauthorized file path.\")\n\n return str(absolute_path)\n
"},{"location":"reference/#src.aeiva.tool.toolkit.shell_toolkit","title":"shell_toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.shell_toolkit.ShellToolkit","title":"ShellToolkit
","text":" Bases: Toolkit
A toolkit for shell and terminal operations.
Source code in src/aeiva/tool/toolkit/shell_toolkit.py
class ShellToolkit(Toolkit):\n \"\"\"\n A toolkit for shell and terminal operations.\n \"\"\"\n\n def __init__(self, config=None):\n super().__init__(\n name=\"ShellToolkit\",\n tool_names=[\n \"chwdir\",\n \"execute_bash_command\",\n \"execute_script\",\n \"grep\",\n \"create_new_shell_session\"\n ],\n config=config\n )\n
"},{"location":"reference/#src.aeiva.tool.toolkit.system_toolkit","title":"system_toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.system_toolkit.SystemToolkit","title":"SystemToolkit
","text":" Bases: Toolkit
A toolkit for interacting with system-level operations.
Source code in src/aeiva/tool/toolkit/system_toolkit.py
class SystemToolkit(Toolkit):\n \"\"\"\n A toolkit for interacting with system-level operations.\n \"\"\"\n\n def __init__(self, config=None):\n super().__init__(\n name=\"SystemToolkit\",\n tool_names=[\n \"get_system_info\",\n \"get_package_root\",\n \"get_user_home_path\",\n \"open_application\",\n \"close_application\",\n \"percept_terminal_input\",\n \"play_music\",\n \"stop_music\",\n \"take_screenshot\"\n \"list_processes\",\n \"kill_process\",\n \"monitor_process\",\n \"get_network_info\",\n \"check_internet_connection\",\n \"get_disk_usage\",\n \"clean_temp_files\",\n \"list_drives\",\n \"get_env_var\",\n \"set_env_var\",\n \"update_system_packages\",\n \"install_package\",\n \"create_user\",\n \"delete_user\",\n \"change_user_password\",\n \"view_system_logs\",\n \"monitor_system_resources\",\n ],\n config=config\n )\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit","title":"toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit","title":"Toolkit
","text":"Toolkit class that manages multiple Tool instances, handles validation, enforces RBAC, and manages shared resources.
Source code in src/aeiva/tool/toolkit/toolkit.py
class Toolkit:\n \"\"\"\n Toolkit class that manages multiple Tool instances, handles validation,\n enforces RBAC, and manages shared resources.\n \"\"\"\n\n subclasses: Dict[str, Type['Toolkit']] = {}\n\n def __init_subclass__(cls, **kwargs):\n \"\"\"\n Automatically register subclasses in the Toolkit's subclasses dictionary.\n \"\"\"\n super().__init_subclass__(**kwargs)\n Toolkit.subclasses[cls.__name__] = cls\n\n def __init__(self, name: str, tool_names: List[str], config: Optional[ToolkitConfig] = None):\n \"\"\"\n Initialize the Toolkit with a name, list of tool names, and optional configuration.\n\n Args:\n name (str): The name of the toolkit.\n tool_names (List[str]): The names of tools to be managed by the toolkit.\n config (Optional[ToolkitConfig]): Configuration for security and roles.\n \"\"\"\n self.toolkit_name = name\n self.tool_names = tool_names\n self.config = config\n self.tools: Dict[str, Tool] = {}\n self.tool_schemas: Dict[str, Dict] = {}\n self.tool_models: Dict[str, Tuple[Type[BaseModel], Type[BaseModel]]] = {}\n self.shared_resources = None # Placeholder for shared resources\n\n # Setup the toolkit\n self.setup()\n\n def setup(self):\n \"\"\"\n Setup the toolkit by loading tools, their schemas, and initializing shared resources.\n \"\"\"\n logger.info(f\"Setting up toolkit '{self.toolkit_name}'.\")\n\n # Load tools and their schemas\n for tool_name in self.tool_names:\n tool = Tool(api_name=tool_name)\n self.tools[tool_name] = tool\n schema = tool.load_tool_schema(tool_name)\n self.tool_schemas[tool_name] = schema\n logger.debug(f\"Loaded schema for tool '{tool_name}': {schema}\")\n\n # Load Pydantic models for validation\n self.load_pydantic_models_for_all_tools()\n\n # Initialize shared resources\n self.init_shared_resources()\n\n def load_pydantic_models_for_all_tools(self):\n \"\"\"\n Load Pydantic models (Params and Result) for all tools for validation.\n \"\"\"\n logger.info(\"Loading Pydantic models for all tools.\")\n for tool_name in self.tool_names:\n try:\n param_model, result_model = self.load_pydantic_models_for_tool(tool_name)\n self.tool_models[tool_name] = (param_model, result_model)\n logger.debug(f\"Loaded models for tool '{tool_name}': Params={param_model}, Result={result_model}\")\n except Exception as e:\n logger.error(f\"Failed to load models for tool '{tool_name}': {e}\")\n raise\n\n def load_pydantic_models_for_tool(self, api_name: str) -> Tuple[Type[BaseModel], Type[BaseModel]]:\n \"\"\"\n Load the parameter and result Pydantic models for the given API.\n\n Args:\n api_name (str): The name of the API function.\n\n Returns:\n Tuple[Type[BaseModel], Type[BaseModel]]: The parameter and result model classes.\n\n Raises:\n ValueError: If models cannot be loaded.\n \"\"\"\n module_path = f\"aeiva.tool.api.{api_name}.model\" # Adjusted as per user's path\n try:\n models_module = importlib.import_module(module_path)\n param_model_class = getattr(models_module, f\"{snake_to_camel(api_name)}Params\", None)\n result_model_class = getattr(models_module, f\"{snake_to_camel(api_name)}Result\", None)\n if not (param_model_class and issubclass(param_model_class, BaseModel)):\n logger.error(f\"Param model class '{snake_to_camel(api_name)}Params' not found in '{module_path}'.\")\n raise ValueError(f\"Param model class '{snake_to_camel(api_name)}Params' not found in '{module_path}'.\")\n if not (result_model_class and issubclass(result_model_class, BaseModel)):\n logger.error(f\"Result model class '{snake_to_camel(api_name)}Result' not found in '{module_path}'.\")\n raise ValueError(f\"Result model class '{snake_to_camel(api_name)}Result' not found in '{module_path}'.\")\n return param_model_class, result_model_class\n except ImportError as e:\n logger.error(f\"Error importing models from '{module_path}': {e}\")\n raise ImportError(f\"Error importing models from '{module_path}': {e}\")\n except AttributeError as e:\n logger.error(f\"Error accessing model classes in '{module_path}': {e}\")\n raise ValueError(f\"Error accessing model classes in '{module_path}': {e}\")\n\n def init_shared_resources(self):\n \"\"\"\n Initialize shared resources required by the toolkit.\n Override this method in subclasses if needed.\n \"\"\"\n logger.info(\"Initializing shared resources.\")\n # Placeholder for initializing shared resources like databases, servers, etc.\n # Example:\n # self.shared_resources = initialize_database_connection()\n pass\n\n def teardown(self):\n \"\"\"\n Teardown the toolkit by unloading tools, their schemas, and cleaning up shared resources.\n \"\"\"\n logger.info(f\"Tearing down toolkit '{self.toolkit_name}'.\")\n\n # Clean up shared resources\n self.teardown_shared_resources()\n\n # Clear loaded data\n self.tools.clear()\n self.tool_schemas.clear()\n self.tool_models.clear()\n\n def teardown_shared_resources(self):\n \"\"\"\n Teardown shared resources.\n Override this method in subclasses if needed.\n \"\"\"\n logger.info(\"Tearing down shared resources.\")\n # Placeholder for tearing down shared resources\n # Example:\n # if self.shared_resources:\n # self.shared_resources.close()\n pass\n\n @asynccontextmanager\n async def acontext(self):\n \"\"\"\n Asynchronous context manager to handle setup and teardown of shared resources.\n\n Usage:\n async with toolkit.acontent():\n # Execute tools\n \"\"\"\n try:\n await self.asetup()\n yield self\n finally:\n await self.ateardown()\n\n @contextmanager\n def context(self):\n \"\"\"\n Synchronous context manager to handle setup and teardown of shared resources.\n\n Usage:\n with toolkit.context():\n # Execute tools\n \"\"\"\n try:\n self.setup()\n yield self\n finally:\n self.teardown()\n\n async def asetup(self):\n \"\"\"\n Asynchronously setup shared resources.\n \"\"\"\n logger.info(f\"Asynchronously setting up toolkit '{self.toolkit_name}'.\")\n # Override in subclasses if asynchronous setup is required\n pass\n\n async def ateardown(self):\n \"\"\"\n Asynchronously teardown shared resources.\n \"\"\"\n logger.info(f\"Asynchronously tearing down toolkit '{self.toolkit_name}'.\")\n # Override in subclasses if asynchronous teardown is required\n self.teardown()\n\n def execute(self, api_name: str, params: Dict[str, Any]) -> Any:\n \"\"\"\n Synchronously execute a tool's API function with validation and RBAC checks.\n\n Args:\n api_name (str): The name of the API function to execute.\n params (Dict[str, Any]): The parameters for the API function.\n\n Returns:\n Any: The result of the tool execution.\n\n Raises:\n ValueError: If tool not found or parameter validation fails.\n PermissionError: If user does not have permission.\n RuntimeError: If tool execution fails.\n \"\"\"\n tool = self.tools.get(api_name)\n if not tool:\n logger.error(f\"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.\")\n raise ValueError(f\"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.\")\n\n # Perform RBAC check if config is provided\n if self.config:\n # Automatically retrieve user role from OS\n os_user = get_os_user()\n user_role = self.config.user_role_mapping.get(os_user)\n if not user_role:\n logger.error(f\"OS user '{os_user}' does not have an assigned role.\")\n raise ValueError(f\"OS user '{os_user}' does not have an assigned role.\")\n if not check_permission(user_role, api_name, self.config):\n logger.error(f\"User role '{user_role}' does not have permission to execute '{api_name}'.\")\n raise PermissionError(f\"User role '{user_role}' does not have permission to execute '{api_name}'.\")\n\n # Load the Pydantic models for validation\n param_model, result_model = self.tool_models.get(api_name, (None, None))\n if not param_model or not result_model:\n logger.error(f\"Pydantic models for tool '{api_name}' are not loaded.\")\n raise ValueError(f\"Pydantic models for tool '{api_name}' are not loaded.\")\n\n # Instantiate and validate the parameter model\n try:\n param_instance = param_model(**params)\n logger.debug(f\"Validated input parameters for '{api_name}': {param_instance}\")\n except Exception as e:\n logger.error(f\"Error parsing parameters for '{api_name}': {e}\")\n raise ValueError(f\"Invalid parameters for '{api_name}': {e}\")\n\n # Perform security checks on parameters if needed\n param_instance = self.perform_security_checks(param_instance)\n\n # Execute the API function via the Tool\n try:\n raw_result = tool.execute(param_instance.dict())\n logger.debug(f\"Raw result from '{api_name}': {raw_result}\")\n except Exception as e:\n logger.error(f\"Error executing tool '{api_name}': {e}\")\n raise RuntimeError(f\"Error executing tool '{api_name}': {e}\")\n\n # Validate the result using the result model\n try:\n result_instance = result_model(**raw_result)\n logger.info(f\"Execution of '{api_name}' successful with result: {result_instance}\")\n return result_instance\n except Exception as e:\n logger.error(f\"Error parsing result for '{api_name}': {e}\")\n raise ValueError(f\"Invalid result from '{api_name}': {e}\")\n\n async def aexecute(self, api_name: str, params: Dict[str, Any]) -> Any:\n \"\"\"\n Asynchronously execute a tool's API function with validation and RBAC checks.\n\n Args:\n api_name (str): The name of the API function to execute.\n params (Dict[str, Any]): The parameters for the API function.\n\n Returns:\n Any: The result of the tool execution.\n\n Raises:\n ValueError: If tool not found or parameter validation fails.\n PermissionError: If user does not have permission.\n RuntimeError: If tool execution fails.\n \"\"\"\n tool = self.tools.get(api_name)\n if not tool:\n logger.error(f\"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.\")\n raise ValueError(f\"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.\")\n\n # Perform RBAC check if config is provided\n if self.config:\n # Automatically retrieve user role from OS\n os_user = get_os_user()\n user_role = self.config.user_role_mapping.get(os_user)\n if not user_role:\n logger.error(f\"OS user '{os_user}' does not have an assigned role.\")\n raise ValueError(f\"OS user '{os_user}' does not have an assigned role.\")\n if not check_permission(user_role, api_name, self.config):\n logger.error(f\"User role '{user_role}' does not have permission to execute '{api_name}'.\")\n raise PermissionError(f\"User role '{user_role}' does not have permission to execute '{api_name}'.\")\n\n # Load the Pydantic models for validation\n param_model, result_model = self.tool_models.get(api_name, (None, None))\n if not param_model or not result_model:\n logger.error(f\"Pydantic models for tool '{api_name}' are not loaded.\")\n raise ValueError(f\"Pydantic models for tool '{api_name}' are not loaded.\")\n\n # Instantiate and validate the parameter model\n try:\n param_instance = param_model(**params)\n logger.debug(f\"Validated input parameters for '{api_name}': {param_instance}\")\n except Exception as e:\n logger.error(f\"Error parsing parameters for '{api_name}': {e}\")\n raise ValueError(f\"Invalid parameters for '{api_name}': {e}\")\n\n # Perform security checks on parameters if needed\n param_instance = self.perform_security_checks(param_instance)\n\n # Execute the API function via the Tool\n try:\n raw_result = await tool.aexecute(param_instance.dict())\n logger.debug(f\"Raw result from '{api_name}': {raw_result}\")\n except Exception as e:\n logger.error(f\"Error executing tool '{api_name}': {e}\")\n raise RuntimeError(f\"Error executing tool '{api_name}': {e}\")\n\n # Validate the result using the result model\n try:\n result_instance = result_model(**raw_result)\n logger.info(f\"Execution of '{api_name}' successful with result: {result_instance}\")\n return result_instance\n except Exception as e:\n logger.error(f\"Error parsing result for '{api_name}': {e}\")\n raise ValueError(f\"Invalid result from '{api_name}': {e}\")\n\n def perform_security_checks(self, param_instance: BaseModel) -> BaseModel:\n \"\"\"\n Perform security checks on parameters that require sanitization.\n\n Args:\n param_instance (BaseModel): The validated parameter instance.\n\n Returns:\n BaseModel: The sanitized parameter instance.\n\n Raises:\n ValueError: If sanitization fails for any field or if config is required but not provided.\n \"\"\"\n sanitized_params = param_instance.dict()\n\n for field_name, field in param_instance.__fields__.items():\n sanitize = field.field_info.extra.get('sanitize', False)\n if not sanitize:\n continue # Skip fields that do not require sanitization\n\n field_type = field.type_\n origin = get_origin(field_type)\n args = get_args(field_type)\n\n # Determine if the field is a string type or contains string types\n is_string_field = False\n\n if field_type == str:\n is_string_field = True\n elif origin is Union and str in args:\n is_string_field = True\n elif origin is list and len(args) == 1 and args[0] == str:\n is_string_field = True\n elif origin is Optional and str in args:\n is_string_field = True\n # Add more conditions here if there are other complex types containing strings\n\n if is_string_field:\n original_value = sanitized_params.get(field_name)\n if original_value is None:\n continue # Skip if the field value is None\n\n if self.config is None:\n logger.error(\n f\"Configuration is required to sanitize field '{field_name}', \"\n f\"but config is not provided.\"\n )\n raise ValueError(\n f\"Configuration is required to sanitize field '{field_name}', \"\n f\"but config is not provided.\"\n )\n\n try:\n # If the field is a list of strings, sanitize each path individually\n if origin is list and len(args) == 1 and args[0] == str:\n if not isinstance(original_value, list):\n logger.error(\n f\"Expected a list for field '{field_name}', \"\n f\"got {type(original_value)}.\"\n )\n raise ValueError(\n f\"Expected a list for field '{field_name}'.\"\n )\n sanitized_list = []\n for idx, item in enumerate(original_value):\n sanitized_item = sanitize_file_path(item, self.config)\n sanitized_list.append(sanitized_item)\n logger.debug(\n f\"Sanitized '{field_name}[{idx}]': '{item}' -> '{sanitized_item}'\"\n )\n sanitized_params[field_name] = sanitized_list\n else:\n # Sanitize single string path\n sanitized_path = sanitize_file_path(original_value, self.config)\n sanitized_params[field_name] = sanitized_path\n logger.debug(\n f\"Sanitized '{field_name}': '{original_value}' -> '{sanitized_path}'\"\n )\n except ValueError as ve:\n logger.error(\n f\"Sanitization failed for field '{field_name}': {ve}\"\n )\n raise\n\n # Create a new instance of the parameter model with sanitized parameters\n sanitized_instance = param_instance.copy(update=sanitized_params)\n\n return sanitized_instance\n\n def generate_documentation(self) -> str:\n \"\"\"\n Generate documentation for all functions managed by this toolkit based on their schemas.\n\n Returns:\n str: Generated documentation as a markdown string.\n \"\"\"\n doc = f\"# Toolkit: {self.toolkit_name}\\n\\n\"\n for api_name, tool in self.tools.items():\n schema = self.tool_schemas.get(api_name, {})\n if not schema:\n continue\n doc += f\"## Function: {api_name}\\n\\n\"\n doc += f\"**Description:** {schema.get('description', 'No description provided.')}\\n\\n\"\n doc += \"### Parameters:\\n\\n\"\n parameters = schema.get(\"parameters\", {})\n for prop, details in parameters.get(\"properties\", {}).items():\n req = \" (required)\" if prop in parameters.get(\"required\", []) else \"\"\n description = details.get(\"description\", \"\")\n default = f\" (default: {details.get('default')})\" if \"default\" in details else \"\"\n doc += f\"- **{prop}** ({details.get('type', 'any')}): {description}{default}{req}\\n\"\n doc += \"\\n### Example:\\n\\n\"\n example = schema.get(\"example\", \"No example provided.\")\n if isinstance(example, dict):\n example = json.dumps(example, indent=4)\n doc += f\"```json\\n{example}\\n```\\n\\n\"\n return doc\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.__init__","title":"__init__(name, tool_names, config=None)
","text":"Initialize the Toolkit with a name, list of tool names, and optional configuration.
Parameters:
Name Type Description Default name
str
The name of the toolkit.
required tool_names
List[str]
The names of tools to be managed by the toolkit.
required config
Optional[ToolkitConfig]
Configuration for security and roles.
None
Source code in src/aeiva/tool/toolkit/toolkit.py
def __init__(self, name: str, tool_names: List[str], config: Optional[ToolkitConfig] = None):\n \"\"\"\n Initialize the Toolkit with a name, list of tool names, and optional configuration.\n\n Args:\n name (str): The name of the toolkit.\n tool_names (List[str]): The names of tools to be managed by the toolkit.\n config (Optional[ToolkitConfig]): Configuration for security and roles.\n \"\"\"\n self.toolkit_name = name\n self.tool_names = tool_names\n self.config = config\n self.tools: Dict[str, Tool] = {}\n self.tool_schemas: Dict[str, Dict] = {}\n self.tool_models: Dict[str, Tuple[Type[BaseModel], Type[BaseModel]]] = {}\n self.shared_resources = None # Placeholder for shared resources\n\n # Setup the toolkit\n self.setup()\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.__init_subclass__","title":"__init_subclass__(**kwargs)
","text":"Automatically register subclasses in the Toolkit's subclasses dictionary.
Source code in src/aeiva/tool/toolkit/toolkit.py
def __init_subclass__(cls, **kwargs):\n \"\"\"\n Automatically register subclasses in the Toolkit's subclasses dictionary.\n \"\"\"\n super().__init_subclass__(**kwargs)\n Toolkit.subclasses[cls.__name__] = cls\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.acontext","title":"acontext()
async
","text":"Asynchronous context manager to handle setup and teardown of shared resources.
Usage async with toolkit.acontent(): # Execute tools
Source code in src/aeiva/tool/toolkit/toolkit.py
@asynccontextmanager\nasync def acontext(self):\n \"\"\"\n Asynchronous context manager to handle setup and teardown of shared resources.\n\n Usage:\n async with toolkit.acontent():\n # Execute tools\n \"\"\"\n try:\n await self.asetup()\n yield self\n finally:\n await self.ateardown()\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.aexecute","title":"aexecute(api_name, params)
async
","text":"Asynchronously execute a tool's API function with validation and RBAC checks.
Parameters:
Name Type Description Default api_name
str
The name of the API function to execute.
required params
Dict[str, Any]
The parameters for the API function.
required Returns:
Name Type Description Any
Any
The result of the tool execution.
Raises:
Type Description ValueError
If tool not found or parameter validation fails.
PermissionError
If user does not have permission.
RuntimeError
If tool execution fails.
Source code in src/aeiva/tool/toolkit/toolkit.py
async def aexecute(self, api_name: str, params: Dict[str, Any]) -> Any:\n \"\"\"\n Asynchronously execute a tool's API function with validation and RBAC checks.\n\n Args:\n api_name (str): The name of the API function to execute.\n params (Dict[str, Any]): The parameters for the API function.\n\n Returns:\n Any: The result of the tool execution.\n\n Raises:\n ValueError: If tool not found or parameter validation fails.\n PermissionError: If user does not have permission.\n RuntimeError: If tool execution fails.\n \"\"\"\n tool = self.tools.get(api_name)\n if not tool:\n logger.error(f\"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.\")\n raise ValueError(f\"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.\")\n\n # Perform RBAC check if config is provided\n if self.config:\n # Automatically retrieve user role from OS\n os_user = get_os_user()\n user_role = self.config.user_role_mapping.get(os_user)\n if not user_role:\n logger.error(f\"OS user '{os_user}' does not have an assigned role.\")\n raise ValueError(f\"OS user '{os_user}' does not have an assigned role.\")\n if not check_permission(user_role, api_name, self.config):\n logger.error(f\"User role '{user_role}' does not have permission to execute '{api_name}'.\")\n raise PermissionError(f\"User role '{user_role}' does not have permission to execute '{api_name}'.\")\n\n # Load the Pydantic models for validation\n param_model, result_model = self.tool_models.get(api_name, (None, None))\n if not param_model or not result_model:\n logger.error(f\"Pydantic models for tool '{api_name}' are not loaded.\")\n raise ValueError(f\"Pydantic models for tool '{api_name}' are not loaded.\")\n\n # Instantiate and validate the parameter model\n try:\n param_instance = param_model(**params)\n logger.debug(f\"Validated input parameters for '{api_name}': {param_instance}\")\n except Exception as e:\n logger.error(f\"Error parsing parameters for '{api_name}': {e}\")\n raise ValueError(f\"Invalid parameters for '{api_name}': {e}\")\n\n # Perform security checks on parameters if needed\n param_instance = self.perform_security_checks(param_instance)\n\n # Execute the API function via the Tool\n try:\n raw_result = await tool.aexecute(param_instance.dict())\n logger.debug(f\"Raw result from '{api_name}': {raw_result}\")\n except Exception as e:\n logger.error(f\"Error executing tool '{api_name}': {e}\")\n raise RuntimeError(f\"Error executing tool '{api_name}': {e}\")\n\n # Validate the result using the result model\n try:\n result_instance = result_model(**raw_result)\n logger.info(f\"Execution of '{api_name}' successful with result: {result_instance}\")\n return result_instance\n except Exception as e:\n logger.error(f\"Error parsing result for '{api_name}': {e}\")\n raise ValueError(f\"Invalid result from '{api_name}': {e}\")\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.asetup","title":"asetup()
async
","text":"Asynchronously setup shared resources.
Source code in src/aeiva/tool/toolkit/toolkit.py
async def asetup(self):\n \"\"\"\n Asynchronously setup shared resources.\n \"\"\"\n logger.info(f\"Asynchronously setting up toolkit '{self.toolkit_name}'.\")\n # Override in subclasses if asynchronous setup is required\n pass\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.ateardown","title":"ateardown()
async
","text":"Asynchronously teardown shared resources.
Source code in src/aeiva/tool/toolkit/toolkit.py
async def ateardown(self):\n \"\"\"\n Asynchronously teardown shared resources.\n \"\"\"\n logger.info(f\"Asynchronously tearing down toolkit '{self.toolkit_name}'.\")\n # Override in subclasses if asynchronous teardown is required\n self.teardown()\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.context","title":"context()
","text":"Synchronous context manager to handle setup and teardown of shared resources.
Usage with toolkit.context(): # Execute tools
Source code in src/aeiva/tool/toolkit/toolkit.py
@contextmanager\ndef context(self):\n \"\"\"\n Synchronous context manager to handle setup and teardown of shared resources.\n\n Usage:\n with toolkit.context():\n # Execute tools\n \"\"\"\n try:\n self.setup()\n yield self\n finally:\n self.teardown()\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.execute","title":"execute(api_name, params)
","text":"Synchronously execute a tool's API function with validation and RBAC checks.
Parameters:
Name Type Description Default api_name
str
The name of the API function to execute.
required params
Dict[str, Any]
The parameters for the API function.
required Returns:
Name Type Description Any
Any
The result of the tool execution.
Raises:
Type Description ValueError
If tool not found or parameter validation fails.
PermissionError
If user does not have permission.
RuntimeError
If tool execution fails.
Source code in src/aeiva/tool/toolkit/toolkit.py
def execute(self, api_name: str, params: Dict[str, Any]) -> Any:\n \"\"\"\n Synchronously execute a tool's API function with validation and RBAC checks.\n\n Args:\n api_name (str): The name of the API function to execute.\n params (Dict[str, Any]): The parameters for the API function.\n\n Returns:\n Any: The result of the tool execution.\n\n Raises:\n ValueError: If tool not found or parameter validation fails.\n PermissionError: If user does not have permission.\n RuntimeError: If tool execution fails.\n \"\"\"\n tool = self.tools.get(api_name)\n if not tool:\n logger.error(f\"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.\")\n raise ValueError(f\"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.\")\n\n # Perform RBAC check if config is provided\n if self.config:\n # Automatically retrieve user role from OS\n os_user = get_os_user()\n user_role = self.config.user_role_mapping.get(os_user)\n if not user_role:\n logger.error(f\"OS user '{os_user}' does not have an assigned role.\")\n raise ValueError(f\"OS user '{os_user}' does not have an assigned role.\")\n if not check_permission(user_role, api_name, self.config):\n logger.error(f\"User role '{user_role}' does not have permission to execute '{api_name}'.\")\n raise PermissionError(f\"User role '{user_role}' does not have permission to execute '{api_name}'.\")\n\n # Load the Pydantic models for validation\n param_model, result_model = self.tool_models.get(api_name, (None, None))\n if not param_model or not result_model:\n logger.error(f\"Pydantic models for tool '{api_name}' are not loaded.\")\n raise ValueError(f\"Pydantic models for tool '{api_name}' are not loaded.\")\n\n # Instantiate and validate the parameter model\n try:\n param_instance = param_model(**params)\n logger.debug(f\"Validated input parameters for '{api_name}': {param_instance}\")\n except Exception as e:\n logger.error(f\"Error parsing parameters for '{api_name}': {e}\")\n raise ValueError(f\"Invalid parameters for '{api_name}': {e}\")\n\n # Perform security checks on parameters if needed\n param_instance = self.perform_security_checks(param_instance)\n\n # Execute the API function via the Tool\n try:\n raw_result = tool.execute(param_instance.dict())\n logger.debug(f\"Raw result from '{api_name}': {raw_result}\")\n except Exception as e:\n logger.error(f\"Error executing tool '{api_name}': {e}\")\n raise RuntimeError(f\"Error executing tool '{api_name}': {e}\")\n\n # Validate the result using the result model\n try:\n result_instance = result_model(**raw_result)\n logger.info(f\"Execution of '{api_name}' successful with result: {result_instance}\")\n return result_instance\n except Exception as e:\n logger.error(f\"Error parsing result for '{api_name}': {e}\")\n raise ValueError(f\"Invalid result from '{api_name}': {e}\")\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.generate_documentation","title":"generate_documentation()
","text":"Generate documentation for all functions managed by this toolkit based on their schemas.
Returns:
Name Type Description str
str
Generated documentation as a markdown string.
Source code in src/aeiva/tool/toolkit/toolkit.py
def generate_documentation(self) -> str:\n \"\"\"\n Generate documentation for all functions managed by this toolkit based on their schemas.\n\n Returns:\n str: Generated documentation as a markdown string.\n \"\"\"\n doc = f\"# Toolkit: {self.toolkit_name}\\n\\n\"\n for api_name, tool in self.tools.items():\n schema = self.tool_schemas.get(api_name, {})\n if not schema:\n continue\n doc += f\"## Function: {api_name}\\n\\n\"\n doc += f\"**Description:** {schema.get('description', 'No description provided.')}\\n\\n\"\n doc += \"### Parameters:\\n\\n\"\n parameters = schema.get(\"parameters\", {})\n for prop, details in parameters.get(\"properties\", {}).items():\n req = \" (required)\" if prop in parameters.get(\"required\", []) else \"\"\n description = details.get(\"description\", \"\")\n default = f\" (default: {details.get('default')})\" if \"default\" in details else \"\"\n doc += f\"- **{prop}** ({details.get('type', 'any')}): {description}{default}{req}\\n\"\n doc += \"\\n### Example:\\n\\n\"\n example = schema.get(\"example\", \"No example provided.\")\n if isinstance(example, dict):\n example = json.dumps(example, indent=4)\n doc += f\"```json\\n{example}\\n```\\n\\n\"\n return doc\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.init_shared_resources","title":"init_shared_resources()
","text":"Initialize shared resources required by the toolkit. Override this method in subclasses if needed.
Source code in src/aeiva/tool/toolkit/toolkit.py
def init_shared_resources(self):\n \"\"\"\n Initialize shared resources required by the toolkit.\n Override this method in subclasses if needed.\n \"\"\"\n logger.info(\"Initializing shared resources.\")\n # Placeholder for initializing shared resources like databases, servers, etc.\n # Example:\n # self.shared_resources = initialize_database_connection()\n pass\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.load_pydantic_models_for_all_tools","title":"load_pydantic_models_for_all_tools()
","text":"Load Pydantic models (Params and Result) for all tools for validation.
Source code in src/aeiva/tool/toolkit/toolkit.py
def load_pydantic_models_for_all_tools(self):\n \"\"\"\n Load Pydantic models (Params and Result) for all tools for validation.\n \"\"\"\n logger.info(\"Loading Pydantic models for all tools.\")\n for tool_name in self.tool_names:\n try:\n param_model, result_model = self.load_pydantic_models_for_tool(tool_name)\n self.tool_models[tool_name] = (param_model, result_model)\n logger.debug(f\"Loaded models for tool '{tool_name}': Params={param_model}, Result={result_model}\")\n except Exception as e:\n logger.error(f\"Failed to load models for tool '{tool_name}': {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.load_pydantic_models_for_tool","title":"load_pydantic_models_for_tool(api_name)
","text":"Load the parameter and result Pydantic models for the given API.
Parameters:
Name Type Description Default api_name
str
The name of the API function.
required Returns:
Type Description Tuple[Type[BaseModel], Type[BaseModel]]
Tuple[Type[BaseModel], Type[BaseModel]]: The parameter and result model classes.
Raises:
Type Description ValueError
If models cannot be loaded.
Source code in src/aeiva/tool/toolkit/toolkit.py
def load_pydantic_models_for_tool(self, api_name: str) -> Tuple[Type[BaseModel], Type[BaseModel]]:\n \"\"\"\n Load the parameter and result Pydantic models for the given API.\n\n Args:\n api_name (str): The name of the API function.\n\n Returns:\n Tuple[Type[BaseModel], Type[BaseModel]]: The parameter and result model classes.\n\n Raises:\n ValueError: If models cannot be loaded.\n \"\"\"\n module_path = f\"aeiva.tool.api.{api_name}.model\" # Adjusted as per user's path\n try:\n models_module = importlib.import_module(module_path)\n param_model_class = getattr(models_module, f\"{snake_to_camel(api_name)}Params\", None)\n result_model_class = getattr(models_module, f\"{snake_to_camel(api_name)}Result\", None)\n if not (param_model_class and issubclass(param_model_class, BaseModel)):\n logger.error(f\"Param model class '{snake_to_camel(api_name)}Params' not found in '{module_path}'.\")\n raise ValueError(f\"Param model class '{snake_to_camel(api_name)}Params' not found in '{module_path}'.\")\n if not (result_model_class and issubclass(result_model_class, BaseModel)):\n logger.error(f\"Result model class '{snake_to_camel(api_name)}Result' not found in '{module_path}'.\")\n raise ValueError(f\"Result model class '{snake_to_camel(api_name)}Result' not found in '{module_path}'.\")\n return param_model_class, result_model_class\n except ImportError as e:\n logger.error(f\"Error importing models from '{module_path}': {e}\")\n raise ImportError(f\"Error importing models from '{module_path}': {e}\")\n except AttributeError as e:\n logger.error(f\"Error accessing model classes in '{module_path}': {e}\")\n raise ValueError(f\"Error accessing model classes in '{module_path}': {e}\")\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.perform_security_checks","title":"perform_security_checks(param_instance)
","text":"Perform security checks on parameters that require sanitization.
Parameters:
Name Type Description Default param_instance
BaseModel
The validated parameter instance.
required Returns:
Name Type Description BaseModel
BaseModel
The sanitized parameter instance.
Raises:
Type Description ValueError
If sanitization fails for any field or if config is required but not provided.
Source code in src/aeiva/tool/toolkit/toolkit.py
def perform_security_checks(self, param_instance: BaseModel) -> BaseModel:\n \"\"\"\n Perform security checks on parameters that require sanitization.\n\n Args:\n param_instance (BaseModel): The validated parameter instance.\n\n Returns:\n BaseModel: The sanitized parameter instance.\n\n Raises:\n ValueError: If sanitization fails for any field or if config is required but not provided.\n \"\"\"\n sanitized_params = param_instance.dict()\n\n for field_name, field in param_instance.__fields__.items():\n sanitize = field.field_info.extra.get('sanitize', False)\n if not sanitize:\n continue # Skip fields that do not require sanitization\n\n field_type = field.type_\n origin = get_origin(field_type)\n args = get_args(field_type)\n\n # Determine if the field is a string type or contains string types\n is_string_field = False\n\n if field_type == str:\n is_string_field = True\n elif origin is Union and str in args:\n is_string_field = True\n elif origin is list and len(args) == 1 and args[0] == str:\n is_string_field = True\n elif origin is Optional and str in args:\n is_string_field = True\n # Add more conditions here if there are other complex types containing strings\n\n if is_string_field:\n original_value = sanitized_params.get(field_name)\n if original_value is None:\n continue # Skip if the field value is None\n\n if self.config is None:\n logger.error(\n f\"Configuration is required to sanitize field '{field_name}', \"\n f\"but config is not provided.\"\n )\n raise ValueError(\n f\"Configuration is required to sanitize field '{field_name}', \"\n f\"but config is not provided.\"\n )\n\n try:\n # If the field is a list of strings, sanitize each path individually\n if origin is list and len(args) == 1 and args[0] == str:\n if not isinstance(original_value, list):\n logger.error(\n f\"Expected a list for field '{field_name}', \"\n f\"got {type(original_value)}.\"\n )\n raise ValueError(\n f\"Expected a list for field '{field_name}'.\"\n )\n sanitized_list = []\n for idx, item in enumerate(original_value):\n sanitized_item = sanitize_file_path(item, self.config)\n sanitized_list.append(sanitized_item)\n logger.debug(\n f\"Sanitized '{field_name}[{idx}]': '{item}' -> '{sanitized_item}'\"\n )\n sanitized_params[field_name] = sanitized_list\n else:\n # Sanitize single string path\n sanitized_path = sanitize_file_path(original_value, self.config)\n sanitized_params[field_name] = sanitized_path\n logger.debug(\n f\"Sanitized '{field_name}': '{original_value}' -> '{sanitized_path}'\"\n )\n except ValueError as ve:\n logger.error(\n f\"Sanitization failed for field '{field_name}': {ve}\"\n )\n raise\n\n # Create a new instance of the parameter model with sanitized parameters\n sanitized_instance = param_instance.copy(update=sanitized_params)\n\n return sanitized_instance\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.setup","title":"setup()
","text":"Setup the toolkit by loading tools, their schemas, and initializing shared resources.
Source code in src/aeiva/tool/toolkit/toolkit.py
def setup(self):\n \"\"\"\n Setup the toolkit by loading tools, their schemas, and initializing shared resources.\n \"\"\"\n logger.info(f\"Setting up toolkit '{self.toolkit_name}'.\")\n\n # Load tools and their schemas\n for tool_name in self.tool_names:\n tool = Tool(api_name=tool_name)\n self.tools[tool_name] = tool\n schema = tool.load_tool_schema(tool_name)\n self.tool_schemas[tool_name] = schema\n logger.debug(f\"Loaded schema for tool '{tool_name}': {schema}\")\n\n # Load Pydantic models for validation\n self.load_pydantic_models_for_all_tools()\n\n # Initialize shared resources\n self.init_shared_resources()\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.teardown","title":"teardown()
","text":"Teardown the toolkit by unloading tools, their schemas, and cleaning up shared resources.
Source code in src/aeiva/tool/toolkit/toolkit.py
def teardown(self):\n \"\"\"\n Teardown the toolkit by unloading tools, their schemas, and cleaning up shared resources.\n \"\"\"\n logger.info(f\"Tearing down toolkit '{self.toolkit_name}'.\")\n\n # Clean up shared resources\n self.teardown_shared_resources()\n\n # Clear loaded data\n self.tools.clear()\n self.tool_schemas.clear()\n self.tool_models.clear()\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.teardown_shared_resources","title":"teardown_shared_resources()
","text":"Teardown shared resources. Override this method in subclasses if needed.
Source code in src/aeiva/tool/toolkit/toolkit.py
def teardown_shared_resources(self):\n \"\"\"\n Teardown shared resources.\n Override this method in subclasses if needed.\n \"\"\"\n logger.info(\"Tearing down shared resources.\")\n # Placeholder for tearing down shared resources\n # Example:\n # if self.shared_resources:\n # self.shared_resources.close()\n pass\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit_config","title":"toolkit_config
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.toolkit_config.ToolkitConfig","title":"ToolkitConfig
dataclass
","text":" Bases: BaseConfig
Configuration for the Toolkit.
Source code in src/aeiva/tool/toolkit/toolkit_config.py
@dataclass\nclass ToolkitConfig(BaseConfig):\n \"\"\"\n Configuration for the Toolkit.\n \"\"\"\n\n allowed_directories: List[str] = field(\n default_factory=lambda: [\"/tmp/\", \"/home/user/allowed_directory/\"],\n metadata={\"help\": \"Directories that tools are allowed to access.\"}\n )\n # Mapping from OS usernames to roles\n user_role_mapping: Dict[str, str] = field(\n default_factory=lambda: {\n \"admin_user\": \"admin\",\n \"regular_user\": \"user\"\n # Add more user-role mappings as needed\n },\n metadata={\"help\": \"Mapping of OS usernames to their roles.\"}\n )\n # Define permissions for each role\n role_permissions: Dict[str, List[str]] = field(\n default_factory=lambda: {\n \"admin\": [\"delete_file\", \"view_file\", \"create_file\"],\n \"user\": [\"view_file\", \"create_file\"]\n },\n metadata={\"help\": \"Mapping of roles to allowed API functions.\"}\n )\n
"},{"location":"reference/#src.aeiva.tool.toolkit.web_toolkit","title":"web_toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.web_toolkit.WebToolkit","title":"WebToolkit
","text":" Bases: Toolkit
A toolkit for interacting with web pages.
Source code in src/aeiva/tool/toolkit/web_toolkit.py
class WebToolkit(Toolkit):\n \"\"\"\n A toolkit for interacting with web pages.\n \"\"\"\n\n def __init__(self, config=None):\n super().__init__(\n name=\"WebToolkit\",\n tool_names=[\n \"click_webpage_element\",\n \"crawl\",\n \"execute_js_script_on_webpage\",\n \"get_webpage_details\",\n \"get_webpage_elements\",\n \"navigate_browser_history\",\n \"navigate_to_webpage\",\n \"refresh_webpage\",\n \"scrape\",\n \"scroll_webpage\",\n \"type_text_in_webpage_element\",\n \"web_search\"\n ],\n config=config\n )\n
"},{"location":"reference/#src.aeiva.trainer","title":"trainer
","text":""},{"location":"reference/#src.aeiva.trainer.pl_trainer","title":"pl_trainer
","text":""},{"location":"reference/#src.aeiva.trainer.pl_trainer.LightningTrainer","title":"LightningTrainer
","text":" Bases: LightningModule
Source code in src/aeiva/trainer/pl_trainer.py
class LightningTrainer(pl.LightningModule):\n def __init__(self, model, tokenizer, config):\n super().__init__()\n self.model = model\n self.tokenizer = tokenizer\n self.config = config\n\n def forward(self, batch):\n outputs = self.model(batch)\n return outputs\n\n def training_step(self, batch, batch_idx):\n outputs = self(batch)\n loss = outputs.loss\n return {\"loss\": loss}\n\n def validation_step(self, batch, batch_idx):\n outputs = self(batch)\n loss = outputs.loss\n return {\"loss\": loss}\n\n def test_step(self, batch, batch_idx):\n outputs = self(batch)\n loss = outputs.loss\n return {\"loss\": loss}\n\n def configure_optimizers(self):\n \"\"\"\n Function to prepare the optimizer and learning rate scheduler for model training.\n This function separates model parameters into two categories: parameters that will experience weight decay, \n and those that will not (e.g., bias and layernorm weights). \n\n Returns:\n Tuple[Optimizer, Scheduler]: Tuple containing the optimizer and learning rate scheduler.\n \"\"\"\n\n # List of module types that will be subjected to weight decay.\n whitelist_weight_modules = (torch.nn.Linear, ) \n\n # List of module types that will not be subjected to weight decay.\n blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)\n\n # Parameter sets for decay and no decay.\n decay = set()\n no_decay = set()\n\n # Populate the decay and no_decay sets. \n # Loop over all modules to get module name (mn) and module (m).\n # !!!! revise later.\n # for mn, m in self.model.named_modules():\n # for pn, p in m.named_parameters():\n # fpn = '%s.%s' % (mn, pn) if mn else pn \n\n # if 'bias' in pn:\n # no_decay.add(fpn)\n # elif 'weight' in pn:\n # decay.add(fpn)\n\n param_dict = {pn: p for pn, p in self.model.named_parameters()}\n\n for mn, m in self.model.named_modules():\n for pn, p in m.named_parameters():\n fpn = '%s.%s' % (mn, pn) if mn else pn # full param name\n # random note: because named_modules and named_parameters are recursive\n # we will see the same tensors p many many times. but doing it this way\n # allows us to know which parent module any tensor p belongs to...\n # Adding new condition to check for the 'class_embedding' and 'logit_scale' parameters\n if pn.endswith('bias') or 'class_embedding' in pn or 'logit_scale' in pn:\n no_decay.add(fpn)\n elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):\n no_decay.add(fpn)\n for pn, p in param_dict.items():\n if pn not in no_decay:\n decay.add(pn)\n\n\n # # After this loop, print out all parameters in the intersection of decay and no_decay:\n # print(\"decay: \", decay)\n # print(\"no_decay: \", no_decay)\n # print(\"intersection: \", decay.intersection(no_decay))\n\n # print(\"difference: \", param_dict.keys() - (decay | no_decay))\n\n\n # # 'lm_head.weight' is tied to 'model.embed_tokens.weight', so it should not be decayed. \n # # This ensures that the same tensor is not optimized in different ways.\n # decay.remove('llm.lm_head.weight')\n\n # Validate that we considered every parameter.\n param_dict = {pn: p for pn, p in self.model.named_parameters()}\n assert len(decay & no_decay) == 0, \"Some parameters are in both decay and no_decay sets!\"\n assert len(param_dict.keys() - (decay | no_decay)) == 0, \"Some parameters are in neither decay nor no_decay sets!\"\n\n # Create the PyTorch optimizer object.\n optim_groups = [\n {\"params\": [param_dict[pn] for pn in sorted(list(decay))], \"weight_decay\": self.config.weight_decay},\n {\"params\": [param_dict[pn] for pn in sorted(list(no_decay))], \"weight_decay\": 0.0},\n ]\n # new PyTorch nightly has a new 'fused' option for AdamW that is much faster\n use_fused = (self.config.device == 'cuda') and (\n 'fused' in inspect.signature(torch.optim.AdamW).parameters)\n print(f\"using fused AdamW: {use_fused}\")\n extra_args = dict(fused=True) if use_fused else dict()\n optimizer = torch.optim.AdamW(\n optim_groups, lr=self.config.learning_rate, betas=(self.config.adam_beta1, self.config.adam_beta2), **extra_args)\n\n # Prepare learning rate scheduler.\n total_steps = self.config.max_steps\n pct_start = self.config.warmup_steps / total_steps\n final_div_factor = self.config.learning_rate / self.config.min_lr\n\n scheduler = {\n 'scheduler': torch.optim.lr_scheduler.OneCycleLR(\n optimizer,\n max_lr=self.config.learning_rate,\n total_steps=total_steps,\n pct_start=pct_start,\n final_div_factor=final_div_factor,\n div_factor=1.0, # No additional scaling for the initial learning rate\n anneal_strategy='cos', # Use cosine annealing\n cycle_momentum=False, # Disable momentum cycling\n ),\n 'interval': 'step',\n 'frequency': 1\n }\n\n return [optimizer], [scheduler]\n\n\n def get_num_params(self, non_embedding=True):\n \"\"\"\n Return the number of parameters in the model.\n For non-embedding count (default), the position embeddings get subtracted.\n The token embeddings would too, except due to the parameter sharing these\n params are actually used as weights in the final layer, so we include them.\n \"\"\"\n n_params = sum(p.numel() for p in self.model.parameters())\n if non_embedding:\n embedding_params = sum(p.numel() for m in self.model.modules() if isinstance(m, nn.Embedding) for p in m.parameters())\n n_params -= embedding_params\n return n_params\n\n def estimate_mfu(self, fwdbwd_per_iter, dt):\n \"\"\" estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS \"\"\"\n # first estimate the number of flops we do per iteration.\n # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311\n N = self.get_num_params()\n cfg = self.config\n L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size\n flops_per_token = 6*N + 12*L*H*Q*T\n flops_per_fwdbwd = flops_per_token * T\n flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter\n # express our flops throughput as ratio of A100 bfloat16 peak flops\n flops_achieved = flops_per_iter * (1.0/dt) # per second\n flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS\n mfu = flops_achieved / flops_promised\n return mfu\n
"},{"location":"reference/#src.aeiva.trainer.pl_trainer.LightningTrainer.configure_optimizers","title":"configure_optimizers()
","text":"Function to prepare the optimizer and learning rate scheduler for model training. This function separates model parameters into two categories: parameters that will experience weight decay, and those that will not (e.g., bias and layernorm weights).
Returns:
Type Description Tuple[Optimizer, Scheduler]: Tuple containing the optimizer and learning rate scheduler.
Source code in src/aeiva/trainer/pl_trainer.py
def configure_optimizers(self):\n \"\"\"\n Function to prepare the optimizer and learning rate scheduler for model training.\n This function separates model parameters into two categories: parameters that will experience weight decay, \n and those that will not (e.g., bias and layernorm weights). \n\n Returns:\n Tuple[Optimizer, Scheduler]: Tuple containing the optimizer and learning rate scheduler.\n \"\"\"\n\n # List of module types that will be subjected to weight decay.\n whitelist_weight_modules = (torch.nn.Linear, ) \n\n # List of module types that will not be subjected to weight decay.\n blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)\n\n # Parameter sets for decay and no decay.\n decay = set()\n no_decay = set()\n\n # Populate the decay and no_decay sets. \n # Loop over all modules to get module name (mn) and module (m).\n # !!!! revise later.\n # for mn, m in self.model.named_modules():\n # for pn, p in m.named_parameters():\n # fpn = '%s.%s' % (mn, pn) if mn else pn \n\n # if 'bias' in pn:\n # no_decay.add(fpn)\n # elif 'weight' in pn:\n # decay.add(fpn)\n\n param_dict = {pn: p for pn, p in self.model.named_parameters()}\n\n for mn, m in self.model.named_modules():\n for pn, p in m.named_parameters():\n fpn = '%s.%s' % (mn, pn) if mn else pn # full param name\n # random note: because named_modules and named_parameters are recursive\n # we will see the same tensors p many many times. but doing it this way\n # allows us to know which parent module any tensor p belongs to...\n # Adding new condition to check for the 'class_embedding' and 'logit_scale' parameters\n if pn.endswith('bias') or 'class_embedding' in pn or 'logit_scale' in pn:\n no_decay.add(fpn)\n elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):\n no_decay.add(fpn)\n for pn, p in param_dict.items():\n if pn not in no_decay:\n decay.add(pn)\n\n\n # # After this loop, print out all parameters in the intersection of decay and no_decay:\n # print(\"decay: \", decay)\n # print(\"no_decay: \", no_decay)\n # print(\"intersection: \", decay.intersection(no_decay))\n\n # print(\"difference: \", param_dict.keys() - (decay | no_decay))\n\n\n # # 'lm_head.weight' is tied to 'model.embed_tokens.weight', so it should not be decayed. \n # # This ensures that the same tensor is not optimized in different ways.\n # decay.remove('llm.lm_head.weight')\n\n # Validate that we considered every parameter.\n param_dict = {pn: p for pn, p in self.model.named_parameters()}\n assert len(decay & no_decay) == 0, \"Some parameters are in both decay and no_decay sets!\"\n assert len(param_dict.keys() - (decay | no_decay)) == 0, \"Some parameters are in neither decay nor no_decay sets!\"\n\n # Create the PyTorch optimizer object.\n optim_groups = [\n {\"params\": [param_dict[pn] for pn in sorted(list(decay))], \"weight_decay\": self.config.weight_decay},\n {\"params\": [param_dict[pn] for pn in sorted(list(no_decay))], \"weight_decay\": 0.0},\n ]\n # new PyTorch nightly has a new 'fused' option for AdamW that is much faster\n use_fused = (self.config.device == 'cuda') and (\n 'fused' in inspect.signature(torch.optim.AdamW).parameters)\n print(f\"using fused AdamW: {use_fused}\")\n extra_args = dict(fused=True) if use_fused else dict()\n optimizer = torch.optim.AdamW(\n optim_groups, lr=self.config.learning_rate, betas=(self.config.adam_beta1, self.config.adam_beta2), **extra_args)\n\n # Prepare learning rate scheduler.\n total_steps = self.config.max_steps\n pct_start = self.config.warmup_steps / total_steps\n final_div_factor = self.config.learning_rate / self.config.min_lr\n\n scheduler = {\n 'scheduler': torch.optim.lr_scheduler.OneCycleLR(\n optimizer,\n max_lr=self.config.learning_rate,\n total_steps=total_steps,\n pct_start=pct_start,\n final_div_factor=final_div_factor,\n div_factor=1.0, # No additional scaling for the initial learning rate\n anneal_strategy='cos', # Use cosine annealing\n cycle_momentum=False, # Disable momentum cycling\n ),\n 'interval': 'step',\n 'frequency': 1\n }\n\n return [optimizer], [scheduler]\n
"},{"location":"reference/#src.aeiva.trainer.pl_trainer.LightningTrainer.estimate_mfu","title":"estimate_mfu(fwdbwd_per_iter, dt)
","text":"estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS
Source code in src/aeiva/trainer/pl_trainer.py
def estimate_mfu(self, fwdbwd_per_iter, dt):\n \"\"\" estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS \"\"\"\n # first estimate the number of flops we do per iteration.\n # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311\n N = self.get_num_params()\n cfg = self.config\n L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size\n flops_per_token = 6*N + 12*L*H*Q*T\n flops_per_fwdbwd = flops_per_token * T\n flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter\n # express our flops throughput as ratio of A100 bfloat16 peak flops\n flops_achieved = flops_per_iter * (1.0/dt) # per second\n flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS\n mfu = flops_achieved / flops_promised\n return mfu\n
"},{"location":"reference/#src.aeiva.trainer.pl_trainer.LightningTrainer.get_num_params","title":"get_num_params(non_embedding=True)
","text":"Return the number of parameters in the model. For non-embedding count (default), the position embeddings get subtracted. The token embeddings would too, except due to the parameter sharing these params are actually used as weights in the final layer, so we include them.
Source code in src/aeiva/trainer/pl_trainer.py
def get_num_params(self, non_embedding=True):\n \"\"\"\n Return the number of parameters in the model.\n For non-embedding count (default), the position embeddings get subtracted.\n The token embeddings would too, except due to the parameter sharing these\n params are actually used as weights in the final layer, so we include them.\n \"\"\"\n n_params = sum(p.numel() for p in self.model.parameters())\n if non_embedding:\n embedding_params = sum(p.numel() for m in self.model.modules() if isinstance(m, nn.Embedding) for p in m.parameters())\n n_params -= embedding_params\n return n_params\n
"},{"location":"reference/#src.aeiva.util","title":"util
","text":""},{"location":"reference/#src.aeiva.util.file_utils","title":"file_utils
","text":""},{"location":"reference/#src.aeiva.util.file_utils.from_json_or_yaml","title":"from_json_or_yaml(filepath)
","text":"Load configuration from a JSON or YAML file based on the file extension.
Parameters:
Name Type Description Default filepath
str
The path to the configuration file.
required Returns:
Name Type Description dict
dict
The configuration dictionary.
Raises:
Type Description FileNotFoundError
If the file does not exist.
ValueError
If the file extension is unsupported or if parsing fails.
Source code in src/aeiva/util/file_utils.py
def from_json_or_yaml(filepath: str) -> dict:\n \"\"\"\n Load configuration from a JSON or YAML file based on the file extension.\n\n Args:\n filepath (str): The path to the configuration file.\n\n Returns:\n dict: The configuration dictionary.\n\n Raises:\n FileNotFoundError: If the file does not exist.\n ValueError: If the file extension is unsupported or if parsing fails.\n \"\"\"\n if not os.path.exists(filepath):\n logger.error(f\"Configuration file not found at path: {filepath}\")\n raise FileNotFoundError(f\"Configuration file not found at path: {filepath}\")\n\n _, ext = os.path.splitext(filepath)\n ext = ext.lower()\n\n try:\n with open(filepath, 'r', encoding='utf-8') as f:\n if ext == '.json':\n config = json.load(f)\n logger.info(f\"Loaded JSON configuration from {filepath}.\")\n return config\n elif ext in ['.yaml', '.yml']:\n config = yaml.safe_load(f)\n logger.info(f\"Loaded YAML configuration from {filepath}.\")\n return config\n else:\n logger.error(f\"Unsupported configuration file format: {ext}\")\n raise ValueError(f\"Unsupported configuration file format: {ext}\")\n except (json.JSONDecodeError, yaml.YAMLError) as e:\n logger.error(f\"Error parsing configuration file '{filepath}': {e}\")\n raise ValueError(f\"Error parsing configuration file '{filepath}': {e}\")\n
"},{"location":"reference/#src.aeiva.util.os_utils","title":"os_utils
","text":""},{"location":"reference/#src.aeiva.util.os_utils.get_os_user","title":"get_os_user()
","text":"Retrieve the current OS username.
Returns:
Name Type Description str
str
The current OS user's name.
Source code in src/aeiva/util/os_utils.py
def get_os_user() -> str:\n \"\"\"\n Retrieve the current OS username.\n\n Returns:\n str: The current OS user's name.\n \"\"\"\n return getpass.getuser()\n
"},{"location":"reference/#src.aeiva.util.path_utils","title":"path_utils
","text":""},{"location":"reference/#src.aeiva.util.path_utils.get_package_root","title":"get_package_root(package_name)
","text":"Obtain the root directory of a given package.
Parameters:
Name Type Description Default package_name
str
The name of the package.
required Returns:
Name Type Description str
str
The absolute path to the package root directory.
Source code in src/aeiva/util/path_utils.py
def get_package_root(package_name: str) -> str:\n \"\"\"\n Obtain the root directory of a given package.\n\n Args:\n package_name (str): The name of the package.\n\n Returns:\n str: The absolute path to the package root directory.\n \"\"\"\n spec = importlib.util.find_spec(package_name)\n if spec is None or spec.origin is None:\n raise ImportError(f\"Cannot find package '{package_name}'\")\n package_path = os.path.dirname(os.path.abspath(spec.origin))\n return package_path\n
"},{"location":"reference/#src.aeiva.util.path_utils.get_user_home_path","title":"get_user_home_path()
","text":"Retrieves the home directory of the current user across different platforms.
Supported Platforms - Windows
- macOS
- Linux
- iOS (best-effort)
- Android (best-effort)
Returns:
Name Type Description Path
Path
A Path
object representing the user's home directory.
Raises:
Type Description EnvironmentError
If the home directory cannot be determined.
Source code in src/aeiva/util/path_utils.py
def get_user_home_path() -> Path:\n \"\"\"\n Retrieves the home directory of the current user across different platforms.\n\n Supported Platforms:\n - Windows\n - macOS\n - Linux\n - iOS (best-effort)\n - Android (best-effort)\n\n Returns:\n Path: A `Path` object representing the user's home directory.\n\n Raises:\n EnvironmentError: If the home directory cannot be determined.\n \"\"\"\n system = platform.system()\n logger.info(f\"Detected operating system: {system}\")\n\n try:\n if system == \"Windows\":\n # Windows: Use USERPROFILE or combine HOMEDRIVE and HOMEPATH\n home = os.environ.get('USERPROFILE') or os.path.join(os.environ.get('HOMEDRIVE', ''), os.environ.get('HOMEPATH', ''))\n logger.debug(f\"Windows home directory: {home}\")\n elif system in [\"Linux\", \"Darwin\"]: # Darwin is macOS\n # Unix-like systems: Use expanduser\n home = os.path.expanduser(\"~\")\n logger.debug(f\"Unix-like home directory: {home}\")\n elif system == \"Java\": # Potentially Android (e.g., running via Jython or similar)\n # Android typically uses /data/user/0/<package_name>/ or /sdcard/\n # However, accessing these paths may require specific permissions\n # Here, we attempt to use the HOME environment variable\n home = os.environ.get('HOME') or '/sdcard/'\n logger.debug(f\"Android home directory (best-effort): {home}\")\n elif system == \"iOS\":\n # iOS applications are sandboxed; home directory is typically the app's sandbox\n # Accessing it might require specific APIs or configurations\n # Here, we return the current working directory as a placeholder\n home = Path.cwd()\n logger.debug(f\"iOS home directory (best-effort): {home}\")\n else:\n # Fallback for unknown systems\n home = os.path.expanduser(\"~\")\n logger.warning(f\"Unknown system '{system}'. Falling back to expanduser: {home}\")\n\n if home and os.path.isdir(home):\n return Path(home)\n else:\n raise EnvironmentError(\"Determined home directory does not exist or is not a directory.\")\n\n except Exception as e:\n logger.error(f\"Failed to determine the user's home directory: {e}\")\n raise EnvironmentError(\"Cannot determine the user's home directory.\") from e\n
"},{"location":"reference/#src.aeiva.util.path_utils.snake_to_camel","title":"snake_to_camel(snake_str)
","text":"Convert a snake_case string to CamelCase.
Parameters:
Name Type Description Default snake_str
str
The snake_case string.
required Returns:
Name Type Description str
str
The CamelCase string.
Source code in src/aeiva/util/path_utils.py
def snake_to_camel(snake_str: str) -> str:\n \"\"\"\n Convert a snake_case string to CamelCase.\n\n Args:\n snake_str (str): The snake_case string.\n\n Returns:\n str: The CamelCase string.\n \"\"\"\n components = snake_str.split('_')\n # Capitalize the first letter of each component\n return ''.join(x.title() for x in components)\n
"},{"location":"reference/#src.aeiva.util.token_utils","title":"token_utils
","text":""},{"location":"reference/#src.aeiva.util.token_utils.pad_or_truncate_tokens","title":"pad_or_truncate_tokens(tokens, max_length, pad_token_id)
","text":"This function aims to pad or truncate tokens to max_length.
Parameters:
Name Type Description Default tokens
list
the list of tokens.
required max_length
int
the max length of tokens.
required pad_token_id
int
the id of pad token.
required Returns:
Name Type Description tokens
list
the list of tokens after padding or truncating.
Source code in src/aeiva/util/token_utils.py
def pad_or_truncate_tokens(tokens, max_length, pad_token_id):\n \"\"\" This function aims to pad or truncate tokens to max_length.\n\n Args:\n tokens (list): the list of tokens.\n max_length (int): the max length of tokens.\n pad_token_id (int): the id of pad token.\n\n Returns:\n tokens (list): the list of tokens after padding or truncating.\n \"\"\"\n if len(tokens) > max_length:\n tokens = tokens[:max_length]\n elif len(tokens) < max_length:\n tokens = tokens + [pad_token_id] * (max_length - len(tokens))\n return tokens\n
"},{"location":"tutorials/","title":"Tutorials","text":"Here we summarize some experience we learned during developing Aeiva.
How to generate project documentation automatically from docstrings
"},{"location":"blogs/Thoughts_on_key_concepts_for_agentic_intelligence/","title":"Thoughts on Several Key Concepts for Agentic Intelligence","text":"Author: Bang Liu
Date: 2023-10-21
In building an intelligent agent system, especially one designed to perform complex tasks and learn from experience, it is crucial to clearly define core concepts that guide its behavior. These concepts shape how the agent interacts with its environment, executes tasks, learns from past experiences, and acquires new knowledge. Below are my thoughts on several key concepts, enriched with examples to make them more tangible.
"},{"location":"blogs/Thoughts_on_key_concepts_for_agentic_intelligence/#1-what-is-a-plan","title":"1. What is a Plan?","text":"A Plan is a structured, goal-driven roadmap for an agent to achieve a specific task. The key feature of a Plan is that it decomposes the primary task into subtasks, forming a hierarchical structure. The agent follows this roadmap, completing one subtask after another. Since a plan ultimately governs execution, it must be well-structured\u2014most naturally as a Directed Acyclic Graph (DAG).
Each node in the DAG represents a Task or subtask, and the edges describe dependencies between them. This ensures a logical, stepwise execution where subtasks cannot begin until their dependencies are satisfied.
- Example: Consider an agent tasked with preparing a meal. The plan breaks the main task (\"Cook meal\") into subtasks like \"Chop vegetables,\" \"Boil water,\" \"Cook rice,\" and \"Serve meal.\" Some tasks must precede others (e.g., \"Boil water\" must happen before \"Cook rice\"). This structure forms a DAG, ensuring tasks are completed in the correct order without cycles or deadlocks.
"},{"location":"blogs/Thoughts_on_key_concepts_for_agentic_intelligence/#2-what-is-a-task","title":"2. What is a Task?","text":"A Task is the fundamental unit of work in a plan. Each task has a clear status, which can be one of: - Not Executed: The task is yet to be started. - Executing: The task is currently being performed. - Success: The task has been completed successfully. - Fail: The task has failed, possibly requiring intervention or retry.
Tasks can have meta-data such as the task owner, creation time, priority, or other relevant attributes. A task also needs a mechanism to check whether it has been completed successfully, which might involve running tests or checking outputs against expectations.
- Example: In a factory, an agent may have a task like \"Assemble component A.\" The task could have metadata such as who is responsible (agent A or robot arm B), creation time (timestamp when this task was queued), and a priority level (perhaps \"high\" because component A is needed soon). After execution, the task might check the assembled part for defects before marking itself as \"Success.\"
"},{"location":"blogs/Thoughts_on_key_concepts_for_agentic_intelligence/#3-what-is-a-tool","title":"3. What is a Tool?","text":"A Tool provides functionality that the agent can use to perform actions. In modern software, a tool often takes the form of an API\u2014a set of operations that accept inputs (parameters) and return outputs (results).
Tools can be seen as atomic units of functionality that are executed in isolation, but their results can influence the broader task or plan. Tools are often reusable across different tasks or plans.
- Example: Consider a research assistant agent that interacts with a remote API to retrieve scientific papers. Here, the \"Arxiv API\" is a tool. The agent calls this API (providing search parameters), and the tool returns a list of papers in a structured format. The agent uses this tool to complete tasks like \"Find papers related to quantum computing.\"
"},{"location":"blogs/Thoughts_on_key_concepts_for_agentic_intelligence/#4-what-is-an-action","title":"4. What is an Action?","text":"An Action is a higher-level operation the agent can take. While it may use a tool (or multiple tools), it is broader than just invoking a function. An Action might involve decision-making, performing logic internally, or combining the output of multiple tools.
Whereas tools are about \"doing one thing well,\" actions are more about how the agent decides to use tools or perform processes. Some actions may not even require external tools but might involve manipulating data internally.
- Example: A warehouse robot's action could be \"Pick up an item from shelf A and place it in bin B.\" The action uses the robot\u2019s sensors and movement tools, but the decision-making on how to execute it\u2014like which arm to use or which path to follow\u2014is part of the action.
"},{"location":"blogs/Thoughts_on_key_concepts_for_agentic_intelligence/#5-what-is-a-skill","title":"5. What is a Skill?","text":"A Skill is the procedural knowledge an agent uses to complete a task. It represents a series of actions or steps the agent follows to solve a problem. Skills can be encoded as DAGs, with each node representing an action, and the edges defining the flow or dependencies between actions.
What distinguishes a Skill from hardcoded instructions is its flexibility. For instance, a skill may allow for different actions to be taken in varying orders, or certain parameters may be adjusted dynamically. In other words, a skill isn\u2019t rigid but adaptable to different contexts or environments.
- Example: An agent trained to clean a room could have a \"Cleaning skill.\" It involves subtasks like \"vacuum the floor,\" \"wipe the table,\" and \"empty the trash.\" In some cases, the agent may vacuum first and then wipe the table, but in others, it may reverse the order depending on room conditions. The ability to adapt while following a general cleaning procedure is what makes it a skill.
"},{"location":"blogs/Thoughts_on_key_concepts_for_agentic_intelligence/#6-what-is-an-experience","title":"6. What is an Experience?","text":"An Experience is a personal record of how an agent solved a particular task. While the structure of an Experience may resemble that of a Skill, it is tied to a specific instance.
The main distinction is that Experiences are not generalized. Instead, they capture the details of how a task was solved under particular circumstances, including all the decisions, parameters, and actions taken during the process. Over time, multiple experiences can be analyzed to derive common patterns, which evolve into Skills.
- Example: After attempting to solve several puzzles, an agent might log each experience\u2014how it solved the puzzle, what tools it used, how long it took, etc. After analyzing several such experiences, the agent may extract a general strategy (skill) for solving puzzles of this type.
"},{"location":"blogs/Thoughts_on_key_concepts_for_agentic_intelligence/#7-what-is-memory","title":"7. What is Memory?","text":"Memory is the broader concept that includes all the data an agent remembers about its past actions, interactions, and decisions. Memory could encompass many forms, including: - Experiential memory: Specific memories about how the agent solved tasks (as described in Experience). - Episodic memory: Memory of specific events or interactions the agent has been part of. - Semantic memory: Knowledge the agent has learned about its environment or domain.
Memory plays a critical role in making an agent \"intelligent,\" as it allows the agent to learn from past mistakes, reuse successful strategies, and adapt to new situations by recalling prior experiences.
- Example: A personal assistant agent might have episodic memory of the last time it scheduled a meeting for the user. The next time the user asks it to schedule a meeting, it can reference that memory to understand the user's preferences, such as their preferred meeting time.
"},{"location":"blogs/Thoughts_on_key_concepts_for_agentic_intelligence/#8-what-is-knowledge","title":"8. What is Knowledge?","text":"Knowledge is a validated, generalizable form of learning. While an experience is a personal, one-off record, knowledge has been abstracted and validated across multiple situations. Knowledge allows the agent to generalize beyond specific experiences and apply what it has learned to new, similar tasks.
In many cases, a Skill represents a particular type of knowledge\u2014the procedural knowledge required to complete a task. Knowledge might also be sharable between agents, or taught from one agent to another, making it reusable.
- Example: An agent that has learned to solve various types of math problems can generalize its knowledge into a set of skills. When faced with a new math problem, it can apply this knowledge, even if the problem differs slightly from the ones it has solved before.
"},{"location":"blogs/Thoughts_on_key_concepts_for_agentic_intelligence/#closing-thoughts","title":"Closing Thoughts","text":"These key concepts\u2014Plan, Task, Tool, Action, Skill, Experience, Memory, and Knowledge\u2014form the foundation of agentic intelligence. Together, they allow an agent to: - Decompose tasks into executable steps (Plan), - Perform specific actions (Task, Action, Tool), - Learn from both immediate tasks and general experiences (Experience, Memory), - Generalize that learning into knowledge that improves future performance (Knowledge, Skill).
By keeping these concepts clear and well-defined, an agent can operate in a structured, intelligent way, continually learning and improving over time.
"},{"location":"blogs/unity_dev_notes/","title":"Notes about how I developed the AI Konoha Village","text":""},{"location":"blogs/unity_dev_notes/#obtain-the-hidden-leaf-village-3d-model","title":"Obtain the \"hidden leaf village \" 3D model","text":"I downloaded it from here:
https://mega.nz/file/vkcHSYLT#t5gG06y65gEp8g3U8N8Yic5BijvZ0PA_7UstCmnoG38
https://www.deviantart.com/naruko-uzumaki-the-9/art/Hidden-Leaf-Village-Complete-DL-Fixed-809223977
"},{"location":"blogs/unity_dev_notes/#import-to-blender","title":"Import to blender","text":"In my case, I cannot directly open the files. But I can import the .fbx file in blender 3.6 (mac M1). Change the Render Engine from Eevee to Workbench, and then at the Color drop menu, select Texture. Then press \"Z\" and select \"render\" model. You will see colored model there.
"},{"location":"blogs/unity_dev_notes/#import-export-fbx-file-from-blender","title":"import & export .fbx file from blender","text":"When export .fbx file from blender and load to unity, it may encounter errors like mesh tangents or self intersection warning. The way to solve this is: 1. Install Better FBX Importer Exporter plugin for blender (it solves the mesh tangent problem); 2. When export using the plugin, select triangulate (it solves the intersection problem).
"},{"location":"blogs/unity_dev_notes/#import-fbx-or-dae-file-to-unity","title":"import .fbx or .dae file to unity","text":"I found the best way is directly drag the whole folder including the materials/textures to the asset folder of the unity project. Then unity will load the assets in the folder and generate .meta data. After that, we can drag the assets to the project from the \"project\" window. Note that seems unity 2022 doesn't show project and inspector windows by default. But unity 2021 can show the windows. For unity 2022, we can select \"Window -> Layouts -> Default\" to get the desired layout.
I also compared the .dae and .fbx file for the hidden leaf village model. In Unity, seems the \"Hidden Leaf Village - Complete.dae\" file looks better in Unity.
"},{"location":"tutorials/How_to_Make_a_Python_Package/","title":"How to Make Your Python Project a Pip-Installable Package","text":"Author: Bang Liu
Date: 2024-11-23
This guide walks you through the process of creating a Python package that others can install using pip
.
"},{"location":"tutorials/How_to_Make_a_Python_Package/#step-1-structure-your-project","title":"Step 1: Structure Your Project","text":"Organize your project with a proper directory structure:
your_project/\n\u251c\u2500\u2500 src/\n\u2502 \u2514\u2500\u2500 your_project/\n\u2502 \u251c\u2500\u2500 __init__.py # Makes this a package\n\u2502 \u251c\u2500\u2500 module.py # Your module files\n\u251c\u2500\u2500 setup.py # Metadata and build script\n\u251c\u2500\u2500 README.md # Project description\n\u251c\u2500\u2500 LICENSE # License file (optional but recommended)\n\u251c\u2500\u2500 requirements.txt # Dependency file (optional)\n
src/your_project/
: Contains your package code. __init__.py
: Makes the folder a Python package. setup.py
: Defines metadata and installation behavior.
"},{"location":"tutorials/How_to_Make_a_Python_Package/#step-2-create-setuppy","title":"Step 2: Create setup.py
","text":"setup.py
is the script used to build and install your package. Here's a sample:
from setuptools import setup, find_packages\n\nsetup(\n name=\"your_project\", # Your package name\n version=\"0.1.0\", # Package version\n author=\"Your Name\", # Your name\n author_email=\"your.email@example.com\", # Your email\n description=\"A brief description\", # Short description\n long_description=open('README.md').read(), # Long description from README\n long_description_content_type='text/markdown', # Markdown format\n url=\"https://github.com/username/repository\", # Project repository\n packages=find_packages(where=\"src\"), # Find packages in src/\n package_dir={\"\": \"src\"}, # Root directory for packages\n classifiers=[ # Metadata for PyPI\n \"Programming Language :: Python :: 3\",\n \"License :: OSI Approved :: MIT License\",\n \"Operating System :: OS Independent\",\n ],\n python_requires='>=3.6', # Minimum Python version\n install_requires=[ # Dependencies\n \"numpy\", # Example dependency\n ],\n)\n
"},{"location":"tutorials/How_to_Make_a_Python_Package/#step-3-create-readmemd","title":"Step 3: Create README.md
","text":"Write a README.md
file to describe your project. Use Markdown for formatting. Example:
# Your Project Name\n\nA short description of your project.\n\n## Installation\n\n```bash\npip install your_project\n
"},{"location":"tutorials/How_to_Make_a_Python_Package/#usage","title":"Usage","text":"import your_project\nyour_project.some_function()\n
"},{"location":"tutorials/How_to_Make_a_Python_Package/#step-4-test-your-package-locally","title":"Step 4: Test Your Package Locally","text":"Test your package before publishing:
-
Navigate to your project root: bash cd /path/to/your_project
-
Install it in editable mode: bash pip install -e .
-
Import your package to verify: ```bash python
import your_project your_project.some_function() ```
"},{"location":"tutorials/How_to_Make_a_Python_Package/#step-5-build-the-package","title":"Step 5: Build the Package","text":"Install the necessary tools:
pip install build\n
Build your package:
python -m build\n
This creates a dist/
directory with .tar.gz
and .whl
files.
"},{"location":"tutorials/How_to_Make_a_Python_Package/#step-6-upload-to-pypi","title":"Step 6: Upload to PyPI","text":" - Register on PyPI:
- Create an account at PyPI.
-
Optionally, register on TestPyPI for testing.
-
Install Twine: bash pip install twine
-
Upload Your Package: bash python -m twine upload dist/*
To test uploads on TestPyPI: bash python -m twine upload --repository testpypi dist/*
- Provide Your PyPI Token:
-
If prompted, enter your PyPI API token.
-
Alternate way to uplaod
python -m twine upload --repository-url https://upload.pypi.org/legacy/ dist/* -u __token__ -p pypi-<your token password here>\n
"},{"location":"tutorials/How_to_Make_a_Python_Package/#step-7-verify-installation","title":"Step 7: Verify Installation","text":"Install your package from PyPI:
pip install your_project\n
Verify it works as expected:
python\n>>> import your_project\n>>> your_project.some_function()\n
"},{"location":"tutorials/How_to_Make_a_Python_Package/#tips-and-best-practices","title":"Tips and Best Practices","text":" - Include a License: Add a
LICENSE
file to clarify usage terms. - Automate Versioning: Use tools like
bumpversion
to manage versions. - Test Thoroughly: Use TestPyPI before uploading to the main PyPI repository.
- Secure Tokens: Use project-specific tokens for uploads.
Congratulations! Your project is now a pip-installable Python package.
"},{"location":"tutorials/coding_guidelines/","title":"Coding Guidelines","text":""},{"location":"tutorials/coding_guidelines/#code-hierarchy","title":"Code Hierarchy","text":"Generally, our code can be organized into three different levels:
-
Framework: This level forms the architectural backbone of your project. It houses the core functionalities that define the basic structure and shared logic for your project. These files, typically stored under the package_name/bases/
directory, establish the protocols and high-level operations that the rest of your project will adhere to.
-
Brick: The \"Brick\" level acts as a collection of modular, reusable components used across your project. These components, which are stored in the package_name/xxx/
directory, promote code reusability and reduce redundancy, thereby enhancing the efficiency of your codebase.
-
Applications: This level contains specific implementations associated with particular datasets, models, or experiments. These files, which are stored in the package_name/
directory, are separate from the abstract base classes and reusable functions found in the other levels. This separation aids in code navigation and readability, making it easier to locate and understand the specific components of your project.
By adhering to this structure, your codebase will be well-organized, easily navigable, and efficient. This organization adheres to best practices in software development, promoting code reusability and a clear separation of concerns.
"},{"location":"tutorials/coding_guidelines/#generate-requirementstxt","title":"Generate requirements.txt","text":"Use pipreqs: pipreqs is a useful tool that generates a requirements.txt file based on the imports in your Python project, not on the installed packages in your current environment. You can install it and use it as follows:
pip install pipreqs\npipreqs --force /path/to/your/project\n
"},{"location":"tutorials/coding_guidelines/#args-and-kwargs","title":"args and *kwargs","text":"*args
and **kwargs
in Python allow a function to accept optional arguments, meaning that the user can pass a variable number of arguments to these functions. Here's when you might want to use them:
-
When you're not sure how many arguments might be passed to your function: *args
is used to send a non-keyworded variable-length argument list to your function. You might use it when you're not sure how many arguments might be passed to your function, or if you want to support an arbitrary number of arguments.
-
When you want to write a function that must accept a dictionary: **kwargs
is used to pass a keyworded, variable-length argument list. You would use this if you want your function to be able to accept a dictionary of attributes.
-
When creating wrapper functions or decorators: *args
and **kwargs
are commonly used when you're writing higher-order functions or decorators that need to manipulate the inputs to another function that they're wrapping.
-
When subclassing and you want to extend the parent class's methods: In this case, you may not know exactly what the parent class's method takes as arguments. *args
and **kwargs
let you pass any parameters from the child class to the parent class's method without having to know what those parameters are.
However, while *args
and **kwargs
are very helpful, they should be used judiciously. Overuse can make your code harder to understand and debug since it's not immediately clear what arguments a function expects. When writing a function, if you know the exact number and role of each argument, it's better to list them explicitly.
In summary, *args
and **kwargs
are powerful tools that make Python functions more flexible. However, as with any tool, they should be used judiciously and appropriately.
"},{"location":"tutorials/coding_guidelines/#order-of-function-arguments-in-python","title":"Order of Function Arguments in Python","text":"In Python, the recommended order of function parameters is as follows:
-
Required positional arguments: These are arguments that need to be in a specific positional order. When calling the function, Python interprets them based on their order.
Example: def func(name, age):
-
Optional positional arguments / Default Parameters: These are arguments that are optional and have a default value. They are also interpreted based on their order.
Example: def func(name, age=22):
-
Required keyword-only arguments: These are arguments that must be supplied by keyword and follow a \"*,\" in the function definition.
Example: def func(name, age, *, city):
-
Optional keyword-only arguments / Default Keyword Parameters: These are keyword arguments that are optional. The function will use the default value if no value is provided.
Example: def func(name, age, *, city='New York'):
-
Arbitrary argument lists: The *args
and **kwargs
parameters, which collect all positional and keyword arguments that are not already caught by other parameters.
Example: def func(name, age, *args, city='New York', **kwargs):
This order can help make your function definitions clear and easy to read. It also helps prevent common bugs caused by confusing positional and keyword arguments.
"},{"location":"tutorials/coding_guidelines/#naming-noun-or-verb","title":"Naming: Noun or Verb?","text":"Thing Choice of Word Modules Noun Data types Noun or Adjective Functions Noun or Verb Constants/Variables Noun - Try to make your name short and avoid longer than 3 words name if possible.
- Use verb or noun for functions or methods depends on what you want to emphasize: the return result or the process to get the result.
To better choose verbs for functions, below are some suggestions:
- Is the function a test? -> test_\\_\\.
-
Does the function has a @property decorator? -> don\u2019t use a verb in the function name.
-
Does the function use a disk or a network:
3.1. \u2026 to store data? -> save_to, send, write_to
3.2. \u2026 to receive data? -> fetch, load, read
-
Does the function output any data? -> print, output
-
Returns boolean value? -> is_, has_/have_, can_, check_if_\\_\\
-
Aggregates data? -> calculate, extract, analyze
-
Put data from one form to another:
7.1. Creates a single meaningful object? -> create
7.2. Fills an existing object with data? -> initialize, configure
7.3. Clean raw data? -> clean
7.4. Receive a string as input? -> parse
7.5. Return a string as output? -> render
7.6. Return an iterator as output? ->iter
7.7. Mutates its arguments or some global state? -> update, mutate, add, remove, insert, set
7.8. Return a list of errors? -> validate
7.9. Checks data items recursively? -> walk
7.10. Finds appropriate item in data? -> find, search, match
7.11. Transform data type? -> \\_to_\\
7.12. None of the above, but still works with data? -> Check one of those: morph, compose, prepare, extract, generate, initialize, filter, map, aggregate, export, import, normalize, calculate .
"},{"location":"tutorials/coding_guidelines/#install-package","title":"Install package","text":"We can install the package we are developing by the following command:
pip install -e .\n
It means we are installing it in editable mode. In Python, if you want to be able to edit your package and have the changes be reflected immediately without needing to reinstall the package every time, you can use pip to install the package in \"editable\" mode.
If you are worried about the state of your package affecting other parts of your system or other projects, you might consider using a virtual environment. A virtual environment is an isolated Python environment, separate from your system Python and other virtual environments. You can install your package in a virtual environment and make changes and test without worrying about affecting other projects.
"},{"location":"tutorials/coding_guidelines/#reference","title":"Reference","text":"1(https://ahsmart.com/pub/naming-things-properly/ ).
2(https://melevir.medium.com/python-functions-naming-the-algorithm-74320a18278d)
"},{"location":"tutorials/generate_docs/","title":"How to generate docs automatically","text":"Author: Bang Liu
Date: 2023-08-05
In this document, I will introduce how to automatically generate the documentation for your python project with several tools.
"},{"location":"tutorials/generate_docs/#install-libraries","title":"Install libraries","text":"We use the following python packages:
- MkDocs for building static pages from Markdown
- mkdocstrings for auto-generating documentation from docstrings in your code
- Material for MkDocs for styling your documentation
pip install --upgrade pip\npip install mkdocs\npip install mkdocstrings\npip install mkdocs-material\n
You can install support for specific languages using extras, for example:
pip install 'mkdocstrings[crystal,python]'\n
Note: the support for specific languages are not installed by default, so I would recommend install by the above command.
"},{"location":"tutorials/generate_docs/#create-mkdocs-project","title":"Create mkdocs project","text":"Now assume you are in the root directory of your project:
mkdocs new .\n
You will see:
INFO - Writing config file: ./mkdocs.yml\nINFO - Writing initial docs: ./docs/index.md\n
MkDocs comes with a built-in dev-server that lets you preview your documentation as you work on it. Make sure you're in the same directory as the mkdocs.yml
configuration file, and then start the server by running the mkdocs serve
command:
% mkdocs serve\nINFO - Building documentation...\nINFO - Cleaning site directory\nWARNING - Excluding 'README.md' from the site because it conflicts with\n 'index.md'.\nINFO - Documentation built in 0.08 seconds\nINFO - [14:25:59] Watching paths for changes: 'docs', 'mkdocs.yml'\nINFO - [14:25:59] Serving on http://127.0.0.1:8000/\nINFO - [14:26:11] Browser connected: http://127.0.0.1:8000/\n
Open up http://127.0.0.1:8000/ in your browser, and you'll see the default home page being displayed.
"},{"location":"tutorials/generate_docs/#customize-your-mkdocsyml","title":"Customize your mkdocs.yml","text":"We can customize the style of our documentation. Edit the ./mkdocs.yml file:
site_name: your-project-name\nsite_url: your-project-website\nnav:\n - Home: index.md\ntheme:\n name: \"material\"\n
This way, we can use the material theme. You can also use other themes [1,2].
"},{"location":"tutorials/generate_docs/#add-more-markdown-files-to-the-documentation","title":"Add more markdown files to the documentation","text":"As described in [1], we can follow the structure proposed in the Di\u00e1taxis documentation framework, which suggests splitting your documentation into four distinct parts:
- Tutorials
- How-To Guides
- Reference
- Explanation
Therefore, we can create these markdown files and put them into the ./docs/ folder. Then we edit our mkdocs.yml configuration file to add them:
site_name: your-project-name\nsite_url: your-project-website\n\nnav:\n - index.md\n - tutorials.md\n - how-to-guides.md\n - reference.md\n - explanation.md\n\ntheme:\n name: \"material\"\n
We can also edit the titles for each page, adjust their order, and so on. See [1] for more details.
"},{"location":"tutorials/generate_docs/#generate-document-from-docstrings","title":"Generate document from Docstrings","text":"We need to use mkdocstrings
package for this purpose.
MkDocs is a static-site generator geared toward writing documentation. However, you can\u2019t fetch docstring information from your code using MkDocs alone. You can make it work with an additional package called mkdocstrings.
You already installed mkdocstrings into your virtual environment at the beginning of this tutorial, so you only need to add it as a plugin to your MkDocs configuration file:
site_name: your-project-name\nsite_url: your-project-website\n\nplugins:\n - mkdocstrings\n\nnav:\n - index.md\n - tutorials.md\n - how-to-guides.md\n - reference.md\n - explanation.md\n\ntheme:\n name: \"material\"\n
Now, to generate documentation from soruce code docstrings, we can select a markdown file, e.g., the reference.md file we have created, and put identifiers in it.
Mkdocstrings allows you to insert docstring information right into your Markdown pages using a special syntax of three colons (:::) followed by the code identifier that you want to document:
::: identifier\n
The identifier is a string identifying the object you want to document. The format of an identifier
can vary from one handler to another. For example, the Python handler expects the full dotted-path to a Python object: my_package.my_module.MyClass.my_method
[3].
The syntax to use identifier is:
::: identifier\n YAML block\n
See https://mkdocstrings.github.io/usage/ for more details.
Basically, the YAML block is optional, and contains some configuration options.
For global options, we can put it in mkdocs.yml
. For example:
plugins:\n- mkdocstrings:\n enabled: !ENV [ENABLE_MKDOCSTRINGS, true]\n custom_templates: templates\n default_handler: python\n handlers:\n python:\n options:\n show_source: false\n
And global configurations can be overridden by local configurations.
See [3] for more detailed tutorials. Briefly summarize, with mkdocstrings, we can use identifiers to gather the docstrings in our code and turn them into documentation.
Tips: Maintain a good coding style is very important. I prefer to use the docstring style listed here: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings
"},{"location":"tutorials/generate_docs/#automatically-collect-all-the-docstrings-in-a-module","title":"Automatically collect all the docstrings in a module","text":"To avoid manually write the identifiers for each submodule/class/method in a markdown file to include the corresponding docstrings in our documentation, we can use the following option:
::: src.aeiva.agent\n options:\n show_submodules: true\n
The above example will automatically introduce all the docstrings in the aeiva.agent package into our documentation.
"},{"location":"tutorials/generate_docs/#advanced-theme-customization","title":"Advanced Theme Customization","text":""},{"location":"tutorials/generate_docs/#changing-the-logo-and-icons","title":"Changing the logo and icons","text":"See: https://squidfunk.github.io/mkdocs-material/setup/changing-the-logo-and-icons/
"},{"location":"tutorials/generate_docs/#customize-the-landing-home-page","title":"Customize the landing home page","text":"We can further customize the home page of our documentation.
First, set your custom_dir in mkdocs.yml:
theme:\n custom_dir: docs/overrides\n...\n\n
The above setting use overrides directory in docs/ as the custom directory.
We than copy all the contents in: https://github.com/squidfunk/mkdocs-material/tree/master/src/.overrides to our docs/overrides/
folder.
Next, in the front matter of your index.md, you need to specify the template to use (copy below to index.md):
---\ntitle: Title\ntemplate: home.html\n---\n
One important thing that took me a while to realize: you need a newline at the end of your md file. If you don't have one, the content will not display [6].
Finally, we can customize the home.html
and main.html
in the overrides folder to make it consistent with our project.
See [6] for a reference.
Note: I found the landing page on https://squidfunk.github.io/mkdocs-material/ is really fancy! It is based on Parallax Image Effect using html and css. To DIY the effect, I downloaded the source file of the webpage directly, and then replace all assets/images/layers/
in the html source file with ./Material for MkDocs_files/
. Because this is the only folder I can get with downloading. I haven't done with understanding and customizing the landing homepage based on this template. To be tested in the future. :) (I put this verion in docs/overrides-dev/)
"},{"location":"tutorials/generate_docs/#organize-your-documentation","title":"Organize your documentation","text":""},{"location":"tutorials/generate_docs/#navbar-nesting","title":"Navbar nesting","text":"You can add an additional level to your navbar like this:
nav:\n - Home: index.md\n - About: about.md\n - Foo:\n - Overview: foo/index.md\n - Bar: foo/bar.md\n
"},{"location":"tutorials/generate_docs/#reference-to-another-markdown-file","title":"Reference to another markdown file","text":"In a markdown document, we can refer to another file from one file, like the following:
[How to generate project documentation automatically from docstrings](./GENERATE_DOCS.md)\n
"},{"location":"tutorials/generate_docs/#deploy-your-documentation-to-github","title":"Deploy Your Documentation to GitHub","text":"GitHub repositories automatically serve static content when committed to a branch named gh-pages. MkDocs integrates with that and allows you to build and deploy your project documentation in a single step:
mkdocs gh-deploy\n
Running this command rebuilds the documentation from your Markdown files and source code and pushes it to the gh-pages branch on your remote GitHub repository.
Because of GitHub\u2019s default configuration, that\u2019ll make your documentation available at the URL that MkDocs shows you at the end of your terminal output:
INFO - Your documentation should shortly be available at:\n https://user-name.github.io/project-name/\n
"},{"location":"tutorials/generate_docs/#summarize","title":"Summarize","text":"So we basically follow the following procedures to create our documentation:
- Create virtual env for your project. Create your project. Create your github repository.
- Install the libraries: mkdocs, mkdocstrings, mkdocs-material
- Go to the project root directory.
- Use mkdocs to create the docs. It will produce
mkdocs.yml
and ./docs/index.md
. - Customize the
mkdocs.yml
. Basically, this is the global setting of the documentation. See [2] for details. You can customize your documentation theme to materials
theme that supported by mkdocs-material
python package. - Customize the contents in
./docs/
. Basically, you can create different markdown files here; you can automatically create documentation contents from docstrings of your code by using ::: identifier
that supported by mkdocstrings
. See [4] for details. - Customize the organization of your documentation. For example, you can use nested navigation; you can use cross-reference, etc.
- Build your documentation using ```mkdocs build.
- Host your documentation using
mkdocs gh-deploy
. Your documentation should shortly be available at: https://user-name.github.io/project-name/
.
"},{"location":"tutorials/generate_docs/#more","title":"More","text":"Please read [1,2,3,4] for more detailed tutorials.
"},{"location":"tutorials/generate_docs/#reference","title":"Reference","text":"1(https://realpython.com/python-project-documentation-with-mkdocs/)
2(https://www.mkdocs.org/getting-started/)
3(https://mkdocstrings.github.io/)
4(https://github.com/squidfunk/mkdocs-material)
[5] Di\u00e1taxis A systematic approach to technical documentation authoring.
6(https://github.com/squidfunk/mkdocs-material/issues/1996)
"},{"location":"tutorials/install_minedojo/","title":"Install MineDojo platform on MacBook Pro with M1 Chip","text":"Author: Bang Liu
Date: 2023-08-01
"},{"location":"tutorials/install_minedojo/#setup-java-environment","title":"Setup Java Environment","text":"I followed the instructions on: https://docs.minedojo.org/sections/getting_started/install.html#prerequisites
Specifically, remember to list all installed Java and and export the temurin8 version java:
/usr/libexec/java_home -V\nexport JAVA_HOME=path/to/eclipse/temurin8\n
After run
java -version\n
I got
openjdk version \"1.8.0_332\"\nOpenJDK Runtime Environment (Temurin)(build 1.8.0_332-b09)\nOpenJDK 64-Bit Server VM (Temurin)(build 25.332-b09, mixed mode)\n
"},{"location":"tutorials/install_minedojo/#install-minedojo","title":"Install MineDojo","text":"I used the following command: (Assume Java JDK 8 is already installed)
pip3 install setuptools==65.5.0 pip==21\npip3 install gym==0.21\ngit clone https://github.com/MineDojo/MineDojo && cd MineDojo\npip install -e .\n
Note: I found that at the end, if I install from source, I cannot remove the source directory. So after resolved all the bugs as follows, I reinstalled minedojo via pip in my conda virtual env:
pip install minedojo\n
So I would recommend install via pip rather than from source.
"},{"location":"tutorials/install_minedojo/#debug-experience","title":"Debug experience","text":"There are many different bugs when I try to run
python scripts/validate_install.py\n
Below, I list all the operations I have done.
"},{"location":"tutorials/install_minedojo/#upgraded-gradle","title":"Upgraded gradle","text":"Check the following: https://gradle.org/install/
After installed the new gradle, I got:
>>> gradle -v\n\n------------------------------------------------------------\nGradle 8.2.1\n------------------------------------------------------------\n\nBuild time: 2023-07-10 12:12:35 UTC\nRevision: a38ec64d3c4612da9083cc506a1ccb212afeecaa\n\nKotlin: 1.8.20\nGroovy: 3.0.17\nAnt: Apache Ant(TM) version 1.10.13 compiled on January 4 2023\nJVM: 1.8.0_332 (Temurin 25.332-b09)\nOS: Mac OS X 10.16 x86_64\n\n
"},{"location":"tutorials/install_minedojo/#malmo-errors","title":"Malmo errors","text":"I referred to: https://github.com/MineDojo/MineDojo/issues/32#issuecomment-1237247417 It says:
For Deprecated Gradle feature --> Go to Malmo project download latest prebuild version https://github.com/Microsoft/malmo/releases. Then find and replace the Malmo directory in your python package directory @ xxx/minedojo/sim/Malmo on your computer. (Reminder directory shall keep the same name \"Malmo\")
For \"OpenGL: ERROR RuntimeException: No OpenGL context found in the current thread.\" (X Error & bad value) --> make sure you run sudo apt update && sudo apt upgrade before you compile the minecraft java program as the same problem has been described in https://stackoverflow.com/questions/28867285/lwjgl-reports-that-opengl-is-not-supported-on-a-modern-nvidia-card. This works for me.
Before running python Minedojo code, go xxx/minedojo/sim/Malmo/Minecraft/ where your python put minedojo package and execute ./launchClient.sh (for linux/unix) or .\\launchClient (for windows, there's a launchClient.bat file) and make sure it can run normally before you start with Minedojo.
Specifically, when I try to run ./launchClient.sh, I got error due to tools.jar, so I did the following:
copy tools.jar from \n/Library/Java/JavaVirtualMachines/temurin-8.jdk/Contents/Home/lib\nto\n/Library/Internet Plug-Ins/JavaAppletPlugin.plugin/Contents/Home/lib\n\n>>> sudo copy XXX XXX\npasswd: (For me, it is the same as the passwd when I login to my macbook pro: the name :)\n
Then, it still fail. So I used back the original Malmo in MineDojo installation (i.e., maybe we DON'T need to download latest prebuild version https://github.com/Microsoft/malmo/releases and then find and replace the Malmo directory in your python package directory ).
Now it can run. But still some error due to
raise NoSuchProcess(self.pid, self._name)\npsutil.NoSuchProcess: process no longer exists (pid=50957, name='bash')\n
I removed the
env.close()\n
in the script and it works.
This is not the end of the story: I found the script doesn't always work. Sometimes, I don't need to remvoe the env.close()
and it still works. Sometimes it doesn't work due to errors like
...\n at org.apache.http.impl.DefaultBHttpClientConnection.receiveResponseHeader(DefaultBHttpClientConnection.java:163)\n at org.apache.http.impl.conn.CPoolProxy.receiveResponseHeader(CPoolProxy.java:165)\n at org.apache.http.protocol.HttpRequestExecutor.doReceiveResponse(HttpRequestExecutor.java:273)\n at org.apache.http.protocol.HttpRequestExecutor.execute(HttpRequestExecutor.java:125)\n at org.apache.http.impl.execchain.MainClientExec.createTunnelToTarget(MainClientExec.java:473)\n at org.apache.http.impl.execchain.MainClientExec.establishRoute(MainClientExec.java:398)\n at org.apache.http.impl.execchain.MainClientExec.execute(MainClientExec.java:237)\n at org.apache.http.impl.execchain.ProtocolExec.execute(ProtocolExec.java:185)\n at org.apache.http.impl.execchain.RetryExec.execute(RetryExec.java:89)\n at org.apache.http.impl.execchain.RedirectExec.execute(RedirectExec.java:111)\n at org.apache.http.impl.client.InternalHttpClient.doExecute(InternalHttpClient.java:185)\n at org.apache.http.impl.client.CloseableHttpClient.execute(CloseableHttpClient.java:83)\n at org.gradle.internal.resource.transport.http.HttpClientHelper.performHttpRequest(HttpClientHelper.java:148)\n at org.gradle.internal.resource.transport.http.HttpClientHelper.performHttpRequest(HttpClientHelper.java:126)\n at org.gradle.internal.resource.transport.http.HttpClientHelper.executeGetOrHead(HttpClientHelper.java:103)\n at org.gradle.internal.resource.transport.http.HttpClientHelper.performRequest(HttpClientHelper.java:94)\n ... 171 more\n\n\n* Get more help at https://help.gradle.org\n\nBUILD FAILED in 31s\n\n\nMinecraft process finished unexpectedly. There was an error with Malmo.\n
I suppose it is due to some network connection errors?
Anyway, now it can work.
"}]}
\ No newline at end of file
+{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"]},"docs":[{"location":"","title":"Welcome to Aeiva","text":"Home page...
"},{"location":"blogs/","title":"Blogs","text":"Here we summarize some experience we learned during developing Aeiva.
Thoughts on Several Key Concepts for Agentic Intelligence
"},{"location":"intro/","title":"Introduction","text":""},{"location":"intro/#aeiva-an-evolving-intelligent-virtual-assistant","title":"AEIVA: An Evolving Intelligent Virtual Assistant","text":""},{"location":"intro/#introduction","title":"Introduction","text":"In this project, our objective is to develop a modular and flexible intelligent agent and society system, designed as a virtual assistant capable of performing diverse tasks, learning from data, environment, and interactions, and self-evolving over time. The system will leverage deep learning models, primarily transformers, while also exploring innovative models and learning methods.
Our ultimate goal is to develop a General AI Agent System capable of forming a \u201cgenius society\u201d of AI agents. These agents will:
- Collaboratively address and solve societal challenges across domains.
- Function in diverse environments, from virtual simulations to real-world applications.
- Continuously evolve and improve through self-assessment and adaptation.
- Serve as versatile assistants in various roles, such as AI researchers, software engineers, game players, or digital society members.
Currently, Aeiva supports the following interaction modes:
- Chat in terminal: chat with an agent in the terminal interface
- Chat with Gradio Webui: we developed a gradio web UI interface that allows user to chat with the agent. We plan to support multimodality in the near future.
- Chat with desktop Waifu mode: by combining with our another project
Maid
, we can use our agent as the backend and call it through Maid desktop assistant.
\u2b50\ufe0f Documentation \ud83d\udc49 aeiva documentation
"},{"location":"intro/#key-features","title":"Key Features","text":"Currently, we features with the following functionalities:
- Rich Toolkits: I have implemented a series of different API tools and I'm keep improving the API library.
- Open Operator: By implementing computer-use related tools, aeiva is able to understand and operate user's computer and complete daily tasks. We are keep enhancing the functionality in this part. Note: use this feature with caution!
- Memory Palace: I have designed and implemented a layered memory palace for storaging agent memories. It is flexible and can be customized to represent and query different types of memories.
More functionalities and modules will be implemented gradually. Keep tuned! If you find any errors or bugs, feel free to report by opening an issue, thanks a lot!
"},{"location":"intro/#installation","title":"Installation","text":"To install AEIVA, follow these steps:
"},{"location":"intro/#prerequisites","title":"Prerequisites","text":" Python 3.10
or newer pip
(Python package manager)
"},{"location":"intro/#option-1-install-via-pip-recommended","title":"Option 1: Install via pip
[recommended]","text":"You can easily install vai pip by:
pip install aeiva\n
"},{"location":"intro/#option-2-install-from-repository","title":"Option 2: Install from Repository","text":" -
Clone the AEIVA Repository
First, clone the AEIVA repository to your local machine using Git:
bash git clone https://github.com/chatsci/Aeiva.git cd Aeiva
-
Create a Virtual Environment (Recommended) It's a good practice to create a virtual environment for Python projects. This keeps dependencies required by different projects separate. Use the following command to create a virtual environment with conda
:
bash conda create --name <my-env>
Replace <my-env>
with the name of your environment.
To acivate your env:
bash conda activate <my-env>
For more advanced configurations or options, please check the online document of conda
.
-
Install Dependencies Install all dependencies listed in requirements.txt:
bash pip install -r requirements.txt
-
Install Aeiva Finally, install AEIVA using the setup.py script:
bash python setup.py install
-
Verify Installation To verify that AEIVA has been installed correctly, you can run the following command:
bash python -c \"import aeiva; print(aeiva.__version__)\"
"},{"location":"intro/#dependencies","title":"Dependencies","text":"Our memory module utilizes different types of databases.
-
Vector Database: Our memory module also utilizes vector database. Please install vector database such as milvus
(recommended), chroma
, qdrant
, or weaviate
.
-
Graph Database: Ensure Neo4j is installed and the NEO4J_HOME
environment variable is set.
-
Relational Database: We use sqlite
(recommended) or postgre sql
.
"},{"location":"intro/#commands","title":"Commands","text":"After installing Neo4j and setting the environment variable, follow these steps to run different aeiva chat commands.
"},{"location":"intro/#aeiva-chat-in-terminal-mode","title":"\ud83e\ude84\u2b50Aeiva Chat in Terminal Mode","text":"Run the following command in terminal:
aeiva-chat-terminal --config configs/agent_config.yaml --verbose\n
-
Options:
--config
or -c
: Path to the configuration file (default: configs/agent_config.yaml
). --verbose
or -v
: Enable verbose logging for detailed output.
-
Using the Interface:
- Interact with the chatbot directly in your terminal after running the command. * View Logs:
- Logs are stored at
~/.aeiva/logs/aeiva-chat-terminal.log
. - To monitor logs in real-time, use:
shell tail -f ~/.aeiva/logs/aeiva-chat-terminal.log
You will see your terminal is like below:
"},{"location":"intro/#aeiva-chat-in-gradio-mode","title":"\ud83e\ude84\u2b50Aeiva Chat in Gradio Mode","text":"Run the following command in terminal:
aeiva-chat-gradio --config configs/agent_config.yaml --verbose\n
-
Options:
--config
or -c
: Path to the configuration file (default: configs/agent_config.yaml
). --verbose
or -v
: Enable verbose logging for detailed output.
-
Access the Gradio Interface:
- Open your web browser and navigate to http://localhost:7860.
- Alternatively, use the public URL provided in the terminal output (e.g., https://1b1f89328e57b2f2e1.gradio.live) to access the interface remotely.
- View Logs:
- Logs are stored at
~/.aeiva/logs/aeiva-chat-gradio.log
. - To monitor logs in real-time, use:
shell tail -f ~/.aeiva/logs/aeiva-chat-gradio.log
By visiting the gradio interface, you will see a gradio web-ui like below:
"},{"location":"intro/#aeiva-server","title":"\ud83e\ude84\u2b50Aeiva Server","text":"Run the following command in terminal:
aeiva-server --config configs/agent_config.yaml --host 0.0.0.0 --port 8000 --verbose\n
- Options:
--config
or -c
: Path to the configuration file (default: configs/agent_config.yaml). --host
or -H
: Host address to run the server on (default: 0.0.0.0). --port
or -p
: Port number to run the server on (default: 8000). --verbose
or -v
: Enable verbose logging for detailed output.
- Access the Server:
- Open your web browser and navigate to
http://localhost:8000/docs
to access the interactive API documentation.
- View Logs:
- Logs are stored at
~/.aeiva/logs/aeiva-server.log
. - To monitor logs in real-time, use:
shell tail -f ~/.aeiva/logs/aeiva-server.log
"},{"location":"intro/#maid-chat-your-intelligent-assistant-on-desktop","title":"\ud83e\ude84\u2b50Maid Chat (Your Intelligent Assistant on Desktop!)","text":"Run the following command in terminal to get an animated virtual assisatnt on your deskto that you can talk in voice mode or by typing:
maid-chat --config configs/agent_config.yaml --host 0.0.0.0 --port 8000 --verbose\n
- Options:
--config
or -c
: Path to the configuration file (default: configs/agent_config.yaml
). --host
or -H
: Host address to run the server on (default: 0.0.0.0
). --port
or -p
: Port number to run the server on (default: 8000
). --verbose
or -v
: Enable verbose logging for detailed output.
- Download
Maid.app
: - Download
Maid.app
from here.
-
Set MAID_HOME
Environment Variable:
- Unix/Linux/macOS:
shell export MAID_HOME='/path/to/my/unity.app/Contents/MacOS/Maid - Your Intelligent Waifu !' source ~/.bashrc # or source ~/.zshrc
- Windows (Command Prompt):
shell set MAID_HOME=C:\\path\\to\\my\\unity\\app
- Windows (PowerShell):
shell $env:MAID_HOME = \"C:\\path\\to\\my\\unity\\app\"
Replace /path/to/my/unity/app
or C:\\path\\to\\my\\unity\\app
with the actual path to your Unity application.
-
Using the Interface:
- Interact with the server through the Maid.app Unity application after running the command.
- View Logs:
- Logs are stored at
~/.aeiva/logs/maid-chat.log
. - To monitor logs in real-time, use:
shell tail -f ~/.aeiva/logs/maid-chat.log
- Troubleshooting:
-
Permission Denied Error When Starting Unity Application: If you encounter an error like: Error: Failed to start Unity application: [Errno 13] Permission denied: '/path/to/my/unity/app'
Solution:
-
macOS Users:
- Open System Preferences.
- Navigate to Security & Privacy.
- Click on the Privacy tab.
- Select Accessibility from the sidebar.
- Click the lock icon to make changes and enter your password.
- Click the \"+\" button and add your terminal application (e.g., Terminal, iTerm).
- Ensure that your terminal application is checked, granting it the necessary permissions to run the Unity application.
-
Windows Users:
- Right-click on the Unity application executable.
- Select Properties.
- Go to the Compatibility tab.
- Check Run this program as an administrator.
- Click Apply, then OK.
- Try running the command again.
Ensure that the MAID_HOME
environment variable points to the correct path of your Unity application.
Demo of Maid-chat:
"},{"location":"intro/#citation","title":"Citation","text":"To cite Aeiva in publications, please use the following BibTeX entries.
@misc{bang2024aeiva,\n title={Aeiva: An Evolving Intelligent Virtual Assistant}, \n author={Bang Liu},\n year={2024},\n url={https://github.com/chatsci/Aeiva}\n}\n
"},{"location":"intro/#contact","title":"Contact","text":""},{"location":"reference/","title":"Reference","text":"This part of the project documentation focuses on an information-oriented approach. Use it as a reference for the technical implementation of the Aeiva
project code.
"},{"location":"reference/#aeiva-api-references","title":"Aeiva API references","text":""},{"location":"reference/#src.aeiva.action","title":"action
","text":""},{"location":"reference/#src.aeiva.action.action","title":"action
","text":""},{"location":"reference/#src.aeiva.action.action.Action","title":"Action
","text":" Bases: Step
Represents an action that can be executed, extending from the Step class. An action is a tool with states and state management methods. It can execute functionality.
Source code in src/aeiva/action/action.py
class Action(Step):\n \"\"\"\n Represents an action that can be executed, extending from the Step class.\n An action is a tool with states and state management methods. It can execute functionality.\n \"\"\"\n\n def __init__(self, name: str, params: Dict[str, Any] = None,\n id: str = None, dependent_ids: Optional[List[str]] = None, \n type: Optional[str] = None, description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n super().__init__(name=name, params=params,\n id=id, dependent_ids=dependent_ids,\n type=type, description=description,\n metadata=metadata)\n self.type = \"Action\"\n self.tool = Tool(name)\n self.result = None\n\n def reset(self) -> None:\n \"\"\"\n Resets the step status, making it ready for re-execution.\n \"\"\"\n self.result = None\n self.status = Status.NOT_EXECUTED\n\n async def execute(self, params: Dict[str, Any]) -> Any:\n if self.tool is None:\n raise ValueError(f\"Action {self.id} has no tool assigned for execution.\")\n\n self.start()\n try:\n result = await self.tool.execute(params) # Assuming the tool's execute method is async\n self.end(success=True)\n self.result = result\n return result\n except Exception as e:\n self.end(success=False)\n raise RuntimeError(f\"Action {self.id} failed: {str(e)}\")\n
"},{"location":"reference/#src.aeiva.action.action.Action.reset","title":"reset()
","text":"Resets the step status, making it ready for re-execution.
Source code in src/aeiva/action/action.py
def reset(self) -> None:\n \"\"\"\n Resets the step status, making it ready for re-execution.\n \"\"\"\n self.result = None\n self.status = Status.NOT_EXECUTED\n
"},{"location":"reference/#src.aeiva.action.action_system","title":"action_system
","text":""},{"location":"reference/#src.aeiva.action.action_system.ActionSystem","title":"ActionSystem
","text":"A concrete Action System responsible for translating Plans into executable Skills and managing the execution of Skills.
Source code in src/aeiva/action/action_system.py
class ActionSystem:\n \"\"\"\n A concrete Action System responsible for translating Plans into executable Skills\n and managing the execution of Skills.\n \"\"\"\n\n def __init__(self, config: Dict):\n self.config = config\n self.state = {\n \"current_skill\": None,\n \"execution_status\": \"Not Started\",\n }\n self.tools = []\n self.skill = None\n\n def setup(self) -> None:\n if \"tools\" in self.config.keys():\n for tool_name in self.config[\"tools\"]:\n self.tools.append(Tool.load_tool_schema(tool_name))\n print(\"ActionSystem setup complete.\")\n\n def plan_to_skill(self, plan: Plan) -> Skill:\n actions = []\n\n for task in plan.steps:\n if isinstance(task, Task):\n action = Action(\n name=task.name,\n params=task.params,\n id=task.id,\n dependent_ids=task.dependent_ids,\n type=\"Action\",\n description=task.description,\n metadata=task.metadata\n )\n actions.append(action)\n elif isinstance(task, Plan):\n sub_skill = self.plan_to_skill(task) # Recursively handle sub-plans\n actions.append(sub_skill)\n else:\n raise TypeError(f\"Unexpected step type: {type(task)} in plan {plan.id}\")\n\n if not actions:\n raise ValueError(f\"The plan {plan.id} does not contain any valid actions or sub-plans.\")\n\n return Skill(\n name=plan.name,\n steps=actions,\n id=plan.id,\n dependent_ids=plan.dependent_ids,\n type=\"Skill\",\n description=plan.description,\n metadata=plan.metadata\n )\n\n async def execute(self, plan: Plan) -> None:\n self.state[\"execution_status\"] = \"Executing\"\n\n try:\n self.skill = self.plan_to_skill(plan) \n await self.skill.execute() \n self.state[\"execution_status\"] = \"Completed\" if self.skill.is_successful else \"Failed\"\n except Exception as e:\n self.state[\"execution_status\"] = \"Failed\"\n self.handle_error(e)\n raise # Ensure to re-throw the exception\n\n def handle_error(self, error: Exception) -> None:\n print(f\"ActionSystem encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.action.experience","title":"experience
","text":""},{"location":"reference/#src.aeiva.action.experience.Experience","title":"Experience
","text":" Bases: Procedure
Represents an experience, which is a structured composition of actions. Unlike a skill, an experience cannot be executed until it is validated and transformed into a skill.
Attributes:
Name Type Description owner
str
The person or agent who owns the experience.
reliable
bool
A flag indicating whether the experience is reliable enough to be transformed into a skill.
Source code in src/aeiva/action/experience.py
class Experience(Procedure):\n \"\"\"\n Represents an experience, which is a structured composition of actions.\n Unlike a skill, an experience cannot be executed until it is validated and transformed into a skill.\n\n Attributes:\n owner (str): The person or agent who owns the experience.\n reliable (bool): A flag indicating whether the experience is reliable enough to be transformed into a skill.\n \"\"\"\n\n def __init__(self, name: str, steps: List[Union['Experience', Action]],\n owner: Optional[str] = None, reliable: Optional[bool] = False,\n id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,\n type: Optional[str] = None, description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n \"\"\"\n Initializes a Skill by extending Procedure.\n \"\"\"\n super().__init__(name=name, steps=steps,\n id=id, dependent_ids=dependent_ids,\n type=type, description=description,\n metadata=metadata)\n self.type = \"Experience\"\n self.owner = owner # The owner of the experience\n self.reliable = reliable # Whether the experience can be transformed into a skill. \n # We can use metadata to store some scored and decide whether it is reliable.\n\n @property\n def is_reliable(self) -> bool:\n \"\"\"\n Checks if the experience is reliable enough to be transformed into a skill.\n \"\"\"\n return self.reliable\n\n def mark_reliable(self) -> None:\n \"\"\"\n Marks the experience as reliable, allowing it to be transformed into a skill.\n \"\"\"\n self.reliable = True\n\n def to_skill(self) -> Skill:\n \"\"\"\n Converts this experience into a skill, but only if the experience is marked as reliable.\n If the experience is not reliable, raises a ValueError.\n\n Returns:\n Skill: A new Skill object that is based on the actions from this experience.\n \"\"\"\n if not self.reliable:\n raise ValueError(f\"Experience {self.id} cannot be transformed into a skill because it is not marked as reliable.\")\n\n # Create and return a new Skill instance\n return Skill(\n name=self.name,\n steps=self.steps, # Use the same steps (actions) from the experience\n id=self.id,\n dependent_ids=self.dependent_ids,\n type=\"Skill\",\n description=f\"Skill derived from Experience: {self.description}\", \n metadata=self.metadata\n )\n\n def to_dict(self) -> Dict[str, Any]:\n \"\"\"\n Returns a dictionary representation of the object.\n \"\"\"\n experience_dict = super().to_dict()\n experience_dict.update({\n \"owner\": self.owner,\n \"reliable\": self.reliable,\n })\n return experience_dict\n
"},{"location":"reference/#src.aeiva.action.experience.Experience.is_reliable","title":"is_reliable: bool
property
","text":"Checks if the experience is reliable enough to be transformed into a skill.
"},{"location":"reference/#src.aeiva.action.experience.Experience.__init__","title":"__init__(name, steps, owner=None, reliable=False, id=None, dependent_ids=None, type=None, description=None, metadata=None)
","text":"Initializes a Skill by extending Procedure.
Source code in src/aeiva/action/experience.py
def __init__(self, name: str, steps: List[Union['Experience', Action]],\n owner: Optional[str] = None, reliable: Optional[bool] = False,\n id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,\n type: Optional[str] = None, description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n \"\"\"\n Initializes a Skill by extending Procedure.\n \"\"\"\n super().__init__(name=name, steps=steps,\n id=id, dependent_ids=dependent_ids,\n type=type, description=description,\n metadata=metadata)\n self.type = \"Experience\"\n self.owner = owner # The owner of the experience\n self.reliable = reliable # Whether the experience can be transformed into a skill. \n
"},{"location":"reference/#src.aeiva.action.experience.Experience.mark_reliable","title":"mark_reliable()
","text":"Marks the experience as reliable, allowing it to be transformed into a skill.
Source code in src/aeiva/action/experience.py
def mark_reliable(self) -> None:\n \"\"\"\n Marks the experience as reliable, allowing it to be transformed into a skill.\n \"\"\"\n self.reliable = True\n
"},{"location":"reference/#src.aeiva.action.experience.Experience.to_dict","title":"to_dict()
","text":"Returns a dictionary representation of the object.
Source code in src/aeiva/action/experience.py
def to_dict(self) -> Dict[str, Any]:\n \"\"\"\n Returns a dictionary representation of the object.\n \"\"\"\n experience_dict = super().to_dict()\n experience_dict.update({\n \"owner\": self.owner,\n \"reliable\": self.reliable,\n })\n return experience_dict\n
"},{"location":"reference/#src.aeiva.action.experience.Experience.to_skill","title":"to_skill()
","text":"Converts this experience into a skill, but only if the experience is marked as reliable. If the experience is not reliable, raises a ValueError.
Returns:
Name Type Description Skill
Skill
A new Skill object that is based on the actions from this experience.
Source code in src/aeiva/action/experience.py
def to_skill(self) -> Skill:\n \"\"\"\n Converts this experience into a skill, but only if the experience is marked as reliable.\n If the experience is not reliable, raises a ValueError.\n\n Returns:\n Skill: A new Skill object that is based on the actions from this experience.\n \"\"\"\n if not self.reliable:\n raise ValueError(f\"Experience {self.id} cannot be transformed into a skill because it is not marked as reliable.\")\n\n # Create and return a new Skill instance\n return Skill(\n name=self.name,\n steps=self.steps, # Use the same steps (actions) from the experience\n id=self.id,\n dependent_ids=self.dependent_ids,\n type=\"Skill\",\n description=f\"Skill derived from Experience: {self.description}\", \n metadata=self.metadata\n )\n
"},{"location":"reference/#src.aeiva.action.plan","title":"plan
","text":""},{"location":"reference/#src.aeiva.action.plan.Plan","title":"Plan
","text":" Bases: Procedure
Represents a plan, which is a structured roadmap for achieving a goal by executing tasks and sub-plans. Inherits common functionality from Procedure.
Source code in src/aeiva/action/plan.py
class Plan(Procedure):\n \"\"\"\n Represents a plan, which is a structured roadmap for achieving a goal by executing tasks and sub-plans.\n Inherits common functionality from Procedure.\n \"\"\"\n\n def __init__(self, name: str, steps: List[Union['Plan', Task]],\n id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,\n type: Optional[str] = None, description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n \"\"\"\n Initializes a Skill by extending Procedure.\n \"\"\"\n super().__init__(name=name, steps=steps,\n id=id, dependent_ids=dependent_ids,\n type=type, description=description,\n metadata=metadata)\n self.type = \"Plan\"\n
"},{"location":"reference/#src.aeiva.action.plan.Plan.__init__","title":"__init__(name, steps, id=None, dependent_ids=None, type=None, description=None, metadata=None)
","text":"Initializes a Skill by extending Procedure.
Source code in src/aeiva/action/plan.py
def __init__(self, name: str, steps: List[Union['Plan', Task]],\n id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,\n type: Optional[str] = None, description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n \"\"\"\n Initializes a Skill by extending Procedure.\n \"\"\"\n super().__init__(name=name, steps=steps,\n id=id, dependent_ids=dependent_ids,\n type=type, description=description,\n metadata=metadata)\n self.type = \"Plan\"\n
"},{"location":"reference/#src.aeiva.action.procedure","title":"procedure
","text":""},{"location":"reference/#src.aeiva.action.procedure.Procedure","title":"Procedure
","text":"Abstract base class for composite structures like Plan and Skill. Contains shared attributes and methods for organizing and managing steps (e.g., tasks, sub-procedures) in a directed acyclic graph (DAG).
Source code in src/aeiva/action/procedure.py
class Procedure:\n \"\"\"\n Abstract base class for composite structures like Plan and Skill.\n Contains shared attributes and methods for organizing and managing steps (e.g., tasks, sub-procedures) \n in a directed acyclic graph (DAG).\n \"\"\"\n\n def __init__(self, name: str, steps: List[Union['Procedure', Step]],\n id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,\n type: Optional[str] = None, description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None,\n *args, **kwargs):\n self.name = name\n self.steps = steps\n self.id = id\n self.dependent_ids = dependent_ids or []\n self.type = type\n self.description = description\n self.metadata = metadata or {}\n\n self.graph = nx.DiGraph()\n self.step_map = {step.id: step for step in steps}\n self.status = Status.NOT_EXECUTED\n\n # Add all steps as nodes in the graph\n for step in steps:\n self.graph.add_node(step)\n\n # Handle dependencies for steps\n for step in steps:\n for dep_id in step.dependent_ids:\n if dep_id in self.step_map:\n self.graph.add_edge(self.step_map[dep_id], step)\n else:\n raise ValueError(f\"Dependency {dep_id} not found for step {step.id}.\")\n\n def get_topological_sort(self):\n \"\"\"\n Returns the steps in topologically sorted order based on the dependency graph.\n Ensures that all prerequisite steps are executed before the dependent ones.\n \"\"\"\n try:\n return list(nx.topological_sort(self.graph))\n except nx.NetworkXUnfeasible:\n raise ValueError(\"The dependency graph contains cycles, which is not allowed in a procedure.\")\n\n def reset(self) -> None:\n \"\"\"\n Resets the status of the procedure and all its steps.\n \"\"\"\n self.status = Status.NOT_EXECUTED\n for step in self.steps:\n step.reset()\n\n def start(self) -> None:\n \"\"\"\n Marks the procedure as in progress. Raises an error if it's already in progress or finished.\n \"\"\"\n if self.status != Status.NOT_EXECUTED:\n raise ValueError(f\"{self.type} {self.id} has already been started or finished.\")\n self.status = Status.EXECUTING\n\n def end(self, success: bool) -> None:\n \"\"\"\n Marks the procedure as completed. Raises an error if it hasn't started yet.\n \"\"\"\n if self.status != Status.EXECUTING:\n raise ValueError(f\"Cannot finish {self.type} that hasn't started.\")\n self.status = Status.SUCCESS if success else Status.FAIL\n\n @property\n def is_successful(self) -> bool:\n return all(step.is_successful for step in self.steps)\n\n @property\n def is_failed(self) -> bool:\n return any(step.is_failed for step in self.steps)\n\n @property\n def is_in_progress(self) -> bool:\n return any(step.is_in_progress for step in self.steps)\n\n @property\n def is_not_started(self) -> bool:\n return all(step.is_not_started for step in self.steps)\n\n @property\n def is_finished(self) -> bool:\n return all(step.is_finished for step in self.steps)\n\n def visualize(self, save_path: Optional[str] = None):\n \"\"\"\n Visualizes the procedure's structure using networkx and matplotlib.\n \"\"\"\n pos = nx.spring_layout(self.graph) # Layout for the graph\n labels = {node: f\"{node.id} ({node.description}, {node.status})\" for node in self.graph.nodes()}\n\n # Draw the graph with labels\n nx.draw(self.graph, pos, with_labels=True, labels=labels, node_color='lightblue', node_size=2000, font_size=10, font_weight='bold', arrows=True)\n\n plt.title(f\"{self.type} {self.description} Visualization\")\n if save_path:\n plt.savefig(save_path)\n else:\n plt.show()\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n \"name\": self.name,\n \"steps\": [step.to_dict() for step in self.steps],\n \"id\": self.id,\n \"dependent_ids\": self.dependent_ids,\n \"type\": self.type,\n \"description\": self.description,\n \"metadata\": self.metadata,\n \"status\": self.status\n }\n
"},{"location":"reference/#src.aeiva.action.procedure.Procedure.end","title":"end(success)
","text":"Marks the procedure as completed. Raises an error if it hasn't started yet.
Source code in src/aeiva/action/procedure.py
def end(self, success: bool) -> None:\n \"\"\"\n Marks the procedure as completed. Raises an error if it hasn't started yet.\n \"\"\"\n if self.status != Status.EXECUTING:\n raise ValueError(f\"Cannot finish {self.type} that hasn't started.\")\n self.status = Status.SUCCESS if success else Status.FAIL\n
"},{"location":"reference/#src.aeiva.action.procedure.Procedure.get_topological_sort","title":"get_topological_sort()
","text":"Returns the steps in topologically sorted order based on the dependency graph. Ensures that all prerequisite steps are executed before the dependent ones.
Source code in src/aeiva/action/procedure.py
def get_topological_sort(self):\n \"\"\"\n Returns the steps in topologically sorted order based on the dependency graph.\n Ensures that all prerequisite steps are executed before the dependent ones.\n \"\"\"\n try:\n return list(nx.topological_sort(self.graph))\n except nx.NetworkXUnfeasible:\n raise ValueError(\"The dependency graph contains cycles, which is not allowed in a procedure.\")\n
"},{"location":"reference/#src.aeiva.action.procedure.Procedure.reset","title":"reset()
","text":"Resets the status of the procedure and all its steps.
Source code in src/aeiva/action/procedure.py
def reset(self) -> None:\n \"\"\"\n Resets the status of the procedure and all its steps.\n \"\"\"\n self.status = Status.NOT_EXECUTED\n for step in self.steps:\n step.reset()\n
"},{"location":"reference/#src.aeiva.action.procedure.Procedure.start","title":"start()
","text":"Marks the procedure as in progress. Raises an error if it's already in progress or finished.
Source code in src/aeiva/action/procedure.py
def start(self) -> None:\n \"\"\"\n Marks the procedure as in progress. Raises an error if it's already in progress or finished.\n \"\"\"\n if self.status != Status.NOT_EXECUTED:\n raise ValueError(f\"{self.type} {self.id} has already been started or finished.\")\n self.status = Status.EXECUTING\n
"},{"location":"reference/#src.aeiva.action.procedure.Procedure.visualize","title":"visualize(save_path=None)
","text":"Visualizes the procedure's structure using networkx and matplotlib.
Source code in src/aeiva/action/procedure.py
def visualize(self, save_path: Optional[str] = None):\n \"\"\"\n Visualizes the procedure's structure using networkx and matplotlib.\n \"\"\"\n pos = nx.spring_layout(self.graph) # Layout for the graph\n labels = {node: f\"{node.id} ({node.description}, {node.status})\" for node in self.graph.nodes()}\n\n # Draw the graph with labels\n nx.draw(self.graph, pos, with_labels=True, labels=labels, node_color='lightblue', node_size=2000, font_size=10, font_weight='bold', arrows=True)\n\n plt.title(f\"{self.type} {self.description} Visualization\")\n if save_path:\n plt.savefig(save_path)\n else:\n plt.show()\n
"},{"location":"reference/#src.aeiva.action.skill","title":"skill
","text":""},{"location":"reference/#src.aeiva.action.skill.Skill","title":"Skill
","text":" Bases: Procedure
Represents a skill, which is a structured roadmap for executing actions. Skills are composed of actions and can be executed. Inherits common functionality from Procedure.
Source code in src/aeiva/action/skill.py
class Skill(Procedure):\n \"\"\"\n Represents a skill, which is a structured roadmap for executing actions.\n Skills are composed of actions and can be executed.\n Inherits common functionality from Procedure.\n \"\"\"\n\n def __init__(self, name: str, steps: List[Union['Skill', Action]],\n id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,\n type: Optional[str] = None, description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n \"\"\"\n Initializes a Skill by extending Procedure.\n \"\"\"\n super().__init__(name=name, steps=steps,\n id=id, dependent_ids=dependent_ids,\n type=type, description=description,\n metadata=metadata)\n self.type = \"Skill\"\n\n def get_topological_sort(self):\n \"\"\"\n Returns the steps in topologically sorted order based on the dependency graph.\n Ensures that all prerequisite steps are executed before the dependent ones.\n \"\"\"\n return list(nx.topological_sort(self.graph))\n\n async def execute(self):\n \"\"\"\n Executes all actions in the skill based on the dependencies defined in the graph.\n This will run the actions asynchronously, respecting their dependencies.\n \"\"\"\n self.start()\n\n # Perform topological sort right before execution\n sorted_steps = self.get_topological_sort()\n\n for step in sorted_steps:\n if isinstance(step, Action):\n print(f\"Executing Action: {step.id} - {step.description}\")\n await step.execute(step.params) # Execute the action asynchronously\n elif isinstance(step, Skill):\n print(f\"Executing Sub-Skill: {step.id}\")\n await step.execute() # If it's a sub-skill, execute the sub-skill\n\n self.end(success=self.is_successful)\n
"},{"location":"reference/#src.aeiva.action.skill.Skill.__init__","title":"__init__(name, steps, id=None, dependent_ids=None, type=None, description=None, metadata=None)
","text":"Initializes a Skill by extending Procedure.
Source code in src/aeiva/action/skill.py
def __init__(self, name: str, steps: List[Union['Skill', Action]],\n id: Optional[str] = None, dependent_ids: Optional[List[str]] = None,\n type: Optional[str] = None, description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n \"\"\"\n Initializes a Skill by extending Procedure.\n \"\"\"\n super().__init__(name=name, steps=steps,\n id=id, dependent_ids=dependent_ids,\n type=type, description=description,\n metadata=metadata)\n self.type = \"Skill\"\n
"},{"location":"reference/#src.aeiva.action.skill.Skill.execute","title":"execute()
async
","text":"Executes all actions in the skill based on the dependencies defined in the graph. This will run the actions asynchronously, respecting their dependencies.
Source code in src/aeiva/action/skill.py
async def execute(self):\n \"\"\"\n Executes all actions in the skill based on the dependencies defined in the graph.\n This will run the actions asynchronously, respecting their dependencies.\n \"\"\"\n self.start()\n\n # Perform topological sort right before execution\n sorted_steps = self.get_topological_sort()\n\n for step in sorted_steps:\n if isinstance(step, Action):\n print(f\"Executing Action: {step.id} - {step.description}\")\n await step.execute(step.params) # Execute the action asynchronously\n elif isinstance(step, Skill):\n print(f\"Executing Sub-Skill: {step.id}\")\n await step.execute() # If it's a sub-skill, execute the sub-skill\n\n self.end(success=self.is_successful)\n
"},{"location":"reference/#src.aeiva.action.skill.Skill.get_topological_sort","title":"get_topological_sort()
","text":"Returns the steps in topologically sorted order based on the dependency graph. Ensures that all prerequisite steps are executed before the dependent ones.
Source code in src/aeiva/action/skill.py
def get_topological_sort(self):\n \"\"\"\n Returns the steps in topologically sorted order based on the dependency graph.\n Ensures that all prerequisite steps are executed before the dependent ones.\n \"\"\"\n return list(nx.topological_sort(self.graph))\n
"},{"location":"reference/#src.aeiva.action.status","title":"status
","text":""},{"location":"reference/#src.aeiva.action.status.Status","title":"Status
","text":"A class to hold status constants.
Source code in src/aeiva/action/status.py
class Status:\n \"\"\"\n A class to hold status constants.\n \"\"\"\n NOT_EXECUTED = \"Not Executed\"\n EXECUTING = \"Executing\"\n SUCCESS = \"Success\"\n FAIL = \"Fail\"\n
"},{"location":"reference/#src.aeiva.action.step","title":"step
","text":""},{"location":"reference/#src.aeiva.action.step.Step","title":"Step
","text":"Abstract base class for atomic units like Task and Action. Contains shared attributes and methods for managing their execution and dependencies.
Source code in src/aeiva/action/step.py
class Step:\n \"\"\"\n Abstract base class for atomic units like Task and Action.\n Contains shared attributes and methods for managing their execution and dependencies.\n \"\"\"\n\n def __init__(self, name: str, params: Dict[str, Any] = None,\n id: Optional[str] = None, dependent_ids: Optional[List[str]] = None, \n type: Optional[str] = None, description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None,\n *args, **kwargs):\n self.name = name # The name of the step. It can be a task/action/tool/api/function name\n self.params = params # The parameters for this step. it can be a task/action/tool/api/function's params\n self.id = id # Unique identifier for the step\n self.dependent_ids = dependent_ids or [] # List of IDs of steps that must be completed before this one\n self.type = type # The type of this step, e.g., task or action\n self.description = description # A description for this step\n self.metadata = metadata or {} # Optional metadata (e.g., id, type, description, priority, etc.)\n self.status = Status.NOT_EXECUTED # Initial status\n\n def reset(self) -> None:\n \"\"\"\n Resets the step status, making it ready for re-execution.\n \"\"\"\n self.status = Status.NOT_EXECUTED\n\n def start(self) -> None:\n \"\"\"\n Marks the step as in progress. Raises an error if the step is already started or finished.\n \"\"\"\n if self.status != Status.NOT_EXECUTED:\n raise ValueError(f\"{self.type} {self.description} {self.id} has already been started or finished.\")\n self.status = Status.EXECUTING\n\n def end(self, success: bool) -> None:\n \"\"\"\n Marks the step as finished and indicates whether it was successful.\n Can only be called if the step is in progress.\n \"\"\"\n if self.status != Status.EXECUTING:\n raise ValueError(f\"Cannot finish a {self.type} that hasn't started.\")\n self.status = Status.SUCCESS if success else Status.FAIL\n\n @property\n def is_successful(self) -> bool:\n \"\"\"\n Returns True if the step was completed successfully.\n \"\"\"\n return self.status == Status.SUCCESS\n\n @property\n def is_failed(self) -> bool:\n \"\"\"\n Returns True if the step has finished but failed.\n \"\"\"\n return self.status == Status.FAIL\n\n @property\n def is_in_progress(self) -> bool:\n \"\"\"\n Returns True if the step is in progress (executing but not finished).\n \"\"\"\n return self.status == Status.EXECUTING\n\n @property\n def is_not_started(self) -> bool:\n \"\"\"\n Returns True if the step has not started yet.\n \"\"\"\n return self.status == Status.NOT_EXECUTED\n\n @property\n def is_finished(self) -> bool:\n \"\"\"\n Returns True if the step has finished execution, either successfully or failed.\n \"\"\"\n return self.status == Status.SUCCESS or self.status == Status.FAIL\n\n def to_dict(self) -> Dict[str, Any]:\n \"\"\"\n Converts the step into a dictionary representation.\n \"\"\"\n return {\n \"name\": self.name,\n \"params\": self.params,\n \"id\": self.id,\n \"dependent_ids\": self.dependent_ids,\n \"type\": self.type,\n \"description\": self.description,\n \"status\": self.status,\n \"metadata\": self.metadata\n }\n
"},{"location":"reference/#src.aeiva.action.step.Step.is_failed","title":"is_failed: bool
property
","text":"Returns True if the step has finished but failed.
"},{"location":"reference/#src.aeiva.action.step.Step.is_finished","title":"is_finished: bool
property
","text":"Returns True if the step has finished execution, either successfully or failed.
"},{"location":"reference/#src.aeiva.action.step.Step.is_in_progress","title":"is_in_progress: bool
property
","text":"Returns True if the step is in progress (executing but not finished).
"},{"location":"reference/#src.aeiva.action.step.Step.is_not_started","title":"is_not_started: bool
property
","text":"Returns True if the step has not started yet.
"},{"location":"reference/#src.aeiva.action.step.Step.is_successful","title":"is_successful: bool
property
","text":"Returns True if the step was completed successfully.
"},{"location":"reference/#src.aeiva.action.step.Step.end","title":"end(success)
","text":"Marks the step as finished and indicates whether it was successful. Can only be called if the step is in progress.
Source code in src/aeiva/action/step.py
def end(self, success: bool) -> None:\n \"\"\"\n Marks the step as finished and indicates whether it was successful.\n Can only be called if the step is in progress.\n \"\"\"\n if self.status != Status.EXECUTING:\n raise ValueError(f\"Cannot finish a {self.type} that hasn't started.\")\n self.status = Status.SUCCESS if success else Status.FAIL\n
"},{"location":"reference/#src.aeiva.action.step.Step.reset","title":"reset()
","text":"Resets the step status, making it ready for re-execution.
Source code in src/aeiva/action/step.py
def reset(self) -> None:\n \"\"\"\n Resets the step status, making it ready for re-execution.\n \"\"\"\n self.status = Status.NOT_EXECUTED\n
"},{"location":"reference/#src.aeiva.action.step.Step.start","title":"start()
","text":"Marks the step as in progress. Raises an error if the step is already started or finished.
Source code in src/aeiva/action/step.py
def start(self) -> None:\n \"\"\"\n Marks the step as in progress. Raises an error if the step is already started or finished.\n \"\"\"\n if self.status != Status.NOT_EXECUTED:\n raise ValueError(f\"{self.type} {self.description} {self.id} has already been started or finished.\")\n self.status = Status.EXECUTING\n
"},{"location":"reference/#src.aeiva.action.step.Step.to_dict","title":"to_dict()
","text":"Converts the step into a dictionary representation.
Source code in src/aeiva/action/step.py
def to_dict(self) -> Dict[str, Any]:\n \"\"\"\n Converts the step into a dictionary representation.\n \"\"\"\n return {\n \"name\": self.name,\n \"params\": self.params,\n \"id\": self.id,\n \"dependent_ids\": self.dependent_ids,\n \"type\": self.type,\n \"description\": self.description,\n \"status\": self.status,\n \"metadata\": self.metadata\n }\n
"},{"location":"reference/#src.aeiva.action.task","title":"task
","text":""},{"location":"reference/#src.aeiva.action.task.Task","title":"Task
","text":" Bases: Step
Represents the fundamental unit of work, extending from the Step class. Inherits shared attributes and methods from Step and adds task-specific functionality.
Source code in src/aeiva/action/task.py
class Task(Step):\n \"\"\"\n Represents the fundamental unit of work, extending from the Step class.\n Inherits shared attributes and methods from Step and adds task-specific functionality.\n \"\"\"\n\n def __init__(self, name: str, params: Dict[str, Any] = None,\n id: Optional[str] = None, dependent_ids: Optional[List[str]] = None, \n type: Optional[str] = None, description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n super().__init__(name=name, params=params,\n id=id, dependent_ids=dependent_ids,\n type=type, description=description,\n metadata=metadata)\n self.type = \"Task\"\n\n def show(self) -> None:\n print(\"---- Task Information ----\")\n pprint(self.to_dict(), sort_dicts=False)\n print(\"---- End of Task ----\")\n
"},{"location":"reference/#src.aeiva.agent","title":"agent
","text":""},{"location":"reference/#src.aeiva.agent.agent","title":"agent
","text":""},{"location":"reference/#src.aeiva.agent.agent.Agent","title":"Agent
","text":"Represents the agent that integrates perception, cognition, and action systems.
Source code in src/aeiva/agent/agent.py
class Agent:\n \"\"\"\n Represents the agent that integrates perception, cognition, and action systems.\n \"\"\"\n def __init__(self, config: Dict):\n self.config_dict = config\n self.config = None\n self.event_bus = EventBus()\n self.perception_system = None\n self.cognition_system = None\n self.action_system = None\n\n def setup(self) -> None:\n \"\"\"\n Set up all systems.\n \"\"\"\n perception_config = self.config_dict.get('perception_config', {})\n cognition_config = self.config_dict # NOTE: we didn't define a cognition config class yet.\n action_config = self.config_dict.get('action_config', {})\n\n self.perception_system = PerceptionSystem(perception_config, self.event_bus)\n self.cognition_system = CognitionSystem(cognition_config)\n self.action_system = ActionSystem(action_config)\n\n self.perception_system.setup()\n self.cognition_system.setup()\n self.action_system.setup()\n\n async def run(self) -> None:\n \"\"\"\n Run the agent by connecting perception, cognition, and action systems using the event bus.\n \"\"\"\n # Start the event bus within the running event loop\n self.event_bus.start()\n # Assign the current running loop to the EventBus\n self.event_bus.loop = asyncio.get_running_loop()\n # Set up event handlers\n self.setup_event_handlers()\n # Start the perception system\n await self.perception_system.start()\n\n # Keep the event loop running until interrupted\n try:\n while True:\n await asyncio.sleep(1)\n except KeyboardInterrupt:\n # Handle graceful shutdown\n self.perception_system.stop()\n await self.event_bus.wait_until_all_events_processed()\n self.event_bus.stop()\n except asyncio.CancelledError:\n pass\n except Exception as e:\n # logger.error(f\"Unexpected error in agent run loop: {e}\")\n print(f\"Unexpected error in agent run loop: {e}\", flush=True)\n await self.perception_system.stop()\n await self.event_bus.wait_until_all_events_processed()\n self.event_bus.stop()\n\n async def process_input(self, input_text: str) -> str:\n \"\"\"\n Process input text and return the agent's response.\n \"\"\"\n stream = self.config_dict.get(\"llm_gateway_config\").get(\"llm_stream\")\n use_async = self.config_dict.get(\"llm_gateway_config\").get(\"llm_use_async\")\n stimuli = Stimuli(signals=[Signal(data=input_text, modularity='text')])\n output = \"\"\n try:\n response_gen = self.cognition_system.think(stimuli, tools=self.action_system.tools, stream=stream, use_async=use_async)\n async for chunk in response_gen:\n if isinstance(chunk, str):\n # For streaming chunks\n output += chunk\n elif isinstance(chunk, Thought) or isinstance(chunk, Plan):\n # For non-streaming responses\n output += chunk.content\n except Exception as e:\n logger.error(f\"Error in response: {e}\")\n return output\n\n def setup_event_handlers(self) -> None:\n \"\"\"\n Set up event handlers for perception, cognition, and action events.\n \"\"\"\n\n @self.event_bus.on('perception.stimuli')\n async def handle_stimuli(event: Event):\n # print(\"handle_stimuli called\", flush=True)\n user_input = event.payload\n stimuli = Stimuli(signals=[Signal(data=user_input, modularity='text')])\n #print(f\"Received stimuli: {stimuli}\", flush=True)\n # Process stimuli through cognition system\n #stimuli = [{\"role\": \"user\", \"content\": stimuli}]\n\n stream = self.config_dict.get(\"llm_gateway_config\").get(\"llm_stream\")\n use_async = self.config_dict.get(\"llm_gateway_config\").get(\"llm_use_async\")\n sys.stdout.write(\"\\r\\033[K\") # Return to start of the line and clear it\\\n print(\"Response: \", end='', flush=True)\n\n try:\n response_gen = self.cognition_system.think(stimuli, tools=self.action_system.tools, stream=stream, use_async=use_async)\n async for chunk in response_gen:\n if isinstance(chunk, str):\n # For streaming chunks\n print(f\"{chunk}\", end='', flush=True)\n elif isinstance(chunk, Thought) or isinstance(chunk, Plan):\n # For non-streaming responses\n print(f\"{chunk.content}\", end='', flush=True)\n except Exception as e:\n logger.error(f\"Error in response: {e}\")\n\n print(\"\\nYou: \", end='', flush=True)\n\n # # Determine if output is a Plan or Thought\n # if isinstance(output, Plan): # TODO: change later\n # print(\"Output is a Plan\", flush=True)\n # await self.event_bus.emit('action.plan', payload=output)\n # elif isinstance(output, Thought):\n # print(\"Output is a Thought\", flush=True)\n # print(f\"Agent Response: {output.content}\", flush=True)\n # else:\n # print(\"Unknown output from cognition system.\", flush=True)\n\n @self.event_bus.on('action.plan')\n async def handle_plan(event: Event):\n print(\"handle_plan called\", flush=True)\n plan = event.payload\n await self.action_system.execute(plan)\n\n @self.event_bus.on('perception.gradio')\n async def handle_gradio_input(event: Event):\n \"\"\"\n Handle input from Gradio and emit response.gradio events.\n \"\"\"\n user_input = event.payload\n stimuli = Stimuli(signals=[Signal(data=user_input, modularity='text')])\n\n stream = self.config_dict.get(\"llm_gateway_config\").get(\"llm_stream\")\n use_async = self.config_dict.get(\"llm_gateway_config\").get(\"llm_use_async\")\n logger.info(f\"Handling Gradio input: {user_input} | Stream: {stream}\")\n try:\n response_gen = self.cognition_system.think(stimuli, tools=self.action_system.tools, stream=stream, use_async=use_async)\n\n async for chunk in response_gen:\n if isinstance(chunk, str):\n # For streaming chunks\n await self.event_bus.emit('response.gradio', payload=chunk)\n elif isinstance(chunk, Thought) or isinstance(chunk, Plan):\n # For non-streaming responses\n await self.event_bus.emit('response.gradio', payload=chunk.content if hasattr(chunk, 'content') else str(chunk))\n\n if stream:\n await self.event_bus.emit('response.gradio', payload=\"<END_OF_RESPONSE>\")\n except Exception as e:\n logger.error(f\"Error in streaming response: {e}\")\n await self.event_bus.emit('response.gradio', payload=\"An error occurred during response generation.\")\n if stream:\n await self.event_bus.emit('response.gradio', payload=\"<END_OF_RESPONSE>\")\n
"},{"location":"reference/#src.aeiva.agent.agent.Agent.process_input","title":"process_input(input_text)
async
","text":"Process input text and return the agent's response.
Source code in src/aeiva/agent/agent.py
async def process_input(self, input_text: str) -> str:\n \"\"\"\n Process input text and return the agent's response.\n \"\"\"\n stream = self.config_dict.get(\"llm_gateway_config\").get(\"llm_stream\")\n use_async = self.config_dict.get(\"llm_gateway_config\").get(\"llm_use_async\")\n stimuli = Stimuli(signals=[Signal(data=input_text, modularity='text')])\n output = \"\"\n try:\n response_gen = self.cognition_system.think(stimuli, tools=self.action_system.tools, stream=stream, use_async=use_async)\n async for chunk in response_gen:\n if isinstance(chunk, str):\n # For streaming chunks\n output += chunk\n elif isinstance(chunk, Thought) or isinstance(chunk, Plan):\n # For non-streaming responses\n output += chunk.content\n except Exception as e:\n logger.error(f\"Error in response: {e}\")\n return output\n
"},{"location":"reference/#src.aeiva.agent.agent.Agent.run","title":"run()
async
","text":"Run the agent by connecting perception, cognition, and action systems using the event bus.
Source code in src/aeiva/agent/agent.py
async def run(self) -> None:\n \"\"\"\n Run the agent by connecting perception, cognition, and action systems using the event bus.\n \"\"\"\n # Start the event bus within the running event loop\n self.event_bus.start()\n # Assign the current running loop to the EventBus\n self.event_bus.loop = asyncio.get_running_loop()\n # Set up event handlers\n self.setup_event_handlers()\n # Start the perception system\n await self.perception_system.start()\n\n # Keep the event loop running until interrupted\n try:\n while True:\n await asyncio.sleep(1)\n except KeyboardInterrupt:\n # Handle graceful shutdown\n self.perception_system.stop()\n await self.event_bus.wait_until_all_events_processed()\n self.event_bus.stop()\n except asyncio.CancelledError:\n pass\n except Exception as e:\n # logger.error(f\"Unexpected error in agent run loop: {e}\")\n print(f\"Unexpected error in agent run loop: {e}\", flush=True)\n await self.perception_system.stop()\n await self.event_bus.wait_until_all_events_processed()\n self.event_bus.stop()\n
"},{"location":"reference/#src.aeiva.agent.agent.Agent.setup","title":"setup()
","text":"Set up all systems.
Source code in src/aeiva/agent/agent.py
def setup(self) -> None:\n \"\"\"\n Set up all systems.\n \"\"\"\n perception_config = self.config_dict.get('perception_config', {})\n cognition_config = self.config_dict # NOTE: we didn't define a cognition config class yet.\n action_config = self.config_dict.get('action_config', {})\n\n self.perception_system = PerceptionSystem(perception_config, self.event_bus)\n self.cognition_system = CognitionSystem(cognition_config)\n self.action_system = ActionSystem(action_config)\n\n self.perception_system.setup()\n self.cognition_system.setup()\n self.action_system.setup()\n
"},{"location":"reference/#src.aeiva.agent.agent.Agent.setup_event_handlers","title":"setup_event_handlers()
","text":"Set up event handlers for perception, cognition, and action events.
Source code in src/aeiva/agent/agent.py
def setup_event_handlers(self) -> None:\n \"\"\"\n Set up event handlers for perception, cognition, and action events.\n \"\"\"\n\n @self.event_bus.on('perception.stimuli')\n async def handle_stimuli(event: Event):\n # print(\"handle_stimuli called\", flush=True)\n user_input = event.payload\n stimuli = Stimuli(signals=[Signal(data=user_input, modularity='text')])\n #print(f\"Received stimuli: {stimuli}\", flush=True)\n # Process stimuli through cognition system\n #stimuli = [{\"role\": \"user\", \"content\": stimuli}]\n\n stream = self.config_dict.get(\"llm_gateway_config\").get(\"llm_stream\")\n use_async = self.config_dict.get(\"llm_gateway_config\").get(\"llm_use_async\")\n sys.stdout.write(\"\\r\\033[K\") # Return to start of the line and clear it\\\n print(\"Response: \", end='', flush=True)\n\n try:\n response_gen = self.cognition_system.think(stimuli, tools=self.action_system.tools, stream=stream, use_async=use_async)\n async for chunk in response_gen:\n if isinstance(chunk, str):\n # For streaming chunks\n print(f\"{chunk}\", end='', flush=True)\n elif isinstance(chunk, Thought) or isinstance(chunk, Plan):\n # For non-streaming responses\n print(f\"{chunk.content}\", end='', flush=True)\n except Exception as e:\n logger.error(f\"Error in response: {e}\")\n\n print(\"\\nYou: \", end='', flush=True)\n\n # # Determine if output is a Plan or Thought\n # if isinstance(output, Plan): # TODO: change later\n # print(\"Output is a Plan\", flush=True)\n # await self.event_bus.emit('action.plan', payload=output)\n # elif isinstance(output, Thought):\n # print(\"Output is a Thought\", flush=True)\n # print(f\"Agent Response: {output.content}\", flush=True)\n # else:\n # print(\"Unknown output from cognition system.\", flush=True)\n\n @self.event_bus.on('action.plan')\n async def handle_plan(event: Event):\n print(\"handle_plan called\", flush=True)\n plan = event.payload\n await self.action_system.execute(plan)\n\n @self.event_bus.on('perception.gradio')\n async def handle_gradio_input(event: Event):\n \"\"\"\n Handle input from Gradio and emit response.gradio events.\n \"\"\"\n user_input = event.payload\n stimuli = Stimuli(signals=[Signal(data=user_input, modularity='text')])\n\n stream = self.config_dict.get(\"llm_gateway_config\").get(\"llm_stream\")\n use_async = self.config_dict.get(\"llm_gateway_config\").get(\"llm_use_async\")\n logger.info(f\"Handling Gradio input: {user_input} | Stream: {stream}\")\n try:\n response_gen = self.cognition_system.think(stimuli, tools=self.action_system.tools, stream=stream, use_async=use_async)\n\n async for chunk in response_gen:\n if isinstance(chunk, str):\n # For streaming chunks\n await self.event_bus.emit('response.gradio', payload=chunk)\n elif isinstance(chunk, Thought) or isinstance(chunk, Plan):\n # For non-streaming responses\n await self.event_bus.emit('response.gradio', payload=chunk.content if hasattr(chunk, 'content') else str(chunk))\n\n if stream:\n await self.event_bus.emit('response.gradio', payload=\"<END_OF_RESPONSE>\")\n except Exception as e:\n logger.error(f\"Error in streaming response: {e}\")\n await self.event_bus.emit('response.gradio', payload=\"An error occurred during response generation.\")\n if stream:\n await self.event_bus.emit('response.gradio', payload=\"<END_OF_RESPONSE>\")\n
"},{"location":"reference/#src.aeiva.agent.base_agent","title":"base_agent
","text":""},{"location":"reference/#src.aeiva.agent.base_agent.BaseAgent","title":"BaseAgent
","text":" Bases: ABC
Abstract base class for autonomous agents with perception, cognition, and action capabilities.
Source code in src/aeiva/agent/base_agent.py
class BaseAgent(ABC):\n \"\"\"\n Abstract base class for autonomous agents with perception, cognition, and action capabilities.\n \"\"\"\n\n def __init__(self, config: Any):\n \"\"\"\n Initialize the agent with configuration.\n\n Args:\n config (Any): Configuration settings for the agent.\n \"\"\"\n self.config = config\n self.state = self.initialize_state() # can be a dict that includes: id, profile, motivation, goal, task, plan, etc.\n self.stop_event = asyncio.Event()\n\n # Systems will be initialized in the setup method\n self.perception_system: PerceptionSystem = None\n self.cognition_system: CognitionSystem = None\n self.action_system: ActionSystem = None\n\n @abstractmethod\n def initialize_state(self) -> Any:\n \"\"\"\n Initialize the agent's state.\n\n Returns:\n Any: The initial state of the agent.\n \"\"\"\n pass\n\n @abstractmethod\n def setup(self) -> None:\n \"\"\"\n Set up the agent's components (perception, cognition, action, etc.).\n Perform any asynchronous initialization if necessary.\n \"\"\"\n pass\n\n @abstractmethod\n async def cycle(self) -> None:\n \"\"\"\n Execute one cycle of perception, cognition, and action.\n This method should be overridden to define the agent's behavior per cycle.\n \"\"\"\n pass\n\n async def run(self) -> None:\n \"\"\"\n Run the agent, continuously executing cycles until stopped.\n \"\"\"\n await self.setup()\n cycle_interval = self.config.get('cycle_interval', 1.0)\n while not self.stop_event.is_set():\n try:\n await self.cycle()\n except Exception as e:\n self.handle_error(e)\n await asyncio.sleep(cycle_interval)\n\n def stop(self) -> None:\n \"\"\"\n Signal the agent to stop running.\n \"\"\"\n self.stop_event.set()\n\n def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during cycle execution.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n # Implement your error handling logic here (e.g., logging)\n print(f\"Error during agent cycle: {error}\")\n
"},{"location":"reference/#src.aeiva.agent.base_agent.BaseAgent.__init__","title":"__init__(config)
","text":"Initialize the agent with configuration.
Parameters:
Name Type Description Default config
Any
Configuration settings for the agent.
required Source code in src/aeiva/agent/base_agent.py
def __init__(self, config: Any):\n \"\"\"\n Initialize the agent with configuration.\n\n Args:\n config (Any): Configuration settings for the agent.\n \"\"\"\n self.config = config\n self.state = self.initialize_state() # can be a dict that includes: id, profile, motivation, goal, task, plan, etc.\n self.stop_event = asyncio.Event()\n\n # Systems will be initialized in the setup method\n self.perception_system: PerceptionSystem = None\n self.cognition_system: CognitionSystem = None\n self.action_system: ActionSystem = None\n
"},{"location":"reference/#src.aeiva.agent.base_agent.BaseAgent.cycle","title":"cycle()
abstractmethod
async
","text":"Execute one cycle of perception, cognition, and action. This method should be overridden to define the agent's behavior per cycle.
Source code in src/aeiva/agent/base_agent.py
@abstractmethod\nasync def cycle(self) -> None:\n \"\"\"\n Execute one cycle of perception, cognition, and action.\n This method should be overridden to define the agent's behavior per cycle.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.agent.base_agent.BaseAgent.handle_error","title":"handle_error(error)
","text":"Handle errors that occur during cycle execution.
Parameters:
Name Type Description Default error
Exception
The exception that was raised.
required Source code in src/aeiva/agent/base_agent.py
def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during cycle execution.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n # Implement your error handling logic here (e.g., logging)\n print(f\"Error during agent cycle: {error}\")\n
"},{"location":"reference/#src.aeiva.agent.base_agent.BaseAgent.initialize_state","title":"initialize_state()
abstractmethod
","text":"Initialize the agent's state.
Returns:
Name Type Description Any
Any
The initial state of the agent.
Source code in src/aeiva/agent/base_agent.py
@abstractmethod\ndef initialize_state(self) -> Any:\n \"\"\"\n Initialize the agent's state.\n\n Returns:\n Any: The initial state of the agent.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.agent.base_agent.BaseAgent.run","title":"run()
async
","text":"Run the agent, continuously executing cycles until stopped.
Source code in src/aeiva/agent/base_agent.py
async def run(self) -> None:\n \"\"\"\n Run the agent, continuously executing cycles until stopped.\n \"\"\"\n await self.setup()\n cycle_interval = self.config.get('cycle_interval', 1.0)\n while not self.stop_event.is_set():\n try:\n await self.cycle()\n except Exception as e:\n self.handle_error(e)\n await asyncio.sleep(cycle_interval)\n
"},{"location":"reference/#src.aeiva.agent.base_agent.BaseAgent.setup","title":"setup()
abstractmethod
","text":"Set up the agent's components (perception, cognition, action, etc.). Perform any asynchronous initialization if necessary.
Source code in src/aeiva/agent/base_agent.py
@abstractmethod\ndef setup(self) -> None:\n \"\"\"\n Set up the agent's components (perception, cognition, action, etc.).\n Perform any asynchronous initialization if necessary.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.agent.base_agent.BaseAgent.stop","title":"stop()
","text":"Signal the agent to stop running.
Source code in src/aeiva/agent/base_agent.py
def stop(self) -> None:\n \"\"\"\n Signal the agent to stop running.\n \"\"\"\n self.stop_event.set()\n
"},{"location":"reference/#src.aeiva.cognition","title":"cognition
","text":""},{"location":"reference/#src.aeiva.cognition.brain","title":"brain
","text":""},{"location":"reference/#src.aeiva.cognition.brain.brain","title":"brain
","text":""},{"location":"reference/#src.aeiva.cognition.brain.brain.Brain","title":"Brain
","text":" Bases: ABC
Abstract base class representing the cognitive processing unit.
The Brain is responsible for processing input stimuli to generate cognitive states that the CognitionSystem will translate into actions.
Attributes:
Name Type Description config
Any
Configuration settings for the Brain.
state
Any
The internal state of the Brain.
Source code in src/aeiva/cognition/brain/brain.py
class Brain(ABC):\n \"\"\"\n Abstract base class representing the cognitive processing unit.\n\n The Brain is responsible for processing input stimuli to generate cognitive states\n that the CognitionSystem will translate into actions.\n\n Attributes:\n config (Any): Configuration settings for the Brain.\n state (Any): The internal state of the Brain.\n \"\"\"\n\n def __init__(self, config: Any):\n \"\"\"\n Initialize the Brain with the provided configuration.\n\n Args:\n config (Any): Configuration settings for the Brain.\n \"\"\"\n self.config = config\n self.state = self.init_state()\n\n @abstractmethod\n def init_state(self) -> Any:\n \"\"\"\n Initialize the internal state of the Brain.\n\n This method should set up the initial state required for the Brain's operations.\n\n Returns:\n Any: The initial state of the Brain.\n \"\"\"\n pass\n\n @abstractmethod\n def setup(self) -> None:\n \"\"\"\n Asynchronously set up the Brain's components.\n\n This method should initialize any necessary components or resources\n based on the provided configuration.\n\n Raises:\n ConfigurationError: If the configuration is invalid or incomplete.\n \"\"\"\n pass\n\n @abstractmethod\n async def think(self, stimuli: Any, *args, **kwargs) -> Any:\n \"\"\"\n Asynchronously process input stimuli to update the cognitive state.\n\n Args:\n stimuli (Any): The input stimuli to process.\n\n Returns:\n Any: The updated cognitive state.\n\n Raises:\n ProcessingError: If processing the stimuli fails.\n \"\"\"\n pass\n\n def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during cognitive processing.\n\n This method can be overridden to implement custom error handling logic.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n # Default error handling: log the error\n print(f\"Brain encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.brain.brain.Brain.__init__","title":"__init__(config)
","text":"Initialize the Brain with the provided configuration.
Parameters:
Name Type Description Default config
Any
Configuration settings for the Brain.
required Source code in src/aeiva/cognition/brain/brain.py
def __init__(self, config: Any):\n \"\"\"\n Initialize the Brain with the provided configuration.\n\n Args:\n config (Any): Configuration settings for the Brain.\n \"\"\"\n self.config = config\n self.state = self.init_state()\n
"},{"location":"reference/#src.aeiva.cognition.brain.brain.Brain.handle_error","title":"handle_error(error)
","text":"Handle errors that occur during cognitive processing.
This method can be overridden to implement custom error handling logic.
Parameters:
Name Type Description Default error
Exception
The exception that was raised.
required Source code in src/aeiva/cognition/brain/brain.py
def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during cognitive processing.\n\n This method can be overridden to implement custom error handling logic.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n # Default error handling: log the error\n print(f\"Brain encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.brain.brain.Brain.init_state","title":"init_state()
abstractmethod
","text":"Initialize the internal state of the Brain.
This method should set up the initial state required for the Brain's operations.
Returns:
Name Type Description Any
Any
The initial state of the Brain.
Source code in src/aeiva/cognition/brain/brain.py
@abstractmethod\ndef init_state(self) -> Any:\n \"\"\"\n Initialize the internal state of the Brain.\n\n This method should set up the initial state required for the Brain's operations.\n\n Returns:\n Any: The initial state of the Brain.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.brain.brain.Brain.setup","title":"setup()
abstractmethod
","text":"Asynchronously set up the Brain's components.
This method should initialize any necessary components or resources based on the provided configuration.
Raises:
Type Description ConfigurationError
If the configuration is invalid or incomplete.
Source code in src/aeiva/cognition/brain/brain.py
@abstractmethod\ndef setup(self) -> None:\n \"\"\"\n Asynchronously set up the Brain's components.\n\n This method should initialize any necessary components or resources\n based on the provided configuration.\n\n Raises:\n ConfigurationError: If the configuration is invalid or incomplete.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.brain.brain.Brain.think","title":"think(stimuli, *args, **kwargs)
abstractmethod
async
","text":"Asynchronously process input stimuli to update the cognitive state.
Parameters:
Name Type Description Default stimuli
Any
The input stimuli to process.
required Returns:
Name Type Description Any
Any
The updated cognitive state.
Raises:
Type Description ProcessingError
If processing the stimuli fails.
Source code in src/aeiva/cognition/brain/brain.py
@abstractmethod\nasync def think(self, stimuli: Any, *args, **kwargs) -> Any:\n \"\"\"\n Asynchronously process input stimuli to update the cognitive state.\n\n Args:\n stimuli (Any): The input stimuli to process.\n\n Returns:\n Any: The updated cognitive state.\n\n Raises:\n ProcessingError: If processing the stimuli fails.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.brain.llm_brain","title":"llm_brain
","text":""},{"location":"reference/#src.aeiva.cognition.brain.llm_brain.LLMBrain","title":"LLMBrain
","text":" Bases: Brain
Concrete implementation of the Brain, using an LLM to process stimuli and generate cognitive states.
This brain uses the LLMClient to communicate with a language model to process input stimuli and produce outputs.
Source code in src/aeiva/cognition/brain/llm_brain.py
class LLMBrain(Brain):\n \"\"\"\n Concrete implementation of the Brain, using an LLM to process stimuli\n and generate cognitive states.\n\n This brain uses the LLMClient to communicate with a language model to\n process input stimuli and produce outputs.\n \"\"\"\n\n def __init__(self, config: Dict):\n \"\"\"\n Initialize the LLMBrain with the provided LLM configuration.\n\n Args:\n config (LLMGatewayConfig): Configuration settings for the LLMBrain.\n \"\"\"\n super().__init__(config)\n self.config_dict = config\n self.config = None\n self.llm_client = None\n\n def init_state(self) -> Any:\n \"\"\"\n Initialize the internal state of the Brain.\n\n The state can track the ongoing conversation or task context.\n\n Returns:\n dict: Initial empty state.\n \"\"\"\n return {\"conversation\": [], \"cognitive_state\": None}\n\n def setup(self) -> None:\n \"\"\"\n Set up the Brain's components.\n\n For the LLMBrain, this might involve validating the LLM configuration\n and ensuring that all necessary resources are in place.\n \"\"\"\n llm_conf_dict = self.config_dict.get('llm_gateway_config', {})\n self.config = LLMGatewayConfig(\n llm_api_key=llm_conf_dict.get('llm_api_key'),\n llm_model_name=llm_conf_dict.get('llm_model_name', 'gpt-4o'),\n llm_temperature=llm_conf_dict.get('llm_temperature', 0.7),\n llm_max_output_tokens=llm_conf_dict.get('llm_max_output_tokens', 10000),\n llm_use_async=llm_conf_dict.get('llm_use_async', False),\n llm_stream=llm_conf_dict.get('llm_stream', False)\n )\n self.llm_client = LLMClient(self.config)\n\n system_prompt = llm_conf_dict.get('llm_system_prompt', None)\n if system_prompt is not None: # TODO: only add system prompt for llms that support it.\n self.state[\"conversation\"] += [{ \"role\": \"system\", \"content\": system_prompt }]\n\n print(\"LLMBrain setup complete.\")\n\n async def think(\n self,\n stimuli: Any,\n tools: List[Dict[str, Any]] = None,\n stream: bool = False,\n use_async: bool = False\n ) -> AsyncGenerator[str, None]:\n \"\"\"\n Asynchronously process input stimuli to update the cognitive state.\n\n Args:\n stimuli (Any): The input stimuli to process.\n stream (bool): Whether to use streaming mode. Default is False.\n\n Returns:\n str: The full response in both streaming and non-streaming modes.\n \"\"\"\n try:\n # Assume stimuli is a list of messages (conversation context)\n if not isinstance(stimuli, list):\n raise ValueError(\"Stimuli must be a list of messages.\")\n\n self.state[\"conversation\"] += stimuli #!! NOTE: to let LLM remember the history. \n\n if not use_async: # NOTE: stream mode only works when use_async!!!\n response = self.llm_client(self.state[\"conversation\"], tools=tools, stream=stream) #!! NOTE: llm client will update conversation\n # self.state[\"conversation\"] += [{\"role\": \"assistant\", \"content\": response}]\n self.state[\"cognitive_state\"] = response\n yield response\n elif stream:\n # Stream mode: collect all parts of the streamed response\n response = \"\"\n # messages = self.state[\"conversation\"].copy()\n async for delta in self.llm_client(self.state[\"conversation\"], tools=tools, stream=stream): #!! NOTE: llm client will update conversation\n response += delta # Collect the streamed content\n yield delta\n # self.state[\"conversation\"] += [{\"role\": \"assistant\", \"content\": response}]\n self.state[\"cognitive_state\"] = response\n #return response\n else:\n # messages = self.state[\"conversation\"].copy()\n response = await self.llm_client(self.state[\"conversation\"], tools=tools, stream=stream) #!! NOTE: llm client will update conversation\n # self.state[\"conversation\"] += [{\"role\": \"assistant\", \"content\": response}]\n self.state[\"cognitive_state\"] = response\n yield response\n #return response\n\n except Exception as e:\n self.handle_error(e)\n raise\n\n def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during cognitive processing.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n super().handle_error(error)\n # Custom error handling logic for LLM-related issues\n print(f\"LLMBrain encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.brain.llm_brain.LLMBrain.__init__","title":"__init__(config)
","text":"Initialize the LLMBrain with the provided LLM configuration.
Parameters:
Name Type Description Default config
LLMGatewayConfig
Configuration settings for the LLMBrain.
required Source code in src/aeiva/cognition/brain/llm_brain.py
def __init__(self, config: Dict):\n \"\"\"\n Initialize the LLMBrain with the provided LLM configuration.\n\n Args:\n config (LLMGatewayConfig): Configuration settings for the LLMBrain.\n \"\"\"\n super().__init__(config)\n self.config_dict = config\n self.config = None\n self.llm_client = None\n
"},{"location":"reference/#src.aeiva.cognition.brain.llm_brain.LLMBrain.handle_error","title":"handle_error(error)
","text":"Handle errors that occur during cognitive processing.
Parameters:
Name Type Description Default error
Exception
The exception that was raised.
required Source code in src/aeiva/cognition/brain/llm_brain.py
def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during cognitive processing.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n super().handle_error(error)\n # Custom error handling logic for LLM-related issues\n print(f\"LLMBrain encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.brain.llm_brain.LLMBrain.init_state","title":"init_state()
","text":"Initialize the internal state of the Brain.
The state can track the ongoing conversation or task context.
Returns:
Name Type Description dict
Any
Initial empty state.
Source code in src/aeiva/cognition/brain/llm_brain.py
def init_state(self) -> Any:\n \"\"\"\n Initialize the internal state of the Brain.\n\n The state can track the ongoing conversation or task context.\n\n Returns:\n dict: Initial empty state.\n \"\"\"\n return {\"conversation\": [], \"cognitive_state\": None}\n
"},{"location":"reference/#src.aeiva.cognition.brain.llm_brain.LLMBrain.setup","title":"setup()
","text":"Set up the Brain's components.
For the LLMBrain, this might involve validating the LLM configuration and ensuring that all necessary resources are in place.
Source code in src/aeiva/cognition/brain/llm_brain.py
def setup(self) -> None:\n \"\"\"\n Set up the Brain's components.\n\n For the LLMBrain, this might involve validating the LLM configuration\n and ensuring that all necessary resources are in place.\n \"\"\"\n llm_conf_dict = self.config_dict.get('llm_gateway_config', {})\n self.config = LLMGatewayConfig(\n llm_api_key=llm_conf_dict.get('llm_api_key'),\n llm_model_name=llm_conf_dict.get('llm_model_name', 'gpt-4o'),\n llm_temperature=llm_conf_dict.get('llm_temperature', 0.7),\n llm_max_output_tokens=llm_conf_dict.get('llm_max_output_tokens', 10000),\n llm_use_async=llm_conf_dict.get('llm_use_async', False),\n llm_stream=llm_conf_dict.get('llm_stream', False)\n )\n self.llm_client = LLMClient(self.config)\n\n system_prompt = llm_conf_dict.get('llm_system_prompt', None)\n if system_prompt is not None: # TODO: only add system prompt for llms that support it.\n self.state[\"conversation\"] += [{ \"role\": \"system\", \"content\": system_prompt }]\n\n print(\"LLMBrain setup complete.\")\n
"},{"location":"reference/#src.aeiva.cognition.brain.llm_brain.LLMBrain.think","title":"think(stimuli, tools=None, stream=False, use_async=False)
async
","text":"Asynchronously process input stimuli to update the cognitive state.
Parameters:
Name Type Description Default stimuli
Any
The input stimuli to process.
required stream
bool
Whether to use streaming mode. Default is False.
False
Returns:
Name Type Description str
AsyncGenerator[str, None]
The full response in both streaming and non-streaming modes.
Source code in src/aeiva/cognition/brain/llm_brain.py
async def think(\n self,\n stimuli: Any,\n tools: List[Dict[str, Any]] = None,\n stream: bool = False,\n use_async: bool = False\n ) -> AsyncGenerator[str, None]:\n \"\"\"\n Asynchronously process input stimuli to update the cognitive state.\n\n Args:\n stimuli (Any): The input stimuli to process.\n stream (bool): Whether to use streaming mode. Default is False.\n\n Returns:\n str: The full response in both streaming and non-streaming modes.\n \"\"\"\n try:\n # Assume stimuli is a list of messages (conversation context)\n if not isinstance(stimuli, list):\n raise ValueError(\"Stimuli must be a list of messages.\")\n\n self.state[\"conversation\"] += stimuli #!! NOTE: to let LLM remember the history. \n\n if not use_async: # NOTE: stream mode only works when use_async!!!\n response = self.llm_client(self.state[\"conversation\"], tools=tools, stream=stream) #!! NOTE: llm client will update conversation\n # self.state[\"conversation\"] += [{\"role\": \"assistant\", \"content\": response}]\n self.state[\"cognitive_state\"] = response\n yield response\n elif stream:\n # Stream mode: collect all parts of the streamed response\n response = \"\"\n # messages = self.state[\"conversation\"].copy()\n async for delta in self.llm_client(self.state[\"conversation\"], tools=tools, stream=stream): #!! NOTE: llm client will update conversation\n response += delta # Collect the streamed content\n yield delta\n # self.state[\"conversation\"] += [{\"role\": \"assistant\", \"content\": response}]\n self.state[\"cognitive_state\"] = response\n #return response\n else:\n # messages = self.state[\"conversation\"].copy()\n response = await self.llm_client(self.state[\"conversation\"], tools=tools, stream=stream) #!! NOTE: llm client will update conversation\n # self.state[\"conversation\"] += [{\"role\": \"assistant\", \"content\": response}]\n self.state[\"cognitive_state\"] = response\n yield response\n #return response\n\n except Exception as e:\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.cognition_system","title":"cognition_system
","text":""},{"location":"reference/#src.aeiva.cognition.cognition_system.CognitionSystem","title":"CognitionSystem
","text":"Processes Stimuli into Observations, uses the Brain to generate Thoughts, and orchestrates output into Plans.
Source code in src/aeiva/cognition/cognition_system.py
class CognitionSystem:\n \"\"\"\n Processes Stimuli into Observations, uses the Brain to generate Thoughts, and orchestrates output into Plans.\n \"\"\"\n def __init__(self, config: Dict):\n self.config_dict = config\n self.config = None\n self.input_interpreter = None\n self.brain = None\n self.output_orchestrator = None\n self.memory = None\n self.emotion = None\n self.world_model = None\n self.state = self.init_state()\n\n def init_state(self) -> Dict[str, Any]:\n return {\n \"cognitive_state\": None,\n \"last_input\": None,\n \"last_output\": None\n }\n\n def setup(self) -> None:\n \"\"\"\n Set up the cognition system's components.\n \"\"\"\n self.brain = LLMBrain(config=self.config_dict)\n self.memory = MemoryPalace(config=self.config_dict)\n self.emotion = SimpleEmotion() # TODO: replace\n self.world_model = SimpleWorldModel() # TODO: replace\n self.input_interpreter = SimpleInputInterpreter() # TODO: replace\n self.output_orchestrator = SimpleOutputOrchestrator() # TODO: replace\n\n self.brain.setup()\n self.memory.setup()\n self.world_model.setup()\n self.emotion.setup()\n self.input_interpreter.setup()\n self.output_orchestrator.setup()\n\n def handle_error(self, error: Exception) -> None:\n print(f\"CognitionSystem encountered an error: {error}\")\n\n async def think(\n self,\n stimuli: Stimuli,\n tools: List[Dict[str, Any]] = None,\n stream: bool=False,\n use_async: bool=False\n ) -> AsyncGenerator[str, None]:\n \"\"\"\n Processes stimuli and produces a thought or plan.\n\n Args:\n stimuli (Stimuli): The input stimuli.\n stream (bool): Whether to use streaming mode.\n tools (List[Dict[str, Any]]): Optional tools for function calls.\n\n Yields:\n str: Chunks of the assistant's response.\n \"\"\"\n self.state[\"last_input\"] = stimuli\n\n # Step 1: Use InputInterpreter to process stimuli into observation\n if self.input_interpreter.gate(stimuli):\n observation = await self.input_interpreter.interpret(stimuli)\n else:\n # Directly pass stimuli as observation (assuming it's acceptable)\n observation = Observation(data=stimuli.to_dict())\n\n # Step 2: Brain processes the observation into a thought or plan\n brain_input = [{\"role\": \"user\", \"content\": observation.data}]\n # Initiate brain processing\n response_gen = self.brain.think(brain_input, tools=tools, stream=stream, use_async=use_async)\n\n async for chunk in response_gen:\n if isinstance(chunk, str):\n # Streaming chunk or full response in non-streaming mode\n yield chunk\n elif isinstance(chunk, Thought):\n thought = chunk\n self.state[\"cognitive_state\"] = thought\n\n # Step 3: Use OutputOrchestrator if applicable\n if self.output_orchestrator.gate(thought):\n plan = await self.output_orchestrator.orchestrate(thought)\n self.state[\"last_output\"] = plan\n yield plan.content if hasattr(plan, 'content') else str(plan)\n else:\n self.state[\"last_output\"] = thought\n yield thought.content\n elif isinstance(chunk, Plan):\n plan = chunk\n self.state[\"last_output\"] = plan\n yield plan.content if hasattr(plan, 'content') else str(plan)\n else:\n # Handle unexpected chunk types\n #logger.warning(f\"Unexpected chunk type: {type(chunk)}\")\n yield str(chunk)\n
"},{"location":"reference/#src.aeiva.cognition.cognition_system.CognitionSystem.setup","title":"setup()
","text":"Set up the cognition system's components.
Source code in src/aeiva/cognition/cognition_system.py
def setup(self) -> None:\n \"\"\"\n Set up the cognition system's components.\n \"\"\"\n self.brain = LLMBrain(config=self.config_dict)\n self.memory = MemoryPalace(config=self.config_dict)\n self.emotion = SimpleEmotion() # TODO: replace\n self.world_model = SimpleWorldModel() # TODO: replace\n self.input_interpreter = SimpleInputInterpreter() # TODO: replace\n self.output_orchestrator = SimpleOutputOrchestrator() # TODO: replace\n\n self.brain.setup()\n self.memory.setup()\n self.world_model.setup()\n self.emotion.setup()\n self.input_interpreter.setup()\n self.output_orchestrator.setup()\n
"},{"location":"reference/#src.aeiva.cognition.cognition_system.CognitionSystem.think","title":"think(stimuli, tools=None, stream=False, use_async=False)
async
","text":"Processes stimuli and produces a thought or plan.
Parameters:
Name Type Description Default stimuli
Stimuli
The input stimuli.
required stream
bool
Whether to use streaming mode.
False
tools
List[Dict[str, Any]]
Optional tools for function calls.
None
Yields:
Name Type Description str
AsyncGenerator[str, None]
Chunks of the assistant's response.
Source code in src/aeiva/cognition/cognition_system.py
async def think(\n self,\n stimuli: Stimuli,\n tools: List[Dict[str, Any]] = None,\n stream: bool=False,\n use_async: bool=False\n ) -> AsyncGenerator[str, None]:\n \"\"\"\n Processes stimuli and produces a thought or plan.\n\n Args:\n stimuli (Stimuli): The input stimuli.\n stream (bool): Whether to use streaming mode.\n tools (List[Dict[str, Any]]): Optional tools for function calls.\n\n Yields:\n str: Chunks of the assistant's response.\n \"\"\"\n self.state[\"last_input\"] = stimuli\n\n # Step 1: Use InputInterpreter to process stimuli into observation\n if self.input_interpreter.gate(stimuli):\n observation = await self.input_interpreter.interpret(stimuli)\n else:\n # Directly pass stimuli as observation (assuming it's acceptable)\n observation = Observation(data=stimuli.to_dict())\n\n # Step 2: Brain processes the observation into a thought or plan\n brain_input = [{\"role\": \"user\", \"content\": observation.data}]\n # Initiate brain processing\n response_gen = self.brain.think(brain_input, tools=tools, stream=stream, use_async=use_async)\n\n async for chunk in response_gen:\n if isinstance(chunk, str):\n # Streaming chunk or full response in non-streaming mode\n yield chunk\n elif isinstance(chunk, Thought):\n thought = chunk\n self.state[\"cognitive_state\"] = thought\n\n # Step 3: Use OutputOrchestrator if applicable\n if self.output_orchestrator.gate(thought):\n plan = await self.output_orchestrator.orchestrate(thought)\n self.state[\"last_output\"] = plan\n yield plan.content if hasattr(plan, 'content') else str(plan)\n else:\n self.state[\"last_output\"] = thought\n yield thought.content\n elif isinstance(chunk, Plan):\n plan = chunk\n self.state[\"last_output\"] = plan\n yield plan.content if hasattr(plan, 'content') else str(plan)\n else:\n # Handle unexpected chunk types\n #logger.warning(f\"Unexpected chunk type: {type(chunk)}\")\n yield str(chunk)\n
"},{"location":"reference/#src.aeiva.cognition.emotion","title":"emotion
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.emotion","title":"emotion
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.emotion.ConfigurationError","title":"ConfigurationError
","text":" Bases: Exception
Exception raised for errors in the configuration.
Source code in src/aeiva/cognition/emotion/emotion.py
class ConfigurationError(Exception):\n \"\"\"Exception raised for errors in the configuration.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion","title":"Emotion
","text":" Bases: ABC
, Generic[T]
Abstract base class representing the Emotion system of an agent with generic state type.
The Emotion system manages the agent's emotional states, allowing it to respond to various stimuli in an emotionally coherent manner.
Attributes:
Name Type Description config
Dict[str, Any]
Configuration settings for the Emotion system.
state
T
The internal emotional state of the agent, defined by subclasses.
Source code in src/aeiva/cognition/emotion/emotion.py
class Emotion(ABC, Generic[T]):\n \"\"\"\n Abstract base class representing the Emotion system of an agent with generic state type.\n\n The Emotion system manages the agent's emotional states, allowing it to respond\n to various stimuli in an emotionally coherent manner.\n\n Attributes:\n config (Dict[str, Any]): Configuration settings for the Emotion system.\n state (T): The internal emotional state of the agent, defined by subclasses.\n \"\"\"\n\n def __init__(self, config: Dict[str, Any]):\n \"\"\"\n Initialize the Emotion system with the provided configuration.\n\n Args:\n config (Dict[str, Any]): Configuration settings for the Emotion system.\n \"\"\"\n self.config = config\n self.state = self.init_state()\n\n @abstractmethod\n def init_state(self) -> T:\n \"\"\"\n Initialize the internal emotional state of the Emotion system.\n\n This method should set up the initial emotional state required for the\n Emotion system's operations.\n\n Returns:\n T: The initial emotional state of the agent.\n \"\"\"\n pass\n\n @abstractmethod\n async def setup(self) -> None:\n \"\"\"\n Asynchronously set up the Emotion system's components.\n\n This method should initialize any necessary components or resources\n based on the provided configuration.\n\n Raises:\n ConfigurationError: If the configuration is invalid or incomplete.\n \"\"\"\n pass\n\n @abstractmethod\n async def update(self, input_data: Dict[str, Any]) -> None:\n \"\"\"\n Asynchronously update the emotional state based on input data.\n\n Args:\n input_data (Dict[str, Any]): The data or stimuli that influence the emotional state.\n\n Raises:\n UpdateError: If updating the emotional state fails.\n \"\"\"\n pass\n\n @abstractmethod\n def regulate(self, strategy: str) -> None:\n \"\"\"\n Regulate the emotional state using a specified strategy.\n\n Args:\n strategy (str): The regulation strategy to apply (e.g., 'suppression', 'amplification').\n\n Raises:\n RegulationError: If the regulation strategy is invalid or fails.\n \"\"\"\n pass\n\n @abstractmethod\n def express(self) -> str:\n \"\"\"\n Generate a representation of the current emotional state.\n\n Returns:\n str: A string describing the current emotion (e.g., \"I feel happy!\").\n \"\"\"\n pass\n\n @abstractmethod\n def serialize(self) -> str:\n \"\"\"\n Serialize the current emotional state into a string format.\n\n Returns:\n str: Serialized emotional state.\n \"\"\"\n pass\n\n @abstractmethod\n def deserialize(self, data: str) -> None:\n \"\"\"\n Deserialize the emotional state from a string format.\n\n Args:\n data (str): Serialized emotional state.\n \"\"\"\n pass\n\n def get_current_state(self) -> T:\n \"\"\"\n Retrieve the current emotional state of the agent.\n\n Returns:\n T: The current emotional state.\n \"\"\"\n return self.state\n\n def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during emotional processing.\n\n This method can be overridden to implement custom error handling logic.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion.__init__","title":"__init__(config)
","text":"Initialize the Emotion system with the provided configuration.
Parameters:
Name Type Description Default config
Dict[str, Any]
Configuration settings for the Emotion system.
required Source code in src/aeiva/cognition/emotion/emotion.py
def __init__(self, config: Dict[str, Any]):\n \"\"\"\n Initialize the Emotion system with the provided configuration.\n\n Args:\n config (Dict[str, Any]): Configuration settings for the Emotion system.\n \"\"\"\n self.config = config\n self.state = self.init_state()\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion.deserialize","title":"deserialize(data)
abstractmethod
","text":"Deserialize the emotional state from a string format.
Parameters:
Name Type Description Default data
str
Serialized emotional state.
required Source code in src/aeiva/cognition/emotion/emotion.py
@abstractmethod\ndef deserialize(self, data: str) -> None:\n \"\"\"\n Deserialize the emotional state from a string format.\n\n Args:\n data (str): Serialized emotional state.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion.express","title":"express()
abstractmethod
","text":"Generate a representation of the current emotional state.
Returns:
Name Type Description str
str
A string describing the current emotion (e.g., \"I feel happy!\").
Source code in src/aeiva/cognition/emotion/emotion.py
@abstractmethod\ndef express(self) -> str:\n \"\"\"\n Generate a representation of the current emotional state.\n\n Returns:\n str: A string describing the current emotion (e.g., \"I feel happy!\").\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion.get_current_state","title":"get_current_state()
","text":"Retrieve the current emotional state of the agent.
Returns:
Name Type Description T
T
The current emotional state.
Source code in src/aeiva/cognition/emotion/emotion.py
def get_current_state(self) -> T:\n \"\"\"\n Retrieve the current emotional state of the agent.\n\n Returns:\n T: The current emotional state.\n \"\"\"\n return self.state\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion.handle_error","title":"handle_error(error)
","text":"Handle errors that occur during emotional processing.
This method can be overridden to implement custom error handling logic.
Parameters:
Name Type Description Default error
Exception
The exception that was raised.
required Source code in src/aeiva/cognition/emotion/emotion.py
def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during emotional processing.\n\n This method can be overridden to implement custom error handling logic.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion.init_state","title":"init_state()
abstractmethod
","text":"Initialize the internal emotional state of the Emotion system.
This method should set up the initial emotional state required for the Emotion system's operations.
Returns:
Name Type Description T
T
The initial emotional state of the agent.
Source code in src/aeiva/cognition/emotion/emotion.py
@abstractmethod\ndef init_state(self) -> T:\n \"\"\"\n Initialize the internal emotional state of the Emotion system.\n\n This method should set up the initial emotional state required for the\n Emotion system's operations.\n\n Returns:\n T: The initial emotional state of the agent.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion.regulate","title":"regulate(strategy)
abstractmethod
","text":"Regulate the emotional state using a specified strategy.
Parameters:
Name Type Description Default strategy
str
The regulation strategy to apply (e.g., 'suppression', 'amplification').
required Raises:
Type Description RegulationError
If the regulation strategy is invalid or fails.
Source code in src/aeiva/cognition/emotion/emotion.py
@abstractmethod\ndef regulate(self, strategy: str) -> None:\n \"\"\"\n Regulate the emotional state using a specified strategy.\n\n Args:\n strategy (str): The regulation strategy to apply (e.g., 'suppression', 'amplification').\n\n Raises:\n RegulationError: If the regulation strategy is invalid or fails.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion.serialize","title":"serialize()
abstractmethod
","text":"Serialize the current emotional state into a string format.
Returns:
Name Type Description str
str
Serialized emotional state.
Source code in src/aeiva/cognition/emotion/emotion.py
@abstractmethod\ndef serialize(self) -> str:\n \"\"\"\n Serialize the current emotional state into a string format.\n\n Returns:\n str: Serialized emotional state.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion.setup","title":"setup()
abstractmethod
async
","text":"Asynchronously set up the Emotion system's components.
This method should initialize any necessary components or resources based on the provided configuration.
Raises:
Type Description ConfigurationError
If the configuration is invalid or incomplete.
Source code in src/aeiva/cognition/emotion/emotion.py
@abstractmethod\nasync def setup(self) -> None:\n \"\"\"\n Asynchronously set up the Emotion system's components.\n\n This method should initialize any necessary components or resources\n based on the provided configuration.\n\n Raises:\n ConfigurationError: If the configuration is invalid or incomplete.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.Emotion.update","title":"update(input_data)
abstractmethod
async
","text":"Asynchronously update the emotional state based on input data.
Parameters:
Name Type Description Default input_data
Dict[str, Any]
The data or stimuli that influence the emotional state.
required Raises:
Type Description UpdateError
If updating the emotional state fails.
Source code in src/aeiva/cognition/emotion/emotion.py
@abstractmethod\nasync def update(self, input_data: Dict[str, Any]) -> None:\n \"\"\"\n Asynchronously update the emotional state based on input data.\n\n Args:\n input_data (Dict[str, Any]): The data or stimuli that influence the emotional state.\n\n Raises:\n UpdateError: If updating the emotional state fails.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.RegulationError","title":"RegulationError
","text":" Bases: Exception
Exception raised for errors during emotion regulation.
Source code in src/aeiva/cognition/emotion/emotion.py
class RegulationError(Exception):\n \"\"\"Exception raised for errors during emotion regulation.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion.UpdateError","title":"UpdateError
","text":" Bases: Exception
Exception raised for errors during emotion state updates.
Source code in src/aeiva/cognition/emotion/emotion.py
class UpdateError(Exception):\n \"\"\"Exception raised for errors during emotion state updates.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion_categorical","title":"emotion_categorical
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.emotion_categorical.CategoricalEmotionState","title":"CategoricalEmotionState
","text":"Represents the emotional state in a Categorical Model.
Source code in src/aeiva/cognition/emotion/emotion_categorical.py
class CategoricalEmotionState:\n \"\"\"\n Represents the emotional state in a Categorical Model.\n \"\"\"\n def __init__(self, emotion_label: str = \"neutral\"):\n self.emotion_label = emotion_label\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n 'emotion_label': self.emotion_label\n }\n\n @staticmethod\n def from_dict(data: Dict[str, Any]):\n return CategoricalEmotionState(\n emotion_label=data.get('emotion_label', 'neutral')\n )\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion_category","title":"emotion_category
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.emotion_category.CategoryEmotionState","title":"CategoryEmotionState
dataclass
","text":"Represents the emotional state in a Category-Based Model with extensive categories.
Attributes:
Name Type Description emotion_label
str
The current emotion category.
intensity
float
The intensity of the current emotion (range: 0.0 to 1.0).
Source code in src/aeiva/cognition/emotion/emotion_category.py
@dataclass\nclass CategoryEmotionState:\n \"\"\"\n Represents the emotional state in a Category-Based Model with extensive categories.\n\n Attributes:\n emotion_label (str): The current emotion category.\n intensity (float): The intensity of the current emotion (range: 0.0 to 1.0).\n \"\"\"\n emotion_label: str = \"neutral\"\n intensity: float = 0.0 # Optional: Represents the strength of the emotion\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n 'emotion_label': self.emotion_label,\n 'intensity': self.intensity\n }\n\n @staticmethod\n def from_dict(data: Dict[str, Any]):\n return CategoryEmotionState(\n emotion_label=data.get('emotion_label', 'neutral'),\n intensity=data.get('intensity', 0.0)\n )\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion_circumplex","title":"emotion_circumplex
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.emotion_circumplex.CircumplexEmotionState","title":"CircumplexEmotionState
","text":"Represents the emotional state in the Circumplex Model.
Source code in src/aeiva/cognition/emotion/emotion_circumplex.py
class CircumplexEmotionState:\n \"\"\"\n Represents the emotional state in the Circumplex Model.\n \"\"\"\n def __init__(self, valence: float = 0.0, arousal: float = 0.0):\n self.valence = valence # Range: [-1.0, 1.0]\n self.arousal = arousal # Range: [-1.0, 1.0]\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n 'valence': self.valence,\n 'arousal': self.arousal\n }\n\n @staticmethod\n def from_dict(data: Dict[str, Any]):\n return CircumplexEmotionState(\n valence=data.get('valence', 0.0),\n arousal=data.get('arousal', 0.0)\n )\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion_componential","title":"emotion_componential
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.emotion_componential.ComponentialEmotionState","title":"ComponentialEmotionState
dataclass
","text":"Represents the emotional state based on the Componential Model.
Attributes:
Name Type Description emotion_label
str
Current emotion category.
intensity
float
Intensity of the emotion (0.0 to 1.0).
Source code in src/aeiva/cognition/emotion/emotion_componential.py
@dataclass\nclass ComponentialEmotionState:\n \"\"\"\n Represents the emotional state based on the Componential Model.\n\n Attributes:\n emotion_label (str): Current emotion category.\n intensity (float): Intensity of the emotion (0.0 to 1.0).\n \"\"\"\n emotion_label: str = \"neutral\"\n intensity: float = 0.0\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n 'emotion_label': self.emotion_label,\n 'intensity': self.intensity\n }\n\n @staticmethod\n def from_dict(data: Dict[str, Any]):\n return ComponentialEmotionState(\n emotion_label=data.get('emotion_label', 'neutral'),\n intensity=data.get('intensity', 0.0)\n )\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion_hybrid","title":"emotion_hybrid
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.emotion_hybrid.HybridEmotionState","title":"HybridEmotionState
","text":"Represents the emotional state in the Hybrid Categorical-Dimensional Model.
Source code in src/aeiva/cognition/emotion/emotion_hybrid.py
class HybridEmotionState:\n \"\"\"\n Represents the emotional state in the Hybrid Categorical-Dimensional Model.\n \"\"\"\n def __init__(self, emotion_label: str = \"neutral\", valence: float = 0.0, arousal: float = 0.0):\n self.emotion_label = emotion_label # Categorical label\n self.valence = valence # Dimensional valence\n self.arousal = arousal # Dimensional arousal\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n 'emotion_label': self.emotion_label,\n 'valence': self.valence,\n 'arousal': self.arousal\n }\n\n @staticmethod\n def from_dict(data: Dict[str, Any]):\n return HybridEmotionState(\n emotion_label=data.get('emotion_label', 'neutral'),\n valence=data.get('valence', 0.0),\n arousal=data.get('arousal', 0.0)\n )\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion_occ","title":"emotion_occ
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.emotion_occ.OCCEmotionState","title":"OCCEmotionState
","text":"Represents the emotional state in the OCC Appraisal-Based Model.
Source code in src/aeiva/cognition/emotion/emotion_occ.py
class OCCEmotionState:\n \"\"\"\n Represents the emotional state in the OCC Appraisal-Based Model.\n \"\"\"\n def __init__(self, emotion_categories: Dict[str, float] = None):\n \"\"\"\n Initialize the OCC emotion state with emotion categories and their intensities.\n \"\"\"\n # Initialize with zero intensities if not provided\n self.emotion_categories = emotion_categories if emotion_categories else {\n 'joy': 0.0,\n 'sadness': 0.0,\n 'anger': 0.0,\n 'fear': 0.0,\n 'surprise': 0.0,\n 'disgust': 0.0\n }\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n 'emotion_categories': self.emotion_categories\n }\n\n @staticmethod\n def from_dict(data: Dict[str, Any]):\n return OCCEmotionState(\n emotion_categories=data.get('emotion_categories', {})\n )\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion_occ.OCCEmotionState.__init__","title":"__init__(emotion_categories=None)
","text":"Initialize the OCC emotion state with emotion categories and their intensities.
Source code in src/aeiva/cognition/emotion/emotion_occ.py
def __init__(self, emotion_categories: Dict[str, float] = None):\n \"\"\"\n Initialize the OCC emotion state with emotion categories and their intensities.\n \"\"\"\n # Initialize with zero intensities if not provided\n self.emotion_categories = emotion_categories if emotion_categories else {\n 'joy': 0.0,\n 'sadness': 0.0,\n 'anger': 0.0,\n 'fear': 0.0,\n 'surprise': 0.0,\n 'disgust': 0.0\n }\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion_pad","title":"emotion_pad
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.emotion_pad.PADEmotionState","title":"PADEmotionState
","text":"Represents the emotional state in the PAD Model.
Source code in src/aeiva/cognition/emotion/emotion_pad.py
class PADEmotionState:\n \"\"\"\n Represents the emotional state in the PAD Model.\n \"\"\"\n def __init__(self, pleasure: float = 0.0, arousal: float = 0.0, dominance: float = 0.0):\n self.pleasure = pleasure # Range: [-1.0, 1.0]\n self.arousal = arousal # Range: [-1.0, 1.0]\n self.dominance = dominance # Range: [-1.0, 1.0]\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n 'pleasure': self.pleasure,\n 'arousal': self.arousal,\n 'dominance': self.dominance\n }\n\n @staticmethod\n def from_dict(data: Dict[str, Any]):\n return PADEmotionState(\n pleasure=data.get('pleasure', 0.0),\n arousal=data.get('arousal', 0.0),\n dominance=data.get('dominance', 0.0)\n )\n
"},{"location":"reference/#src.aeiva.cognition.emotion.emotion_plutchik","title":"emotion_plutchik
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.emotion_plutchik.PlutchikEmotionState","title":"PlutchikEmotionState
dataclass
","text":"Represents the emotional state in Plutchik's Wheel of Emotions.
Attributes:
Name Type Description joy
float
Intensity of Joy.
trust
float
Intensity of Trust.
fear
float
Intensity of Fear.
surprise
float
Intensity of Surprise.
sadness
float
Intensity of Sadness.
disgust
float
Intensity of Disgust.
anger
float
Intensity of Anger.
anticipation
float
Intensity of Anticipation.
Source code in src/aeiva/cognition/emotion/emotion_plutchik.py
@dataclass\nclass PlutchikEmotionState:\n \"\"\"\n Represents the emotional state in Plutchik's Wheel of Emotions.\n\n Attributes:\n joy (float): Intensity of Joy.\n trust (float): Intensity of Trust.\n fear (float): Intensity of Fear.\n surprise (float): Intensity of Surprise.\n sadness (float): Intensity of Sadness.\n disgust (float): Intensity of Disgust.\n anger (float): Intensity of Anger.\n anticipation (float): Intensity of Anticipation.\n \"\"\"\n joy: float = 0.0\n trust: float = 0.0\n fear: float = 0.0\n surprise: float = 0.0\n sadness: float = 0.0\n disgust: float = 0.0\n anger: float = 0.0\n anticipation: float = 0.0\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n 'joy': self.joy,\n 'trust': self.trust,\n 'fear': self.fear,\n 'surprise': self.surprise,\n 'sadness': self.sadness,\n 'disgust': self.disgust,\n 'anger': self.anger,\n 'anticipation': self.anticipation\n }\n\n @staticmethod\n def from_dict(data: Dict[str, Any]):\n return PlutchikEmotionState(\n joy=data.get('joy', 0.0),\n trust=data.get('trust', 0.0),\n fear=data.get('fear', 0.0),\n surprise=data.get('surprise', 0.0),\n sadness=data.get('sadness', 0.0),\n disgust=data.get('disgust', 0.0),\n anger=data.get('anger', 0.0),\n anticipation=data.get('anticipation', 0.0)\n )\n
"},{"location":"reference/#src.aeiva.cognition.emotion.exceptions","title":"exceptions
","text":""},{"location":"reference/#src.aeiva.cognition.emotion.exceptions.ConfigurationError","title":"ConfigurationError
","text":" Bases: Exception
Exception raised for errors in the configuration.
Source code in src/aeiva/cognition/emotion/exceptions.py
class ConfigurationError(Exception):\n \"\"\"Exception raised for errors in the configuration.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.exceptions.RegulationError","title":"RegulationError
","text":" Bases: Exception
Exception raised for errors during emotion regulation.
Source code in src/aeiva/cognition/emotion/exceptions.py
class RegulationError(Exception):\n \"\"\"Exception raised for errors during emotion regulation.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.emotion.exceptions.UpdateError","title":"UpdateError
","text":" Bases: Exception
Exception raised for errors during emotion state updates.
Source code in src/aeiva/cognition/emotion/exceptions.py
class UpdateError(Exception):\n \"\"\"Exception raised for errors during emotion state updates.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory","title":"memory
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory","title":"memory
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory","title":"Memory
","text":" Bases: ABC
Abstract base class for memory operations in the intelligent agent.
This class defines methods corresponding to different layers of memory processing, such as creating, filtering, grouping, deriving, structuring, skillizing, embedding, and parameterizing memory units.
Source code in src/aeiva/cognition/memory/memory.py
class Memory(ABC):\n \"\"\"\n Abstract base class for memory operations in the intelligent agent.\n\n This class defines methods corresponding to different layers of memory processing,\n such as creating, filtering, grouping, deriving, structuring, skillizing, embedding,\n and parameterizing memory units.\n \"\"\"\n\n def __init__(self, config: Any):\n \"\"\"\n Initialize the Memory system with the provided configuration.\n\n Args:\n config (Any): Configuration settings for the Memory system.\n \"\"\"\n self.config = config\n\n @abstractmethod\n def setup(self) -> None:\n \"\"\"\n Asynchronously set up the Memory system's components.\n\n This method should initialize any necessary components or resources based on the provided configuration.\n\n Raises:\n ConfigurationError: If the configuration is invalid or incomplete.\n \"\"\"\n pass\n\n @abstractmethod\n def create(self, content: Any, **kwargs) -> MemoryUnit:\n \"\"\"\n Creates a new memory unit with the given content and metadata.\n\n Args:\n content (Any): The core content of the memory unit.\n **kwargs: Additional metadata for the memory unit.\n\n Returns:\n MemoryUnit: The created memory unit.\n \"\"\"\n pass\n\n @abstractmethod\n def get(self, unit_id: str) -> MemoryUnit:\n \"\"\"\n Retrieves a memory unit by its unique identifier.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n\n Returns:\n MemoryUnit: The retrieved memory unit.\n \"\"\"\n pass\n\n @abstractmethod\n def update(self, unit_id: str, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a memory unit with the given updates.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n updates (Dict[str, Any]): A dictionary of fields to update.\n \"\"\"\n pass\n\n @abstractmethod\n def delete(self, unit_id: str) -> None:\n \"\"\"\n Deletes a memory unit by its unique identifier.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n pass\n\n @abstractmethod\n def get_all(self) -> List[MemoryUnit]:\n \"\"\"\n Retrieves all memory units.\n\n Returns:\n List[MemoryUnit]: A list of all memory units.\n \"\"\"\n pass\n\n @abstractmethod\n def delete_all(self) -> None:\n \"\"\"\n Deletes all memory units.\n \"\"\"\n pass\n\n @abstractmethod\n def load(self) -> None:\n \"\"\"\n Loads the memory from file. The path is specified in config.\n \"\"\"\n pass\n\n @abstractmethod\n def save(self) -> None:\n \"\"\"\n Save the memory to database or file. The path is specified in config.\n \"\"\"\n pass\n\n @abstractmethod\n def filter(self, criteria: Dict[str, Any]) -> List[MemoryUnit]:\n \"\"\"\n Filters memory units based on the given criteria.\n\n Args:\n criteria (Dict[str, Any]): A dictionary of filter conditions.\n\n Returns:\n List[MemoryUnit]: A list of memory units matching the criteria.\n \"\"\"\n pass\n\n @abstractmethod\n def organize(self, unit_ids: List[str], organize_type: str, metadata: Optional[Dict[str, Any]] = None) -> str:\n \"\"\"\n Groups memory units into a meaningful group.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to group.\n organize_type (str): The type of group (e.g., 'dialogue_session', 'procedure').\n metadata (Optional[Dict[str, Any]]): Additional metadata for the group.\n\n Returns:\n str: A unique identifier for the created group.\n \"\"\"\n pass\n\n # @abstractmethod\n # def derive(self, unit_ids: List[str], derivation_type: str, **kwargs) -> MemoryUnit:\n # \"\"\"\n # Derives a new memory unit from existing ones.\n\n # Args:\n # unit_ids (List[str]): A list of memory unit IDs to derive from.\n # derivation_type (str): The type of derivation (e.g., 'summary', 'transformation').\n # **kwargs: Additional parameters for the derivation process.\n\n # Returns:\n # MemoryUnit: The derived memory unit.\n # \"\"\"\n # pass\n\n @abstractmethod\n def structurize(self, unit_ids: List[str], structure_type: str, **kwargs) -> None:\n \"\"\"\n Structures memory units into a knowledge graph or other structures.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to structurize.\n structure_type (str): The type of structure (e.g., 'knowledge_graph').\n **kwargs: Additional parameters for the structuring process.\n \"\"\"\n pass\n\n @abstractmethod\n def skillize(self, unit_ids: List[str], skill_name: str, **kwargs) -> str:\n \"\"\"\n Converts memory units into a reusable skill.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to skillize.\n skill_name (str): The name of the skill to create.\n **kwargs: Additional parameters for skill creation.\n\n Returns:\n str: The unique identifier of the created skill.\n \"\"\"\n pass\n\n @abstractmethod\n def embed(self, unit_id: str) -> None:\n \"\"\"\n Generates an embedding for a memory unit.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n pass\n\n @abstractmethod\n def parameterize(self, **kwargs) -> None:\n \"\"\"\n Trains a parametric model using the memory data.\n\n Args:\n **kwargs: Additional parameters for the training process.\n \"\"\"\n pass\n\n @abstractmethod\n def retrieve(self, query: Any, retrieve_type: str, **kwargs) -> List[MemoryUnit]:\n \"\"\"\n Asynchronously retrieve data from memory based on a query.\n\n Args:\n query (Any): The query or criteria to retrieve specific memory data.\n retrieve_type (str): The type of retrieval (e.g., 'retrieve_related', 'retrieve_similar').\n **kwargs: Additional parameters for the structuring process.\n\n Returns:\n Any: The retrieved memory data.\n\n Raises:\n RetrievalError: If the retrieval process fails.\n \"\"\"\n pass\n\n def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during memory operations.\n\n This method can be overridden to implement custom error handling logic.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n # Default error handling: log the error\n print(f\"Memory system encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.__init__","title":"__init__(config)
","text":"Initialize the Memory system with the provided configuration.
Parameters:
Name Type Description Default config
Any
Configuration settings for the Memory system.
required Source code in src/aeiva/cognition/memory/memory.py
def __init__(self, config: Any):\n \"\"\"\n Initialize the Memory system with the provided configuration.\n\n Args:\n config (Any): Configuration settings for the Memory system.\n \"\"\"\n self.config = config\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.create","title":"create(content, **kwargs)
abstractmethod
","text":"Creates a new memory unit with the given content and metadata.
Parameters:
Name Type Description Default content
Any
The core content of the memory unit.
required **kwargs
Additional metadata for the memory unit.
{}
Returns:
Name Type Description MemoryUnit
MemoryUnit
The created memory unit.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef create(self, content: Any, **kwargs) -> MemoryUnit:\n \"\"\"\n Creates a new memory unit with the given content and metadata.\n\n Args:\n content (Any): The core content of the memory unit.\n **kwargs: Additional metadata for the memory unit.\n\n Returns:\n MemoryUnit: The created memory unit.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.delete","title":"delete(unit_id)
abstractmethod
","text":"Deletes a memory unit by its unique identifier.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef delete(self, unit_id: str) -> None:\n \"\"\"\n Deletes a memory unit by its unique identifier.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.delete_all","title":"delete_all()
abstractmethod
","text":"Deletes all memory units.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef delete_all(self) -> None:\n \"\"\"\n Deletes all memory units.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.embed","title":"embed(unit_id)
abstractmethod
","text":"Generates an embedding for a memory unit.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef embed(self, unit_id: str) -> None:\n \"\"\"\n Generates an embedding for a memory unit.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.filter","title":"filter(criteria)
abstractmethod
","text":"Filters memory units based on the given criteria.
Parameters:
Name Type Description Default criteria
Dict[str, Any]
A dictionary of filter conditions.
required Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of memory units matching the criteria.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef filter(self, criteria: Dict[str, Any]) -> List[MemoryUnit]:\n \"\"\"\n Filters memory units based on the given criteria.\n\n Args:\n criteria (Dict[str, Any]): A dictionary of filter conditions.\n\n Returns:\n List[MemoryUnit]: A list of memory units matching the criteria.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.get","title":"get(unit_id)
abstractmethod
","text":"Retrieves a memory unit by its unique identifier.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required Returns:
Name Type Description MemoryUnit
MemoryUnit
The retrieved memory unit.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef get(self, unit_id: str) -> MemoryUnit:\n \"\"\"\n Retrieves a memory unit by its unique identifier.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n\n Returns:\n MemoryUnit: The retrieved memory unit.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.get_all","title":"get_all()
abstractmethod
","text":"Retrieves all memory units.
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of all memory units.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef get_all(self) -> List[MemoryUnit]:\n \"\"\"\n Retrieves all memory units.\n\n Returns:\n List[MemoryUnit]: A list of all memory units.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.handle_error","title":"handle_error(error)
","text":"Handle errors that occur during memory operations.
This method can be overridden to implement custom error handling logic.
Parameters:
Name Type Description Default error
Exception
The exception that was raised.
required Source code in src/aeiva/cognition/memory/memory.py
def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during memory operations.\n\n This method can be overridden to implement custom error handling logic.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n # Default error handling: log the error\n print(f\"Memory system encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.load","title":"load()
abstractmethod
","text":"Loads the memory from file. The path is specified in config.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef load(self) -> None:\n \"\"\"\n Loads the memory from file. The path is specified in config.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.organize","title":"organize(unit_ids, organize_type, metadata=None)
abstractmethod
","text":"Groups memory units into a meaningful group.
Parameters:
Name Type Description Default unit_ids
List[str]
A list of memory unit IDs to group.
required organize_type
str
The type of group (e.g., 'dialogue_session', 'procedure').
required metadata
Optional[Dict[str, Any]]
Additional metadata for the group.
None
Returns:
Name Type Description str
str
A unique identifier for the created group.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef organize(self, unit_ids: List[str], organize_type: str, metadata: Optional[Dict[str, Any]] = None) -> str:\n \"\"\"\n Groups memory units into a meaningful group.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to group.\n organize_type (str): The type of group (e.g., 'dialogue_session', 'procedure').\n metadata (Optional[Dict[str, Any]]): Additional metadata for the group.\n\n Returns:\n str: A unique identifier for the created group.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.parameterize","title":"parameterize(**kwargs)
abstractmethod
","text":"Trains a parametric model using the memory data.
Parameters:
Name Type Description Default **kwargs
Additional parameters for the training process.
{}
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef parameterize(self, **kwargs) -> None:\n \"\"\"\n Trains a parametric model using the memory data.\n\n Args:\n **kwargs: Additional parameters for the training process.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.retrieve","title":"retrieve(query, retrieve_type, **kwargs)
abstractmethod
","text":"Asynchronously retrieve data from memory based on a query.
Parameters:
Name Type Description Default query
Any
The query or criteria to retrieve specific memory data.
required retrieve_type
str
The type of retrieval (e.g., 'retrieve_related', 'retrieve_similar').
required **kwargs
Additional parameters for the structuring process.
{}
Returns:
Name Type Description Any
List[MemoryUnit]
The retrieved memory data.
Raises:
Type Description RetrievalError
If the retrieval process fails.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef retrieve(self, query: Any, retrieve_type: str, **kwargs) -> List[MemoryUnit]:\n \"\"\"\n Asynchronously retrieve data from memory based on a query.\n\n Args:\n query (Any): The query or criteria to retrieve specific memory data.\n retrieve_type (str): The type of retrieval (e.g., 'retrieve_related', 'retrieve_similar').\n **kwargs: Additional parameters for the structuring process.\n\n Returns:\n Any: The retrieved memory data.\n\n Raises:\n RetrievalError: If the retrieval process fails.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.save","title":"save()
abstractmethod
","text":"Save the memory to database or file. The path is specified in config.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef save(self) -> None:\n \"\"\"\n Save the memory to database or file. The path is specified in config.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.setup","title":"setup()
abstractmethod
","text":"Asynchronously set up the Memory system's components.
This method should initialize any necessary components or resources based on the provided configuration.
Raises:
Type Description ConfigurationError
If the configuration is invalid or incomplete.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef setup(self) -> None:\n \"\"\"\n Asynchronously set up the Memory system's components.\n\n This method should initialize any necessary components or resources based on the provided configuration.\n\n Raises:\n ConfigurationError: If the configuration is invalid or incomplete.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.skillize","title":"skillize(unit_ids, skill_name, **kwargs)
abstractmethod
","text":"Converts memory units into a reusable skill.
Parameters:
Name Type Description Default unit_ids
List[str]
A list of memory unit IDs to skillize.
required skill_name
str
The name of the skill to create.
required **kwargs
Additional parameters for skill creation.
{}
Returns:
Name Type Description str
str
The unique identifier of the created skill.
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef skillize(self, unit_ids: List[str], skill_name: str, **kwargs) -> str:\n \"\"\"\n Converts memory units into a reusable skill.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to skillize.\n skill_name (str): The name of the skill to create.\n **kwargs: Additional parameters for skill creation.\n\n Returns:\n str: The unique identifier of the created skill.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.structurize","title":"structurize(unit_ids, structure_type, **kwargs)
abstractmethod
","text":"Structures memory units into a knowledge graph or other structures.
Parameters:
Name Type Description Default unit_ids
List[str]
A list of memory unit IDs to structurize.
required structure_type
str
The type of structure (e.g., 'knowledge_graph').
required **kwargs
Additional parameters for the structuring process.
{}
Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef structurize(self, unit_ids: List[str], structure_type: str, **kwargs) -> None:\n \"\"\"\n Structures memory units into a knowledge graph or other structures.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to structurize.\n structure_type (str): The type of structure (e.g., 'knowledge_graph').\n **kwargs: Additional parameters for the structuring process.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory.Memory.update","title":"update(unit_id, updates)
abstractmethod
","text":"Updates a memory unit with the given updates.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required updates
Dict[str, Any]
A dictionary of fields to update.
required Source code in src/aeiva/cognition/memory/memory.py
@abstractmethod\ndef update(self, unit_id: str, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a memory unit with the given updates.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n updates (Dict[str, Any]): A dictionary of fields to update.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_cleaner","title":"memory_cleaner
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_cleaner.MemoryCleaner","title":"MemoryCleaner
","text":"A class to clean memory units based on various filtering algorithms.
Supported filter types - 'time': Removes memory units older than a specified threshold.
- 'modality': Keeps only memory units matching specified modalities.
- 'type': Keeps only memory units matching specified types.
Source code in src/aeiva/cognition/memory/memory_cleaner.py
class MemoryCleaner:\n \"\"\"\n A class to clean memory units based on various filtering algorithms.\n\n Supported filter types:\n - 'time': Removes memory units older than a specified threshold.\n - 'modality': Keeps only memory units matching specified modalities.\n - 'type': Keeps only memory units matching specified types.\n \"\"\"\n\n def __init__(self):\n \"\"\"\n Initializes the MemoryCleaner.\n\n Currently, no initialization parameters are required.\n \"\"\"\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemoryCleaner without default parameters.\")\n\n def filter(\n self,\n memory_units: List[MemoryUnit],\n filter_type: str,\n **kwargs\n ) -> List[MemoryUnit]:\n \"\"\"\n Filters the provided memory units based on the specified filter type.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be filtered.\n filter_type (str): The type of filtering algorithm to use ('time', 'modality', 'type').\n **kwargs: Additional parameters required for specific filters.\n For 'time' filter:\n - threshold_days (int): Number of days beyond which memory units are removed.\n For 'modality' filter:\n - modalities (List[str]): List of modalities to retain (e.g., ['text', 'image']).\n For 'type' filter:\n - types (List[str]): List of types to retain (e.g., ['dialogue', 'summary']).\n\n Returns:\n List[MemoryUnit]: The list of memory units after filtering.\n\n Raises:\n MemoryCleanerError: If an unknown filter_type is provided or if required parameters are missing.\n \"\"\"\n self.logger.debug(f\"Filtering memory units using filter_type='{filter_type}' with kwargs={kwargs}\")\n try:\n if filter_type == 'time':\n threshold_days = kwargs.get('threshold_days')\n if threshold_days is None:\n self.logger.error(\"Missing 'threshold_days' parameter for time-based filtering.\")\n raise MemoryCleanerError(\"Missing 'threshold_days' parameter for time-based filtering.\")\n return self.filter_by_time(memory_units, threshold_days)\n elif filter_type == 'modality':\n modalities = kwargs.get('modalities')\n if not modalities:\n self.logger.error(\"Missing 'modalities' parameter for modality-based filtering.\")\n raise MemoryCleanerError(\"Missing 'modalities' parameter for modality-based filtering.\")\n return self.filter_by_modality(memory_units, modalities)\n elif filter_type == 'type':\n types = kwargs.get('types')\n if not types:\n self.logger.error(\"Missing 'types' parameter for type-based filtering.\")\n raise MemoryCleanerError(\"Missing 'types' parameter for type-based filtering.\")\n return self.filter_by_type(memory_units, types)\n else:\n self.logger.error(f\"Unknown filter_type: {filter_type}\")\n raise MemoryCleanerError(f\"Unknown filter_type: {filter_type}\")\n except MemoryCleanerError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to filter memory units: {e}\")\n raise MemoryCleanerError(f\"Failed to filter memory units: {e}\")\n # TODO: more filter options\n\n def filter_by_time(self, memory_units: List[MemoryUnit], threshold_days: int) -> List[MemoryUnit]:\n \"\"\"\n Removes memory units older than the specified threshold_days.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be filtered.\n threshold_days (int): Number of days beyond which memory units are removed.\n\n Returns:\n List[MemoryUnit]: The list of memory units after time-based filtering.\n\n Raises:\n MemoryCleanerError: If filtering fails.\n \"\"\"\n self.logger.debug(f\"Applying time-based filtering with threshold_days={threshold_days}\")\n try:\n current_time = datetime.now(UTC)\n threshold = timedelta(days=threshold_days)\n filtered_memory = [\n mu for mu in memory_units\n if (current_time - mu.timestamp) <= threshold\n ]\n removed_count = len(memory_units) - len(filtered_memory)\n self.logger.info(\n f\"Time-based filter: Removed {removed_count} memory units older than {threshold_days} days.\"\n )\n return filtered_memory\n except Exception as e:\n self.logger.error(f\"Time-based filtering failed: {e}\")\n raise MemoryCleanerError(f\"Time-based filtering failed: {e}\")\n\n def filter_by_modality(self, memory_units: List[MemoryUnit], modalities: List[str]) -> List[MemoryUnit]:\n \"\"\"\n Keeps only memory units that match the specified modalities.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be filtered.\n modalities (List[str]): List of modalities to retain (e.g., ['text', 'image']).\n\n Returns:\n List[MemoryUnit]: The list of memory units after modality-based filtering.\n\n Raises:\n MemoryCleanerError: If filtering fails.\n \"\"\"\n self.logger.debug(f\"Applying modality-based filtering with modalities={modalities}\")\n try:\n if not modalities:\n self.logger.warning(\"No modalities specified for modality-based filtering. Returning original memory units.\")\n return memory_units\n\n filtered_memory = [\n mu for mu in memory_units\n if mu.modality in modalities\n ]\n removed_count = len(memory_units) - len(filtered_memory)\n self.logger.info(\n f\"Modality-based filter: Removed {removed_count} memory units not in modalities {modalities}.\"\n )\n return filtered_memory\n except Exception as e:\n self.logger.error(f\"Modality-based filtering failed: {e}\")\n raise MemoryCleanerError(f\"Modality-based filtering failed: {e}\")\n\n def filter_by_type(self, memory_units: List[MemoryUnit], types: List[str]) -> List[MemoryUnit]:\n \"\"\"\n Keeps only memory units that match the specified types.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be filtered.\n types (List[str]): List of types to retain (e.g., ['dialogue', 'summary']).\n\n Returns:\n List[MemoryUnit]: The list of memory units after type-based filtering.\n\n Raises:\n MemoryCleanerError: If filtering fails.\n \"\"\"\n self.logger.debug(f\"Applying type-based filtering with types={types}\")\n try:\n if not types:\n self.logger.warning(\"No types specified for type-based filtering. Returning original memory units.\")\n return memory_units\n\n filtered_memory = [\n mu for mu in memory_units\n if mu.type in types\n ]\n removed_count = len(memory_units) - len(filtered_memory)\n self.logger.info(\n f\"Type-based filter: Removed {removed_count} memory units not in types {types}.\"\n )\n return filtered_memory\n except Exception as e:\n self.logger.error(f\"Type-based filtering failed: {e}\")\n raise MemoryCleanerError(f\"Type-based filtering failed: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_cleaner.MemoryCleaner.__init__","title":"__init__()
","text":"Initializes the MemoryCleaner.
Currently, no initialization parameters are required.
Source code in src/aeiva/cognition/memory/memory_cleaner.py
def __init__(self):\n \"\"\"\n Initializes the MemoryCleaner.\n\n Currently, no initialization parameters are required.\n \"\"\"\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemoryCleaner without default parameters.\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_cleaner.MemoryCleaner.filter","title":"filter(memory_units, filter_type, **kwargs)
","text":"Filters the provided memory units based on the specified filter type.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be filtered.
required filter_type
str
The type of filtering algorithm to use ('time', 'modality', 'type').
required **kwargs
Additional parameters required for specific filters. For 'time' filter: - threshold_days (int): Number of days beyond which memory units are removed. For 'modality' filter: - modalities (List[str]): List of modalities to retain (e.g., ['text', 'image']). For 'type' filter: - types (List[str]): List of types to retain (e.g., ['dialogue', 'summary']).
{}
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The list of memory units after filtering.
Raises:
Type Description MemoryCleanerError
If an unknown filter_type is provided or if required parameters are missing.
Source code in src/aeiva/cognition/memory/memory_cleaner.py
def filter(\n self,\n memory_units: List[MemoryUnit],\n filter_type: str,\n **kwargs\n) -> List[MemoryUnit]:\n \"\"\"\n Filters the provided memory units based on the specified filter type.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be filtered.\n filter_type (str): The type of filtering algorithm to use ('time', 'modality', 'type').\n **kwargs: Additional parameters required for specific filters.\n For 'time' filter:\n - threshold_days (int): Number of days beyond which memory units are removed.\n For 'modality' filter:\n - modalities (List[str]): List of modalities to retain (e.g., ['text', 'image']).\n For 'type' filter:\n - types (List[str]): List of types to retain (e.g., ['dialogue', 'summary']).\n\n Returns:\n List[MemoryUnit]: The list of memory units after filtering.\n\n Raises:\n MemoryCleanerError: If an unknown filter_type is provided or if required parameters are missing.\n \"\"\"\n self.logger.debug(f\"Filtering memory units using filter_type='{filter_type}' with kwargs={kwargs}\")\n try:\n if filter_type == 'time':\n threshold_days = kwargs.get('threshold_days')\n if threshold_days is None:\n self.logger.error(\"Missing 'threshold_days' parameter for time-based filtering.\")\n raise MemoryCleanerError(\"Missing 'threshold_days' parameter for time-based filtering.\")\n return self.filter_by_time(memory_units, threshold_days)\n elif filter_type == 'modality':\n modalities = kwargs.get('modalities')\n if not modalities:\n self.logger.error(\"Missing 'modalities' parameter for modality-based filtering.\")\n raise MemoryCleanerError(\"Missing 'modalities' parameter for modality-based filtering.\")\n return self.filter_by_modality(memory_units, modalities)\n elif filter_type == 'type':\n types = kwargs.get('types')\n if not types:\n self.logger.error(\"Missing 'types' parameter for type-based filtering.\")\n raise MemoryCleanerError(\"Missing 'types' parameter for type-based filtering.\")\n return self.filter_by_type(memory_units, types)\n else:\n self.logger.error(f\"Unknown filter_type: {filter_type}\")\n raise MemoryCleanerError(f\"Unknown filter_type: {filter_type}\")\n except MemoryCleanerError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to filter memory units: {e}\")\n raise MemoryCleanerError(f\"Failed to filter memory units: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_cleaner.MemoryCleaner.filter_by_modality","title":"filter_by_modality(memory_units, modalities)
","text":"Keeps only memory units that match the specified modalities.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be filtered.
required modalities
List[str]
List of modalities to retain (e.g., ['text', 'image']).
required Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The list of memory units after modality-based filtering.
Raises:
Type Description MemoryCleanerError
If filtering fails.
Source code in src/aeiva/cognition/memory/memory_cleaner.py
def filter_by_modality(self, memory_units: List[MemoryUnit], modalities: List[str]) -> List[MemoryUnit]:\n \"\"\"\n Keeps only memory units that match the specified modalities.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be filtered.\n modalities (List[str]): List of modalities to retain (e.g., ['text', 'image']).\n\n Returns:\n List[MemoryUnit]: The list of memory units after modality-based filtering.\n\n Raises:\n MemoryCleanerError: If filtering fails.\n \"\"\"\n self.logger.debug(f\"Applying modality-based filtering with modalities={modalities}\")\n try:\n if not modalities:\n self.logger.warning(\"No modalities specified for modality-based filtering. Returning original memory units.\")\n return memory_units\n\n filtered_memory = [\n mu for mu in memory_units\n if mu.modality in modalities\n ]\n removed_count = len(memory_units) - len(filtered_memory)\n self.logger.info(\n f\"Modality-based filter: Removed {removed_count} memory units not in modalities {modalities}.\"\n )\n return filtered_memory\n except Exception as e:\n self.logger.error(f\"Modality-based filtering failed: {e}\")\n raise MemoryCleanerError(f\"Modality-based filtering failed: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_cleaner.MemoryCleaner.filter_by_time","title":"filter_by_time(memory_units, threshold_days)
","text":"Removes memory units older than the specified threshold_days.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be filtered.
required threshold_days
int
Number of days beyond which memory units are removed.
required Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The list of memory units after time-based filtering.
Raises:
Type Description MemoryCleanerError
If filtering fails.
Source code in src/aeiva/cognition/memory/memory_cleaner.py
def filter_by_time(self, memory_units: List[MemoryUnit], threshold_days: int) -> List[MemoryUnit]:\n \"\"\"\n Removes memory units older than the specified threshold_days.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be filtered.\n threshold_days (int): Number of days beyond which memory units are removed.\n\n Returns:\n List[MemoryUnit]: The list of memory units after time-based filtering.\n\n Raises:\n MemoryCleanerError: If filtering fails.\n \"\"\"\n self.logger.debug(f\"Applying time-based filtering with threshold_days={threshold_days}\")\n try:\n current_time = datetime.now(UTC)\n threshold = timedelta(days=threshold_days)\n filtered_memory = [\n mu for mu in memory_units\n if (current_time - mu.timestamp) <= threshold\n ]\n removed_count = len(memory_units) - len(filtered_memory)\n self.logger.info(\n f\"Time-based filter: Removed {removed_count} memory units older than {threshold_days} days.\"\n )\n return filtered_memory\n except Exception as e:\n self.logger.error(f\"Time-based filtering failed: {e}\")\n raise MemoryCleanerError(f\"Time-based filtering failed: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_cleaner.MemoryCleaner.filter_by_type","title":"filter_by_type(memory_units, types)
","text":"Keeps only memory units that match the specified types.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be filtered.
required types
List[str]
List of types to retain (e.g., ['dialogue', 'summary']).
required Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The list of memory units after type-based filtering.
Raises:
Type Description MemoryCleanerError
If filtering fails.
Source code in src/aeiva/cognition/memory/memory_cleaner.py
def filter_by_type(self, memory_units: List[MemoryUnit], types: List[str]) -> List[MemoryUnit]:\n \"\"\"\n Keeps only memory units that match the specified types.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be filtered.\n types (List[str]): List of types to retain (e.g., ['dialogue', 'summary']).\n\n Returns:\n List[MemoryUnit]: The list of memory units after type-based filtering.\n\n Raises:\n MemoryCleanerError: If filtering fails.\n \"\"\"\n self.logger.debug(f\"Applying type-based filtering with types={types}\")\n try:\n if not types:\n self.logger.warning(\"No types specified for type-based filtering. Returning original memory units.\")\n return memory_units\n\n filtered_memory = [\n mu for mu in memory_units\n if mu.type in types\n ]\n removed_count = len(memory_units) - len(filtered_memory)\n self.logger.info(\n f\"Type-based filter: Removed {removed_count} memory units not in types {types}.\"\n )\n return filtered_memory\n except Exception as e:\n self.logger.error(f\"Type-based filtering failed: {e}\")\n raise MemoryCleanerError(f\"Type-based filtering failed: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_cleaner.MemoryCleanerError","title":"MemoryCleanerError
","text":" Bases: Exception
Exception raised when an error occurs in the MemoryCleaner.
Source code in src/aeiva/cognition/memory/memory_cleaner.py
class MemoryCleanerError(Exception):\n \"\"\"Exception raised when an error occurs in the MemoryCleaner.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_config","title":"memory_config
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_config.MemoryConfig","title":"MemoryConfig
dataclass
","text":" Bases: BaseConfig
Configuration class for the Memory system.
Attributes:
Name Type Description embedder_config
EmbedderConfig
Configuration for the embedding model.
storage_config
StorageConfig
Configuration for the storage system.
Source code in src/aeiva/cognition/memory/memory_config.py
@dataclass\nclass MemoryConfig(BaseConfig):\n \"\"\"\n Configuration class for the Memory system.\n\n Attributes:\n embedder_config (EmbedderConfig): Configuration for the embedding model.\n storage_config (StorageConfig): Configuration for the storage system.\n \"\"\"\n\n embedder_config: EmbedderConfig = field(\n metadata={\"help\": \"Configuration for the embedding model.\"}\n )\n storage_config: StorageConfig = field(\n metadata={\"help\": \"Configuration for the storage system.\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n # Perform any necessary validation\n if not self.embedder_config:\n raise ValueError(\"Embedder configuration must be provided.\")\n if not self.storage_config:\n raise ValueError(\"Storage configuration must be provided.\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_link","title":"memory_link
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_link.MemoryLink","title":"MemoryLink
","text":" Bases: BaseModel
MemoryLink represents a relationship between two memory units, allowing complex structures to be built by linking individual memory units.
Attributes:
Name Type Description id
str
Unique identifier for the edge, generated as a UUID string by default.
source_id
str
Unique identifier of the source memory unit.
target_id
str
Unique identifier of the target memory unit.
relationship
str
Type of relationship between memory units, such as 'causal' or 'association'.
metadata
Optional[Dict[str, Any]]
Additional metadata for the edge.
Source code in src/aeiva/cognition/memory/memory_link.py
class MemoryLink(BaseModel):\n \"\"\"\n MemoryLink represents a relationship between two memory units, allowing\n complex structures to be built by linking individual memory units.\n\n Attributes:\n id (str): Unique identifier for the edge, generated as a UUID string by default.\n source_id (str): Unique identifier of the source memory unit.\n target_id (str): Unique identifier of the target memory unit.\n relationship (str): Type of relationship between memory units, such as 'causal' or 'association'.\n metadata (Optional[Dict[str, Any]]): Additional metadata for the edge.\n \"\"\"\n id: str = Field(default_factory=lambda: uuid4().hex, description=\"Unique identifier for the edge.\")\n source_id: str = Field(..., description=\"Unique identifier of the source memory unit.\")\n target_id: str = Field(..., description=\"Unique identifier of the target memory unit.\")\n relationship: str = Field(\"\", description=\"Type of relationship, e.g., 'causal', 'temporal'.\")\n metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description=\"Additional metadata for the edge.\")\n\n def to_dict(self) -> dict:\n \"\"\"Converts the MemoryLink instance to a dictionary format for serialization.\"\"\"\n return self.dict()\n\n @classmethod\n def from_dict(cls, data: dict) -> \"MemoryLink\":\n \"\"\"Creates a MemoryLink instance from a dictionary.\"\"\"\n return cls(**data)\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_link.MemoryLink.from_dict","title":"from_dict(data)
classmethod
","text":"Creates a MemoryLink instance from a dictionary.
Source code in src/aeiva/cognition/memory/memory_link.py
@classmethod\ndef from_dict(cls, data: dict) -> \"MemoryLink\":\n \"\"\"Creates a MemoryLink instance from a dictionary.\"\"\"\n return cls(**data)\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_link.MemoryLink.to_dict","title":"to_dict()
","text":"Converts the MemoryLink instance to a dictionary format for serialization.
Source code in src/aeiva/cognition/memory/memory_link.py
def to_dict(self) -> dict:\n \"\"\"Converts the MemoryLink instance to a dictionary format for serialization.\"\"\"\n return self.dict()\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_organizer","title":"memory_organizer
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_organizer.MemoryOrganizer","title":"MemoryOrganizer
","text":"A class to organize memory units based on various organizing algorithms.
Supported organize types - 'dialogue': Groups memory units by 'dialogue_session_id'.
Source code in src/aeiva/cognition/memory/memory_organizer.py
class MemoryOrganizer:\n \"\"\"\n A class to organize memory units based on various organizing algorithms.\n\n Supported organize types:\n - 'dialogue': Groups memory units by 'dialogue_session_id'.\n # Future organize types can be added here.\n \"\"\"\n\n def __init__(self):\n \"\"\"\n Initializes the MemoryOrganizer.\n\n Currently, no initialization parameters are required.\n \"\"\"\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemoryOrganizer without default parameters.\")\n\n def organize(\n self,\n memory_units: List[MemoryUnit],\n organize_type: str,\n **kwargs\n ) -> List[MemoryUnit]:\n \"\"\"\n Organizes the provided memory units based on the specified organize type.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be organized.\n organize_type (str): The type of organizing algorithm to use ('dialogue').\n **kwargs: Additional parameters required for specific organizers.\n For 'dialogue' organize:\n - group_field (str): The metadata field to group by (default: 'dialogue_session_id').\n - derive_content (bool): Whether to derive content for the group (default: True).\n - derivation_type (str): The type of derivation to perform ('summary', etc.).\n\n Returns:\n List[MemoryUnit]: The list of memory units after organizing.\n\n Raises:\n MemoryOrganizerError: If an unknown organize_type is provided or if required parameters are missing.\n \"\"\"\n self.logger.debug(f\"Organizing memory units using organize_type='{organize_type}' with kwargs={kwargs}\")\n try:\n if organize_type == 'dialogue':\n group_field = kwargs.get('group_field', 'dialogue_session_id')\n derive_content = kwargs.get('derive_content', True)\n derivation_type = kwargs.get('derivation_type', 'summary')\n return self.organize_by_dialogue(memory_units, group_field, derive_content, derivation_type)\n else:\n self.logger.error(f\"Unknown organize_type: {organize_type}\")\n raise MemoryOrganizerError(f\"Unknown organize_type: {organize_type}\")\n except MemoryOrganizerError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to organize memory units: {e}\")\n raise MemoryOrganizerError(f\"Failed to organize memory units: {e}\")\n\n def organize_by_dialogue(\n self,\n memory_units: List[MemoryUnit],\n group_field: str = 'dialogue_session_id', # NOTE: here we assume the meta data field of dialogue memory units has a dialogue_session_id\n derive_content: bool = False,\n derivation_type: str = 'summary'\n ) -> List[MemoryUnit]:\n \"\"\"\n Organizes memory units into dialogue sessions based on a common 'dialogue_session_id'.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be organized.\n group_field (str): The metadata field to group by (default: 'dialogue_session_id').\n derive_content (bool): Whether to derive content for the group (default: True).\n derivation_type (str): The type of derivation to perform ('summary', etc.).\n\n Returns:\n List[MemoryUnit]: The list of memory units after organizing, including new dialogue groups.\n\n Raises:\n MemoryOrganizerError: If organizing fails.\n \"\"\"\n self.logger.debug(f\"Organizing by dialogue with group_field='{group_field}', derive_content={derive_content}, derivation_type='{derivation_type}'\")\n try:\n # Group memory units by the specified group_field\n groups = defaultdict(list)\n for mu in memory_units:\n group_id = mu.metadata.get(group_field)\n if group_id:\n groups[group_id].append(mu)\n else:\n self.logger.debug(f\"MemoryUnit '{mu.id}' does not have '{group_field}'. Skipping grouping.\")\n\n self.logger.info(f\"Found {len(groups)} dialogue groups based on '{group_field}'.\")\n\n # Create new MemoryUnit for each group\n new_memory_units = []\n for group_id, group_mus in groups.items():\n self.logger.debug(f\"Creating DialogueGroup for group_id='{group_id}' with {len(group_mus)} memory units.\")\n\n # Create a new MemoryUnit to represent the DialogueGroup\n dialogue_group = MemoryUnit(\n content=\"\", # Content to be derived\n type=\"dialogue_session\",\n metadata={\n \"organized_at\": datetime.now(timezone.utc).isoformat(),\n \"member_ids\": [mu.id for mu in group_mus],\n \"derivation_type\": derivation_type\n }\n )\n\n # Link each memory unit to the DialogueGroup\n for mu in group_mus:\n link = MemoryLink(\n source_id=mu.id,\n target_id=dialogue_group.id,\n relationship='part_of'\n )\n mu.edges.append(link)\n self.logger.debug(f\"Linked MemoryUnit '{mu.id}' to DialogueGroup '{dialogue_group.id}'.\")\n\n # Optionally, derive content for the group\n if derive_content:\n if derivation_type == 'summary':\n derived_content = self.derive_summary(group_mus)\n elif derivation_type == 'reflection':\n derived_content = self.derive_reflection(group_mus)\n else:\n self.logger.warning(f\"Unknown derivation_type '{derivation_type}'. Skipping content derivation.\")\n derived_content = \"\"\n dialogue_group.content = derived_content\n dialogue_group.status = 'derived'\n self.logger.debug(f\"Derived content for DialogueGroup '{dialogue_group.id}': {derived_content}\")\n\n new_memory_units.append(dialogue_group)\n self.logger.info(f\"DialogueGroup '{dialogue_group.id}' created for group_id='{group_id}'.\")\n\n # Return the original memory units plus the new dialogue groups\n organized_memory = memory_units + new_memory_units\n self.logger.debug(f\"Organizing by dialogue completed. Total memory units after organizing: {len(organized_memory)}\")\n return organized_memory\n\n except Exception as e:\n self.logger.error(f\"Error organizing by dialogue: {e}\")\n raise MemoryOrganizerError(f\"Error organizing by dialogue: {e}\")\n\n def derive_summary(self, memory_units: List[MemoryUnit]) -> str: # TODO: replace with lmp implementation\n \"\"\"\n Derives a summary from the given memory units.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to summarize.\n\n Returns:\n str: A summary string.\n \"\"\"\n self.logger.debug(f\"Deriving summary from {len(memory_units)} memory units.\")\n try:\n summary = \"Summary of dialogue session:\\n\"\n for mu in memory_units:\n summary += f\"- {mu.content}\\n\"\n derived_summary = summary.strip()\n self.logger.debug(f\"Derived summary: {derived_summary}\")\n return derived_summary\n except Exception as e:\n self.logger.error(f\"Failed to derive summary: {e}\")\n raise MemoryOrganizerError(f\"Failed to derive summary: {e}\")\n\n def derive_reflection(self, memory_units: List[MemoryUnit]) -> str: # TODO: replace with lmp implementation\n \"\"\"\n Derives a reflection from the given memory units.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to reflect upon.\n\n Returns:\n str: A reflection string.\n \"\"\"\n self.logger.debug(f\"Deriving reflection from {len(memory_units)} memory units.\")\n try:\n reflection = \"Reflection on dialogue session:\\n\"\n for mu in memory_units:\n reflection += f\"- {mu.content}\\n\"\n derived_reflection = reflection.strip()\n self.logger.debug(f\"Derived reflection: {derived_reflection}\")\n return derived_reflection\n except Exception as e:\n self.logger.error(f\"Failed to derive reflection: {e}\")\n raise MemoryOrganizerError(f\"Failed to derive reflection: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_organizer.MemoryOrganizer--future-organize-types-can-be-added-here","title":"Future organize types can be added here.","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_organizer.MemoryOrganizer.__init__","title":"__init__()
","text":"Initializes the MemoryOrganizer.
Currently, no initialization parameters are required.
Source code in src/aeiva/cognition/memory/memory_organizer.py
def __init__(self):\n \"\"\"\n Initializes the MemoryOrganizer.\n\n Currently, no initialization parameters are required.\n \"\"\"\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemoryOrganizer without default parameters.\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_organizer.MemoryOrganizer.derive_reflection","title":"derive_reflection(memory_units)
","text":"Derives a reflection from the given memory units.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to reflect upon.
required Returns:
Name Type Description str
str
A reflection string.
Source code in src/aeiva/cognition/memory/memory_organizer.py
def derive_reflection(self, memory_units: List[MemoryUnit]) -> str: # TODO: replace with lmp implementation\n \"\"\"\n Derives a reflection from the given memory units.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to reflect upon.\n\n Returns:\n str: A reflection string.\n \"\"\"\n self.logger.debug(f\"Deriving reflection from {len(memory_units)} memory units.\")\n try:\n reflection = \"Reflection on dialogue session:\\n\"\n for mu in memory_units:\n reflection += f\"- {mu.content}\\n\"\n derived_reflection = reflection.strip()\n self.logger.debug(f\"Derived reflection: {derived_reflection}\")\n return derived_reflection\n except Exception as e:\n self.logger.error(f\"Failed to derive reflection: {e}\")\n raise MemoryOrganizerError(f\"Failed to derive reflection: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_organizer.MemoryOrganizer.derive_summary","title":"derive_summary(memory_units)
","text":"Derives a summary from the given memory units.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to summarize.
required Returns:
Name Type Description str
str
A summary string.
Source code in src/aeiva/cognition/memory/memory_organizer.py
def derive_summary(self, memory_units: List[MemoryUnit]) -> str: # TODO: replace with lmp implementation\n \"\"\"\n Derives a summary from the given memory units.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to summarize.\n\n Returns:\n str: A summary string.\n \"\"\"\n self.logger.debug(f\"Deriving summary from {len(memory_units)} memory units.\")\n try:\n summary = \"Summary of dialogue session:\\n\"\n for mu in memory_units:\n summary += f\"- {mu.content}\\n\"\n derived_summary = summary.strip()\n self.logger.debug(f\"Derived summary: {derived_summary}\")\n return derived_summary\n except Exception as e:\n self.logger.error(f\"Failed to derive summary: {e}\")\n raise MemoryOrganizerError(f\"Failed to derive summary: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_organizer.MemoryOrganizer.organize","title":"organize(memory_units, organize_type, **kwargs)
","text":"Organizes the provided memory units based on the specified organize type.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be organized.
required organize_type
str
The type of organizing algorithm to use ('dialogue').
required **kwargs
Additional parameters required for specific organizers. For 'dialogue' organize: - group_field (str): The metadata field to group by (default: 'dialogue_session_id'). - derive_content (bool): Whether to derive content for the group (default: True). - derivation_type (str): The type of derivation to perform ('summary', etc.).
{}
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The list of memory units after organizing.
Raises:
Type Description MemoryOrganizerError
If an unknown organize_type is provided or if required parameters are missing.
Source code in src/aeiva/cognition/memory/memory_organizer.py
def organize(\n self,\n memory_units: List[MemoryUnit],\n organize_type: str,\n **kwargs\n) -> List[MemoryUnit]:\n \"\"\"\n Organizes the provided memory units based on the specified organize type.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be organized.\n organize_type (str): The type of organizing algorithm to use ('dialogue').\n **kwargs: Additional parameters required for specific organizers.\n For 'dialogue' organize:\n - group_field (str): The metadata field to group by (default: 'dialogue_session_id').\n - derive_content (bool): Whether to derive content for the group (default: True).\n - derivation_type (str): The type of derivation to perform ('summary', etc.).\n\n Returns:\n List[MemoryUnit]: The list of memory units after organizing.\n\n Raises:\n MemoryOrganizerError: If an unknown organize_type is provided or if required parameters are missing.\n \"\"\"\n self.logger.debug(f\"Organizing memory units using organize_type='{organize_type}' with kwargs={kwargs}\")\n try:\n if organize_type == 'dialogue':\n group_field = kwargs.get('group_field', 'dialogue_session_id')\n derive_content = kwargs.get('derive_content', True)\n derivation_type = kwargs.get('derivation_type', 'summary')\n return self.organize_by_dialogue(memory_units, group_field, derive_content, derivation_type)\n else:\n self.logger.error(f\"Unknown organize_type: {organize_type}\")\n raise MemoryOrganizerError(f\"Unknown organize_type: {organize_type}\")\n except MemoryOrganizerError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to organize memory units: {e}\")\n raise MemoryOrganizerError(f\"Failed to organize memory units: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_organizer.MemoryOrganizer.organize_by_dialogue","title":"organize_by_dialogue(memory_units, group_field='dialogue_session_id', derive_content=False, derivation_type='summary')
","text":"Organizes memory units into dialogue sessions based on a common 'dialogue_session_id'.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be organized.
required group_field
str
The metadata field to group by (default: 'dialogue_session_id').
'dialogue_session_id'
derive_content
bool
Whether to derive content for the group (default: True).
False
derivation_type
str
The type of derivation to perform ('summary', etc.).
'summary'
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The list of memory units after organizing, including new dialogue groups.
Raises:
Type Description MemoryOrganizerError
If organizing fails.
Source code in src/aeiva/cognition/memory/memory_organizer.py
def organize_by_dialogue(\n self,\n memory_units: List[MemoryUnit],\n group_field: str = 'dialogue_session_id', # NOTE: here we assume the meta data field of dialogue memory units has a dialogue_session_id\n derive_content: bool = False,\n derivation_type: str = 'summary'\n) -> List[MemoryUnit]:\n \"\"\"\n Organizes memory units into dialogue sessions based on a common 'dialogue_session_id'.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be organized.\n group_field (str): The metadata field to group by (default: 'dialogue_session_id').\n derive_content (bool): Whether to derive content for the group (default: True).\n derivation_type (str): The type of derivation to perform ('summary', etc.).\n\n Returns:\n List[MemoryUnit]: The list of memory units after organizing, including new dialogue groups.\n\n Raises:\n MemoryOrganizerError: If organizing fails.\n \"\"\"\n self.logger.debug(f\"Organizing by dialogue with group_field='{group_field}', derive_content={derive_content}, derivation_type='{derivation_type}'\")\n try:\n # Group memory units by the specified group_field\n groups = defaultdict(list)\n for mu in memory_units:\n group_id = mu.metadata.get(group_field)\n if group_id:\n groups[group_id].append(mu)\n else:\n self.logger.debug(f\"MemoryUnit '{mu.id}' does not have '{group_field}'. Skipping grouping.\")\n\n self.logger.info(f\"Found {len(groups)} dialogue groups based on '{group_field}'.\")\n\n # Create new MemoryUnit for each group\n new_memory_units = []\n for group_id, group_mus in groups.items():\n self.logger.debug(f\"Creating DialogueGroup for group_id='{group_id}' with {len(group_mus)} memory units.\")\n\n # Create a new MemoryUnit to represent the DialogueGroup\n dialogue_group = MemoryUnit(\n content=\"\", # Content to be derived\n type=\"dialogue_session\",\n metadata={\n \"organized_at\": datetime.now(timezone.utc).isoformat(),\n \"member_ids\": [mu.id for mu in group_mus],\n \"derivation_type\": derivation_type\n }\n )\n\n # Link each memory unit to the DialogueGroup\n for mu in group_mus:\n link = MemoryLink(\n source_id=mu.id,\n target_id=dialogue_group.id,\n relationship='part_of'\n )\n mu.edges.append(link)\n self.logger.debug(f\"Linked MemoryUnit '{mu.id}' to DialogueGroup '{dialogue_group.id}'.\")\n\n # Optionally, derive content for the group\n if derive_content:\n if derivation_type == 'summary':\n derived_content = self.derive_summary(group_mus)\n elif derivation_type == 'reflection':\n derived_content = self.derive_reflection(group_mus)\n else:\n self.logger.warning(f\"Unknown derivation_type '{derivation_type}'. Skipping content derivation.\")\n derived_content = \"\"\n dialogue_group.content = derived_content\n dialogue_group.status = 'derived'\n self.logger.debug(f\"Derived content for DialogueGroup '{dialogue_group.id}': {derived_content}\")\n\n new_memory_units.append(dialogue_group)\n self.logger.info(f\"DialogueGroup '{dialogue_group.id}' created for group_id='{group_id}'.\")\n\n # Return the original memory units plus the new dialogue groups\n organized_memory = memory_units + new_memory_units\n self.logger.debug(f\"Organizing by dialogue completed. Total memory units after organizing: {len(organized_memory)}\")\n return organized_memory\n\n except Exception as e:\n self.logger.error(f\"Error organizing by dialogue: {e}\")\n raise MemoryOrganizerError(f\"Error organizing by dialogue: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_organizer.MemoryOrganizerError","title":"MemoryOrganizerError
","text":" Bases: Exception
Exception raised when an error occurs in the MemoryOrganizer.
Source code in src/aeiva/cognition/memory/memory_organizer.py
class MemoryOrganizerError(Exception):\n \"\"\"Exception raised when an error occurs in the MemoryOrganizer.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace","title":"memory_palace
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace","title":"MemoryPalace
","text":" Bases: Memory
Concrete implementation of the Memory abstract base class.
This class provides methods to manage memory units, including creation, retrieval, updating, deletion, filtering, grouping, structurizing, skillizing, parameterizing, and more. It delegates specific operations to specialized components like MemoryCleaner, MemoryOrganizer, MemoryRetriever, MemoryStructurer, MemorySkillizer, and MemoryParameterizer.
Source code in src/aeiva/cognition/memory/memory_palace.py
class MemoryPalace(Memory):\n \"\"\"\n Concrete implementation of the Memory abstract base class.\n\n This class provides methods to manage memory units, including creation, retrieval,\n updating, deletion, filtering, grouping, structurizing, skillizing, parameterizing,\n and more. It delegates specific operations to specialized components like\n MemoryCleaner, MemoryOrganizer, MemoryRetriever, MemoryStructurer, MemorySkillizer,\n and MemoryParameterizer.\n \"\"\"\n\n def __init__(self, config: Dict):\n \"\"\"\n Initialize the MemoryPalace with the provided configuration.\n\n Args:\n config (MemoryConfig): Configuration settings for the MemoryPalace.\n \"\"\"\n self.config_dict = config\n self.config = None\n self.storage = None\n self.embedder = None\n self.cleaner = None\n self.organizer = None\n self.retriever = None\n self.structurer = None\n self.skillizer = None\n self.parameterizer = None\n self.setup()\n\n def setup(self):\n \"\"\"\n Setup the MemoryPalace by initializing all components.\n \"\"\"\n try:\n # Initialize EmbedderConfig\n embedder_config_dict = self.config_dict.get('embedder_config', {})\n self.embedder = Embedder(embedder_config_dict)\n\n storage_config_dict = self.config_dict.get('storage_config', {})\n self.storage = MemoryStorage(storage_config_dict) \n\n # Initialize Memory Configuration\n self.config = MemoryConfig(\n embedder_config=self.embedder.config,\n storage_config=self.storage.config\n )\n\n logger.info(\"MemoryPalace: MemoryStorage and Embedder initialized successfully.\")\n\n # Initialize specialized components\n self.cleaner = MemoryCleaner()\n self.organizer = MemoryOrganizer()\n self.retriever = MemoryRetriever(embedder=self.embedder, storage=self.storage)\n self.structurer = MemoryStructurer()\n self.skillizer = MemorySkillizer()\n self.parameterizer = MemoryParameterizer()\n logger.info(\"MemoryPalace: Specialized components initialized successfully.\")\n\n except Exception as e:\n logger.error(f\"MemoryPalace setup failed: {e}\")\n self.handle_error(e)\n raise\n\n # CRUD Operations\n\n def create(self, content: Any, **kwargs) -> MemoryUnit:\n \"\"\"\n Creates a new memory unit with the given content and metadata.\n\n Args:\n content (Any): The core content of the memory unit.\n **kwargs: Additional metadata for the memory unit.\n\n Returns:\n MemoryUnit: The created memory unit.\n \"\"\"\n try:\n # Instantiate MemoryUnit\n memory_unit = MemoryUnit(content=content, **kwargs)\n\n # Generate embedding\n embedding_response = self.embedder.embed(content)\n if embedding_response.get(\"data\"):\n memory_unit.embedding = embedding_response[\"data\"][0].get(\"embedding\")\n else:\n raise ValueError(\"Failed to generate embedding for the content.\")\n\n # Delegate storage operations to MemoryStorage\n self.storage.add_memory_unit(memory_unit)\n\n logger.info(f\"Created new MemoryUnit with ID: {memory_unit.id}\")\n return memory_unit\n except Exception as e:\n logger.error(f\"Error creating MemoryUnit: {e}\")\n self.handle_error(e)\n raise\n\n def get(self, unit_id: str) -> MemoryUnit:\n \"\"\"\n Retrieves a memory unit by its unique identifier.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n\n Returns:\n MemoryUnit: The retrieved memory unit.\n \"\"\"\n try:\n memory_unit = self.storage.get_memory_unit(unit_id)\n logger.info(f\"Retrieved MemoryUnit with ID: {unit_id}\")\n return memory_unit\n except Exception as e:\n logger.error(f\"Error retrieving MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n\n def update(self, unit_id: str, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a memory unit with the given updates.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n updates (Dict[str, Any]): A dictionary of fields to update.\n \"\"\"\n try:\n # Delegate update operations to MemoryStorage\n self.storage.update_memory_unit(unit_id, updates)\n logger.info(f\"Updated MemoryUnit with ID: {unit_id}\")\n except Exception as e:\n logger.error(f\"Error updating MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n\n def delete(self, unit_id: str) -> None:\n \"\"\"\n Deletes a memory unit by its unique identifier.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n try:\n # Delegate deletion to MemoryStorage\n self.storage.delete_memory_unit(unit_id)\n logger.info(f\"Deleted MemoryUnit with ID: {unit_id}\")\n except Exception as e:\n logger.error(f\"Error deleting MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n\n def get_all(self) -> List[MemoryUnit]:\n \"\"\"\n Retrieves all memory units.\n\n Returns:\n List[MemoryUnit]: A list of all memory units.\n \"\"\"\n try:\n memory_units = self.storage.get_all_memory_units()\n logger.info(f\"Retrieved all MemoryUnits. Total count: {len(memory_units)}\")\n return memory_units\n except Exception as e:\n logger.error(f\"Error retrieving all MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n\n def delete_all(self) -> None:\n \"\"\"\n Deletes all memory units.\n \"\"\"\n try:\n self.storage.delete_all_memory_units() # TODO: seems no work correctly, need to check\n logger.info(\"Deleted all MemoryUnits.\")\n except Exception as e:\n logger.error(f\"Error deleting all MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n\n def load(self) -> List[MemoryUnit]:\n \"\"\"\n Loads all memory units from the storage.\n\n Returns:\n List[MemoryUnit]: A list of all loaded memory units.\n \"\"\"\n try:\n # Retrieve all memory units from storage\n memory_units = self.get_all()\n logger.info(f\"Loaded {len(memory_units)} MemoryUnits from storage.\")\n return memory_units\n except Exception as e:\n logger.error(f\"Error loading MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n\n def save(self, export_path: Optional[str] = None) -> None:\n \"\"\"\n Saves all memory units to the storage or exports them to a specified path.\n\n Args:\n export_path (Optional[str]): The file path to export memory units as JSON.\n If None, saves are handled by MemoryStorage.\n \"\"\"\n try:\n if export_path:\n # Export memory units to a JSON file\n memory_units = self.get_all()\n export_data = [mu.to_dict() for mu in memory_units]\n with open(export_path, 'w', encoding='utf-8') as f:\n json.dump(export_data, f, ensure_ascii=False, indent=4)\n logger.info(f\"Exported {len(memory_units)} MemoryUnits to {export_path}.\")\n else:\n # If no export path is provided, assume that MemoryStorage handles persistence\n logger.info(\"Save operation delegated to MemoryStorage.\")\n # Example: self.storage.persist_changes()\n except Exception as e:\n logger.error(f\"Error saving MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n\n # Delegated Operations\n\n def filter(self, criteria: Dict[str, Any]) -> List[MemoryUnit]:\n \"\"\"\n Filters memory units based on the given criteria.\n\n Args:\n criteria (Dict[str, Any]): A dictionary of filter conditions.\n\n Returns:\n List[MemoryUnit]: A list of memory units matching the criteria.\n \"\"\"\n try:\n memory_units = self.get_all()\n filter_type = criteria.get('filter_type')\n if not filter_type:\n raise ValueError(\"Missing 'filter_type' in criteria.\")\n\n # Delegate filtering to MemoryCleaner\n filtered_memories = self.cleaner.filter(memory_units, filter_type, **criteria)\n logger.info(f\"Filtered memories based on criteria: {criteria}\")\n return filtered_memories\n except Exception as e:\n logger.error(f\"Error filtering memories: {e}\")\n self.handle_error(e)\n raise\n\n def organize(self, unit_ids: List[str], organize_type: str, metadata: Optional[Dict[str, Any]] = None) -> str:\n \"\"\"\n Groups memory units into a meaningful group.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to group.\n organize_type (str): The type of group (e.g., 'dialogue_session', 'procedure').\n metadata (Optional[Dict[str, Any]]): Additional metadata for the group.\n\n Returns:\n str: A unique identifier for the created group.\n \"\"\"\n try:\n # Retrieve the memory units to group\n memory_units = [self.get(unit_id) for unit_id in unit_ids]\n logger.debug(f\"Grouping {len(memory_units)} MemoryUnits into group_type='{organize_type}'.\")\n\n # Delegate grouping to MemoryOrganizer\n organized_memories = self.organizer.organize(memory_units, organize_type, metadata=metadata)\n logger.info(f\"Grouped memories into '{organize_type}'. Total memory units after grouping: {len(organized_memories)}\")\n return \"group_id_placeholder\" # Replace with actual group ID if applicable\n except Exception as e:\n logger.error(f\"Error grouping memories: {e}\")\n self.handle_error(e)\n raise\n\n def structurize(self, unit_ids: List[str], structure_type: str, **kwargs) -> None:\n \"\"\"\n Structures memory units into a knowledge graph or other structures.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to structurize.\n structure_type (str): The type of structure (e.g., 'knowledge_graph').\n **kwargs: Additional parameters for the structuring process.\n \"\"\"\n try:\n # Retrieve the memory units to structurize\n memory_units = [self.get(uid) for uid in unit_ids]\n logger.debug(f\"Structurizing {len(memory_units)} MemoryUnits with structure_type='{structure_type}'.\")\n\n # Delegate structuring to MemoryStructurer\n self.structurer.structure(memory_units, structure_type, **kwargs)\n logger.info(f\"Structurized memories with structure_type='{structure_type}'.\")\n except Exception as e:\n logger.error(f\"Error structurizing memories: {e}\")\n self.handle_error(e)\n raise\n\n def skillize(self, unit_ids: List[str], skill_name: str, **kwargs) -> str:\n \"\"\"\n Converts memory units into a reusable skill.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to skillize.\n skill_name (str): The name of the skill to create.\n **kwargs: Additional parameters for skill creation.\n\n Returns:\n str: The unique identifier of the created skill.\n \"\"\"\n try:\n # Retrieve the memory units to skillize\n memory_units = [self.get(uid) for uid in unit_ids]\n logger.debug(f\"Skillizing {len(memory_units)} MemoryUnits into skill_name='{skill_name}'.\")\n\n # Delegate skillizing to MemorySkillizer\n skill_id = self.skillizer.skillize(memory_units, skill_name, **kwargs)\n logger.info(f\"Skillized memories into skill with ID: {skill_id}\")\n return skill_id\n except Exception as e:\n logger.error(f\"Error skillizing memories: {e}\")\n self.handle_error(e)\n raise\n\n def parameterize(self, **kwargs) -> None:\n \"\"\"\n Trains a parametric model using the memory data.\n\n Args:\n **kwargs: Additional parameters for the training process.\n \"\"\"\n try:\n # Retrieve all memory units\n memory_units = self.get_all()\n logger.debug(f\"Parameterizing {len(memory_units)} MemoryUnits.\")\n\n # Delegate parameterizing to MemoryParameterizer\n self.parameterizer.parameterize(memory_units, **kwargs)\n logger.info(\"Parameterized memories successfully.\")\n except Exception as e:\n logger.error(f\"Error parameterizing memories: {e}\")\n self.handle_error(e)\n raise\n\n def retrieve(self, query: Any, retrieve_type: str, **kwargs) -> List[MemoryUnit]:\n \"\"\"\n Retrieve data from memory based on a query.\n\n Args:\n query (Any): The query or criteria to retrieve specific memory data.\n retrieve_type (str): The type of retrieval (e.g., 'similar', 'related').\n **kwargs: Additional parameters for the retrieval process.\n\n Returns:\n List[MemoryUnit]: The retrieved memory data.\n \"\"\"\n try:\n # Delegate retrieval to MemoryRetriever\n memories = self.retriever.retrieve(query=query, retrieve_type=retrieve_type, **kwargs)\n logger.info(f\"Retrieved {len(memories)} memories using retrieve_type='{retrieve_type}'.\")\n return memories\n except Exception as e:\n logger.error(f\"Error retrieving MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n\n def embed(self, unit_id: str) -> None:\n \"\"\"\n Generates an embedding for a memory unit.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n try:\n # Delegate embedding to MemoryRetriever\n memory_units = self.retriever.retrieve(query=unit_id, retrieve_type='similar', top_k=1)\n if not memory_units:\n raise ValueError(f\"No MemoryUnit found with ID {unit_id} to embed.\")\n\n memory_unit = memory_units[0]\n\n # Generate embedding using the embedder\n embedding_response = self.embedder.embed(memory_unit.content)\n if embedding_response.get(\"data\") and len(embedding_response[\"data\"]) > 0:\n memory_unit.embedding = embedding_response[\"data\"][0].get(\"embedding\")\n else:\n raise ValueError(\"Failed to generate embedding for the content.\")\n\n # Update the memory unit with the new embedding\n self.update(unit_id, {'embedding': memory_unit.embedding})\n\n logger.info(f\"Generated embedding for MemoryUnit ID: {unit_id}\")\n except Exception as e:\n logger.error(f\"Error generating embedding for MemoryUnit ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n\n # Error Handling\n\n def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during memory operations.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n logger.error(f\"MemoryPalace encountered an error: {error}\")\n # Additional error handling can be implemented here\n\n @staticmethod\n def get_api_key(self, config_section: Dict[str, Any], key_field: str, env_var_field: str) -> Optional[str]:\n \"\"\"\n Retrieve an API key from the configuration section.\n\n Args:\n config_section (Dict[str, Any]): The configuration section (e.g., embedder_config).\n key_field (str): The key in the config_section that may contain the API key directly.\n env_var_field (str): The key in the config_section that specifies the environment variable name.\n\n Returns:\n Optional[str]: The API key if found, else None.\n\n Raises:\n EnvironmentError: If the environment variable is specified but not set.\n \"\"\"\n # Check if API key is provided directly\n api_key = config_section.get(key_field)\n if api_key:\n logger.info(f\"Using provided API key for '{key_field}'.\")\n return api_key\n\n # Else, check if an environment variable is specified\n env_var = config_section.get(env_var_field)\n if env_var:\n api_key = os.getenv(env_var)\n if api_key:\n logger.info(f\"Retrieved API key for '{key_field}' from environment variable '{env_var}'.\")\n return api_key\n else:\n logger.error(f\"Environment variable '{env_var}' for '{key_field}' is not set.\")\n raise EnvironmentError(f\"Environment variable '{env_var}' for '{key_field}' is not set.\")\n\n logger.warning(f\"No API key provided for '{key_field}'.\")\n return None\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.__init__","title":"__init__(config)
","text":"Initialize the MemoryPalace with the provided configuration.
Parameters:
Name Type Description Default config
MemoryConfig
Configuration settings for the MemoryPalace.
required Source code in src/aeiva/cognition/memory/memory_palace.py
def __init__(self, config: Dict):\n \"\"\"\n Initialize the MemoryPalace with the provided configuration.\n\n Args:\n config (MemoryConfig): Configuration settings for the MemoryPalace.\n \"\"\"\n self.config_dict = config\n self.config = None\n self.storage = None\n self.embedder = None\n self.cleaner = None\n self.organizer = None\n self.retriever = None\n self.structurer = None\n self.skillizer = None\n self.parameterizer = None\n self.setup()\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.create","title":"create(content, **kwargs)
","text":"Creates a new memory unit with the given content and metadata.
Parameters:
Name Type Description Default content
Any
The core content of the memory unit.
required **kwargs
Additional metadata for the memory unit.
{}
Returns:
Name Type Description MemoryUnit
MemoryUnit
The created memory unit.
Source code in src/aeiva/cognition/memory/memory_palace.py
def create(self, content: Any, **kwargs) -> MemoryUnit:\n \"\"\"\n Creates a new memory unit with the given content and metadata.\n\n Args:\n content (Any): The core content of the memory unit.\n **kwargs: Additional metadata for the memory unit.\n\n Returns:\n MemoryUnit: The created memory unit.\n \"\"\"\n try:\n # Instantiate MemoryUnit\n memory_unit = MemoryUnit(content=content, **kwargs)\n\n # Generate embedding\n embedding_response = self.embedder.embed(content)\n if embedding_response.get(\"data\"):\n memory_unit.embedding = embedding_response[\"data\"][0].get(\"embedding\")\n else:\n raise ValueError(\"Failed to generate embedding for the content.\")\n\n # Delegate storage operations to MemoryStorage\n self.storage.add_memory_unit(memory_unit)\n\n logger.info(f\"Created new MemoryUnit with ID: {memory_unit.id}\")\n return memory_unit\n except Exception as e:\n logger.error(f\"Error creating MemoryUnit: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.delete","title":"delete(unit_id)
","text":"Deletes a memory unit by its unique identifier.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required Source code in src/aeiva/cognition/memory/memory_palace.py
def delete(self, unit_id: str) -> None:\n \"\"\"\n Deletes a memory unit by its unique identifier.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n try:\n # Delegate deletion to MemoryStorage\n self.storage.delete_memory_unit(unit_id)\n logger.info(f\"Deleted MemoryUnit with ID: {unit_id}\")\n except Exception as e:\n logger.error(f\"Error deleting MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.delete_all","title":"delete_all()
","text":"Deletes all memory units.
Source code in src/aeiva/cognition/memory/memory_palace.py
def delete_all(self) -> None:\n \"\"\"\n Deletes all memory units.\n \"\"\"\n try:\n self.storage.delete_all_memory_units() # TODO: seems no work correctly, need to check\n logger.info(\"Deleted all MemoryUnits.\")\n except Exception as e:\n logger.error(f\"Error deleting all MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.embed","title":"embed(unit_id)
","text":"Generates an embedding for a memory unit.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required Source code in src/aeiva/cognition/memory/memory_palace.py
def embed(self, unit_id: str) -> None:\n \"\"\"\n Generates an embedding for a memory unit.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n try:\n # Delegate embedding to MemoryRetriever\n memory_units = self.retriever.retrieve(query=unit_id, retrieve_type='similar', top_k=1)\n if not memory_units:\n raise ValueError(f\"No MemoryUnit found with ID {unit_id} to embed.\")\n\n memory_unit = memory_units[0]\n\n # Generate embedding using the embedder\n embedding_response = self.embedder.embed(memory_unit.content)\n if embedding_response.get(\"data\") and len(embedding_response[\"data\"]) > 0:\n memory_unit.embedding = embedding_response[\"data\"][0].get(\"embedding\")\n else:\n raise ValueError(\"Failed to generate embedding for the content.\")\n\n # Update the memory unit with the new embedding\n self.update(unit_id, {'embedding': memory_unit.embedding})\n\n logger.info(f\"Generated embedding for MemoryUnit ID: {unit_id}\")\n except Exception as e:\n logger.error(f\"Error generating embedding for MemoryUnit ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.filter","title":"filter(criteria)
","text":"Filters memory units based on the given criteria.
Parameters:
Name Type Description Default criteria
Dict[str, Any]
A dictionary of filter conditions.
required Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of memory units matching the criteria.
Source code in src/aeiva/cognition/memory/memory_palace.py
def filter(self, criteria: Dict[str, Any]) -> List[MemoryUnit]:\n \"\"\"\n Filters memory units based on the given criteria.\n\n Args:\n criteria (Dict[str, Any]): A dictionary of filter conditions.\n\n Returns:\n List[MemoryUnit]: A list of memory units matching the criteria.\n \"\"\"\n try:\n memory_units = self.get_all()\n filter_type = criteria.get('filter_type')\n if not filter_type:\n raise ValueError(\"Missing 'filter_type' in criteria.\")\n\n # Delegate filtering to MemoryCleaner\n filtered_memories = self.cleaner.filter(memory_units, filter_type, **criteria)\n logger.info(f\"Filtered memories based on criteria: {criteria}\")\n return filtered_memories\n except Exception as e:\n logger.error(f\"Error filtering memories: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.get","title":"get(unit_id)
","text":"Retrieves a memory unit by its unique identifier.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required Returns:
Name Type Description MemoryUnit
MemoryUnit
The retrieved memory unit.
Source code in src/aeiva/cognition/memory/memory_palace.py
def get(self, unit_id: str) -> MemoryUnit:\n \"\"\"\n Retrieves a memory unit by its unique identifier.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n\n Returns:\n MemoryUnit: The retrieved memory unit.\n \"\"\"\n try:\n memory_unit = self.storage.get_memory_unit(unit_id)\n logger.info(f\"Retrieved MemoryUnit with ID: {unit_id}\")\n return memory_unit\n except Exception as e:\n logger.error(f\"Error retrieving MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.get_all","title":"get_all()
","text":"Retrieves all memory units.
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of all memory units.
Source code in src/aeiva/cognition/memory/memory_palace.py
def get_all(self) -> List[MemoryUnit]:\n \"\"\"\n Retrieves all memory units.\n\n Returns:\n List[MemoryUnit]: A list of all memory units.\n \"\"\"\n try:\n memory_units = self.storage.get_all_memory_units()\n logger.info(f\"Retrieved all MemoryUnits. Total count: {len(memory_units)}\")\n return memory_units\n except Exception as e:\n logger.error(f\"Error retrieving all MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.get_api_key","title":"get_api_key(config_section, key_field, env_var_field)
staticmethod
","text":"Retrieve an API key from the configuration section.
Parameters:
Name Type Description Default config_section
Dict[str, Any]
The configuration section (e.g., embedder_config).
required key_field
str
The key in the config_section that may contain the API key directly.
required env_var_field
str
The key in the config_section that specifies the environment variable name.
required Returns:
Type Description Optional[str]
Optional[str]: The API key if found, else None.
Raises:
Type Description EnvironmentError
If the environment variable is specified but not set.
Source code in src/aeiva/cognition/memory/memory_palace.py
@staticmethod\ndef get_api_key(self, config_section: Dict[str, Any], key_field: str, env_var_field: str) -> Optional[str]:\n \"\"\"\n Retrieve an API key from the configuration section.\n\n Args:\n config_section (Dict[str, Any]): The configuration section (e.g., embedder_config).\n key_field (str): The key in the config_section that may contain the API key directly.\n env_var_field (str): The key in the config_section that specifies the environment variable name.\n\n Returns:\n Optional[str]: The API key if found, else None.\n\n Raises:\n EnvironmentError: If the environment variable is specified but not set.\n \"\"\"\n # Check if API key is provided directly\n api_key = config_section.get(key_field)\n if api_key:\n logger.info(f\"Using provided API key for '{key_field}'.\")\n return api_key\n\n # Else, check if an environment variable is specified\n env_var = config_section.get(env_var_field)\n if env_var:\n api_key = os.getenv(env_var)\n if api_key:\n logger.info(f\"Retrieved API key for '{key_field}' from environment variable '{env_var}'.\")\n return api_key\n else:\n logger.error(f\"Environment variable '{env_var}' for '{key_field}' is not set.\")\n raise EnvironmentError(f\"Environment variable '{env_var}' for '{key_field}' is not set.\")\n\n logger.warning(f\"No API key provided for '{key_field}'.\")\n return None\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.handle_error","title":"handle_error(error)
","text":"Handle errors that occur during memory operations.
Parameters:
Name Type Description Default error
Exception
The exception that was raised.
required Source code in src/aeiva/cognition/memory/memory_palace.py
def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during memory operations.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n logger.error(f\"MemoryPalace encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.load","title":"load()
","text":"Loads all memory units from the storage.
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of all loaded memory units.
Source code in src/aeiva/cognition/memory/memory_palace.py
def load(self) -> List[MemoryUnit]:\n \"\"\"\n Loads all memory units from the storage.\n\n Returns:\n List[MemoryUnit]: A list of all loaded memory units.\n \"\"\"\n try:\n # Retrieve all memory units from storage\n memory_units = self.get_all()\n logger.info(f\"Loaded {len(memory_units)} MemoryUnits from storage.\")\n return memory_units\n except Exception as e:\n logger.error(f\"Error loading MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.organize","title":"organize(unit_ids, organize_type, metadata=None)
","text":"Groups memory units into a meaningful group.
Parameters:
Name Type Description Default unit_ids
List[str]
A list of memory unit IDs to group.
required organize_type
str
The type of group (e.g., 'dialogue_session', 'procedure').
required metadata
Optional[Dict[str, Any]]
Additional metadata for the group.
None
Returns:
Name Type Description str
str
A unique identifier for the created group.
Source code in src/aeiva/cognition/memory/memory_palace.py
def organize(self, unit_ids: List[str], organize_type: str, metadata: Optional[Dict[str, Any]] = None) -> str:\n \"\"\"\n Groups memory units into a meaningful group.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to group.\n organize_type (str): The type of group (e.g., 'dialogue_session', 'procedure').\n metadata (Optional[Dict[str, Any]]): Additional metadata for the group.\n\n Returns:\n str: A unique identifier for the created group.\n \"\"\"\n try:\n # Retrieve the memory units to group\n memory_units = [self.get(unit_id) for unit_id in unit_ids]\n logger.debug(f\"Grouping {len(memory_units)} MemoryUnits into group_type='{organize_type}'.\")\n\n # Delegate grouping to MemoryOrganizer\n organized_memories = self.organizer.organize(memory_units, organize_type, metadata=metadata)\n logger.info(f\"Grouped memories into '{organize_type}'. Total memory units after grouping: {len(organized_memories)}\")\n return \"group_id_placeholder\" # Replace with actual group ID if applicable\n except Exception as e:\n logger.error(f\"Error grouping memories: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.parameterize","title":"parameterize(**kwargs)
","text":"Trains a parametric model using the memory data.
Parameters:
Name Type Description Default **kwargs
Additional parameters for the training process.
{}
Source code in src/aeiva/cognition/memory/memory_palace.py
def parameterize(self, **kwargs) -> None:\n \"\"\"\n Trains a parametric model using the memory data.\n\n Args:\n **kwargs: Additional parameters for the training process.\n \"\"\"\n try:\n # Retrieve all memory units\n memory_units = self.get_all()\n logger.debug(f\"Parameterizing {len(memory_units)} MemoryUnits.\")\n\n # Delegate parameterizing to MemoryParameterizer\n self.parameterizer.parameterize(memory_units, **kwargs)\n logger.info(\"Parameterized memories successfully.\")\n except Exception as e:\n logger.error(f\"Error parameterizing memories: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.retrieve","title":"retrieve(query, retrieve_type, **kwargs)
","text":"Retrieve data from memory based on a query.
Parameters:
Name Type Description Default query
Any
The query or criteria to retrieve specific memory data.
required retrieve_type
str
The type of retrieval (e.g., 'similar', 'related').
required **kwargs
Additional parameters for the retrieval process.
{}
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The retrieved memory data.
Source code in src/aeiva/cognition/memory/memory_palace.py
def retrieve(self, query: Any, retrieve_type: str, **kwargs) -> List[MemoryUnit]:\n \"\"\"\n Retrieve data from memory based on a query.\n\n Args:\n query (Any): The query or criteria to retrieve specific memory data.\n retrieve_type (str): The type of retrieval (e.g., 'similar', 'related').\n **kwargs: Additional parameters for the retrieval process.\n\n Returns:\n List[MemoryUnit]: The retrieved memory data.\n \"\"\"\n try:\n # Delegate retrieval to MemoryRetriever\n memories = self.retriever.retrieve(query=query, retrieve_type=retrieve_type, **kwargs)\n logger.info(f\"Retrieved {len(memories)} memories using retrieve_type='{retrieve_type}'.\")\n return memories\n except Exception as e:\n logger.error(f\"Error retrieving MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.save","title":"save(export_path=None)
","text":"Saves all memory units to the storage or exports them to a specified path.
Parameters:
Name Type Description Default export_path
Optional[str]
The file path to export memory units as JSON. If None, saves are handled by MemoryStorage.
None
Source code in src/aeiva/cognition/memory/memory_palace.py
def save(self, export_path: Optional[str] = None) -> None:\n \"\"\"\n Saves all memory units to the storage or exports them to a specified path.\n\n Args:\n export_path (Optional[str]): The file path to export memory units as JSON.\n If None, saves are handled by MemoryStorage.\n \"\"\"\n try:\n if export_path:\n # Export memory units to a JSON file\n memory_units = self.get_all()\n export_data = [mu.to_dict() for mu in memory_units]\n with open(export_path, 'w', encoding='utf-8') as f:\n json.dump(export_data, f, ensure_ascii=False, indent=4)\n logger.info(f\"Exported {len(memory_units)} MemoryUnits to {export_path}.\")\n else:\n # If no export path is provided, assume that MemoryStorage handles persistence\n logger.info(\"Save operation delegated to MemoryStorage.\")\n # Example: self.storage.persist_changes()\n except Exception as e:\n logger.error(f\"Error saving MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.setup","title":"setup()
","text":"Setup the MemoryPalace by initializing all components.
Source code in src/aeiva/cognition/memory/memory_palace.py
def setup(self):\n \"\"\"\n Setup the MemoryPalace by initializing all components.\n \"\"\"\n try:\n # Initialize EmbedderConfig\n embedder_config_dict = self.config_dict.get('embedder_config', {})\n self.embedder = Embedder(embedder_config_dict)\n\n storage_config_dict = self.config_dict.get('storage_config', {})\n self.storage = MemoryStorage(storage_config_dict) \n\n # Initialize Memory Configuration\n self.config = MemoryConfig(\n embedder_config=self.embedder.config,\n storage_config=self.storage.config\n )\n\n logger.info(\"MemoryPalace: MemoryStorage and Embedder initialized successfully.\")\n\n # Initialize specialized components\n self.cleaner = MemoryCleaner()\n self.organizer = MemoryOrganizer()\n self.retriever = MemoryRetriever(embedder=self.embedder, storage=self.storage)\n self.structurer = MemoryStructurer()\n self.skillizer = MemorySkillizer()\n self.parameterizer = MemoryParameterizer()\n logger.info(\"MemoryPalace: Specialized components initialized successfully.\")\n\n except Exception as e:\n logger.error(f\"MemoryPalace setup failed: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.skillize","title":"skillize(unit_ids, skill_name, **kwargs)
","text":"Converts memory units into a reusable skill.
Parameters:
Name Type Description Default unit_ids
List[str]
A list of memory unit IDs to skillize.
required skill_name
str
The name of the skill to create.
required **kwargs
Additional parameters for skill creation.
{}
Returns:
Name Type Description str
str
The unique identifier of the created skill.
Source code in src/aeiva/cognition/memory/memory_palace.py
def skillize(self, unit_ids: List[str], skill_name: str, **kwargs) -> str:\n \"\"\"\n Converts memory units into a reusable skill.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to skillize.\n skill_name (str): The name of the skill to create.\n **kwargs: Additional parameters for skill creation.\n\n Returns:\n str: The unique identifier of the created skill.\n \"\"\"\n try:\n # Retrieve the memory units to skillize\n memory_units = [self.get(uid) for uid in unit_ids]\n logger.debug(f\"Skillizing {len(memory_units)} MemoryUnits into skill_name='{skill_name}'.\")\n\n # Delegate skillizing to MemorySkillizer\n skill_id = self.skillizer.skillize(memory_units, skill_name, **kwargs)\n logger.info(f\"Skillized memories into skill with ID: {skill_id}\")\n return skill_id\n except Exception as e:\n logger.error(f\"Error skillizing memories: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.structurize","title":"structurize(unit_ids, structure_type, **kwargs)
","text":"Structures memory units into a knowledge graph or other structures.
Parameters:
Name Type Description Default unit_ids
List[str]
A list of memory unit IDs to structurize.
required structure_type
str
The type of structure (e.g., 'knowledge_graph').
required **kwargs
Additional parameters for the structuring process.
{}
Source code in src/aeiva/cognition/memory/memory_palace.py
def structurize(self, unit_ids: List[str], structure_type: str, **kwargs) -> None:\n \"\"\"\n Structures memory units into a knowledge graph or other structures.\n\n Args:\n unit_ids (List[str]): A list of memory unit IDs to structurize.\n structure_type (str): The type of structure (e.g., 'knowledge_graph').\n **kwargs: Additional parameters for the structuring process.\n \"\"\"\n try:\n # Retrieve the memory units to structurize\n memory_units = [self.get(uid) for uid in unit_ids]\n logger.debug(f\"Structurizing {len(memory_units)} MemoryUnits with structure_type='{structure_type}'.\")\n\n # Delegate structuring to MemoryStructurer\n self.structurer.structure(memory_units, structure_type, **kwargs)\n logger.info(f\"Structurized memories with structure_type='{structure_type}'.\")\n except Exception as e:\n logger.error(f\"Error structurizing memories: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_palace.MemoryPalace.update","title":"update(unit_id, updates)
","text":"Updates a memory unit with the given updates.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required updates
Dict[str, Any]
A dictionary of fields to update.
required Source code in src/aeiva/cognition/memory/memory_palace.py
def update(self, unit_id: str, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a memory unit with the given updates.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n updates (Dict[str, Any]): A dictionary of fields to update.\n \"\"\"\n try:\n # Delegate update operations to MemoryStorage\n self.storage.update_memory_unit(unit_id, updates)\n logger.info(f\"Updated MemoryUnit with ID: {unit_id}\")\n except Exception as e:\n logger.error(f\"Error updating MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_parameterizer","title":"memory_parameterizer
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_parameterizer.MemoryParameterizer","title":"MemoryParameterizer
","text":"A class to parameterize memory units based on various parameterizing algorithms.
Supported parameterize types - 'parameterize_type_example': Placeholder for future parameterizing algorithms.
Source code in src/aeiva/cognition/memory/memory_parameterizer.py
class MemoryParameterizer:\n \"\"\"\n A class to parameterize memory units based on various parameterizing algorithms.\n\n Supported parameterize types:\n - 'parameterize_type_example': Placeholder for future parameterizing algorithms.\n \"\"\"\n\n def __init__(self):\n \"\"\"\n Initializes the MemoryParameterizer.\n\n Currently, no initialization parameters are required.\n \"\"\"\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemoryParameterizer without default parameters.\")\n\n def parameterize(\n self,\n memory_units: List[MemoryUnit],\n parameterize_type: str,\n **kwargs\n ) -> List[MemoryUnit]:\n \"\"\"\n Parameterizes the provided memory units based on the specified parameterize type.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be parameterized.\n parameterize_type (str): The type of parameterizing algorithm to use ('parameterize_type_example').\n **kwargs: Additional parameters required for specific parameterizers.\n\n Returns:\n List[MemoryUnit]: The list of memory units after parameterization.\n\n Raises:\n MemoryParameterizerError: If an unknown parameterize_type is provided or if parameterizing fails.\n \"\"\"\n self.logger.debug(f\"Parameterizing memory units using parameterize_type='{parameterize_type}' with kwargs={kwargs}\")\n try:\n if parameterize_type == 'parameterize_type_example':\n # Placeholder for actual parameterizing logic\n return self.parameterize_example(memory_units, **kwargs)\n else:\n self.logger.error(f\"Unknown parameterize_type: {parameterize_type}\")\n raise MemoryParameterizerError(f\"Unknown parameterize_type: {parameterize_type}\")\n except MemoryParameterizerError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to parameterize memory units: {e}\")\n raise MemoryParameterizerError(f\"Failed to parameterize memory units: {e}\")\n\n def parameterize_example(\n self,\n memory_units: List[MemoryUnit],\n **kwargs\n ) -> List[MemoryUnit]:\n \"\"\"\n Example parameterizing method. Currently a placeholder that returns memory units unchanged.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be parameterized.\n **kwargs: Additional parameters (currently unused).\n\n Returns:\n List[MemoryUnit]: The original list of memory units, unchanged.\n \"\"\"\n self.logger.debug(\"Executing parameterize_example: No changes applied to memory units.\")\n # Placeholder: No operation performed\n return memory_units\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_parameterizer.MemoryParameterizer.__init__","title":"__init__()
","text":"Initializes the MemoryParameterizer.
Currently, no initialization parameters are required.
Source code in src/aeiva/cognition/memory/memory_parameterizer.py
def __init__(self):\n \"\"\"\n Initializes the MemoryParameterizer.\n\n Currently, no initialization parameters are required.\n \"\"\"\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemoryParameterizer without default parameters.\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_parameterizer.MemoryParameterizer.parameterize","title":"parameterize(memory_units, parameterize_type, **kwargs)
","text":"Parameterizes the provided memory units based on the specified parameterize type.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be parameterized.
required parameterize_type
str
The type of parameterizing algorithm to use ('parameterize_type_example').
required **kwargs
Additional parameters required for specific parameterizers.
{}
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The list of memory units after parameterization.
Raises:
Type Description MemoryParameterizerError
If an unknown parameterize_type is provided or if parameterizing fails.
Source code in src/aeiva/cognition/memory/memory_parameterizer.py
def parameterize(\n self,\n memory_units: List[MemoryUnit],\n parameterize_type: str,\n **kwargs\n) -> List[MemoryUnit]:\n \"\"\"\n Parameterizes the provided memory units based on the specified parameterize type.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be parameterized.\n parameterize_type (str): The type of parameterizing algorithm to use ('parameterize_type_example').\n **kwargs: Additional parameters required for specific parameterizers.\n\n Returns:\n List[MemoryUnit]: The list of memory units after parameterization.\n\n Raises:\n MemoryParameterizerError: If an unknown parameterize_type is provided or if parameterizing fails.\n \"\"\"\n self.logger.debug(f\"Parameterizing memory units using parameterize_type='{parameterize_type}' with kwargs={kwargs}\")\n try:\n if parameterize_type == 'parameterize_type_example':\n # Placeholder for actual parameterizing logic\n return self.parameterize_example(memory_units, **kwargs)\n else:\n self.logger.error(f\"Unknown parameterize_type: {parameterize_type}\")\n raise MemoryParameterizerError(f\"Unknown parameterize_type: {parameterize_type}\")\n except MemoryParameterizerError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to parameterize memory units: {e}\")\n raise MemoryParameterizerError(f\"Failed to parameterize memory units: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_parameterizer.MemoryParameterizer.parameterize_example","title":"parameterize_example(memory_units, **kwargs)
","text":"Example parameterizing method. Currently a placeholder that returns memory units unchanged.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be parameterized.
required **kwargs
Additional parameters (currently unused).
{}
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The original list of memory units, unchanged.
Source code in src/aeiva/cognition/memory/memory_parameterizer.py
def parameterize_example(\n self,\n memory_units: List[MemoryUnit],\n **kwargs\n) -> List[MemoryUnit]:\n \"\"\"\n Example parameterizing method. Currently a placeholder that returns memory units unchanged.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be parameterized.\n **kwargs: Additional parameters (currently unused).\n\n Returns:\n List[MemoryUnit]: The original list of memory units, unchanged.\n \"\"\"\n self.logger.debug(\"Executing parameterize_example: No changes applied to memory units.\")\n # Placeholder: No operation performed\n return memory_units\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_parameterizer.MemoryParameterizerError","title":"MemoryParameterizerError
","text":" Bases: Exception
Exception raised when an error occurs in the MemoryParameterizer.
Source code in src/aeiva/cognition/memory/memory_parameterizer.py
class MemoryParameterizerError(Exception):\n \"\"\"Exception raised when an error occurs in the MemoryParameterizer.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_retriever","title":"memory_retriever
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_retriever.MemoryRetriever","title":"MemoryRetriever
","text":"A class to retrieve memory units based on various retrieval algorithms.
Supported retrieval types - 'similar': Retrieves memory units similar to a given query based on embeddings.
- 'related': Retrieves memory units related to a specified query based on relationships.
Source code in src/aeiva/cognition/memory/memory_retriever.py
class MemoryRetriever:\n \"\"\"\n A class to retrieve memory units based on various retrieval algorithms.\n\n Supported retrieval types:\n - 'similar': Retrieves memory units similar to a given query based on embeddings.\n - 'related': Retrieves memory units related to a specified query based on relationships.\n \"\"\"\n\n def __init__(self, embedder: Embedder, storage: MemoryStorage):\n \"\"\"\n Initializes the MemoryRetriever.\n\n Args:\n embedder (Embedder): An instance responsible for generating embeddings.\n storage (MemoryStorage): An instance managing data storage and retrieval.\n \"\"\"\n self.embedder = embedder\n self.storage = storage\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemoryRetriever with provided embedder and storage.\")\n\n def retrieve(\n self,\n query: Any,\n retrieve_type: str,\n **kwargs\n ) -> List[MemoryUnit]:\n \"\"\"\n Factory method to retrieve memory units based on the specified retrieval type.\n\n Args:\n query (Any): The query for retrieval.\n retrieve_type (str): The type of retrieval ('similar' or 'related').\n **kwargs: Additional parameters required for specific retrieval types.\n For 'similar' retrieval:\n - top_k (int): The number of similar units to retrieve.\n For 'related' retrieval:\n - relationship (Optional[str]): The type of relationship to filter by.\n\n Returns:\n List[MemoryUnit]: A list of retrieved memory units.\n\n Raises:\n MemoryRetrieverError: If an unknown retrieval_type is provided or if retrieval fails.\n \"\"\"\n self.logger.info(f\"Initiating retrieval of type '{retrieve_type}' with query: {query}\")\n try:\n if retrieve_type == 'similar':\n top_k = kwargs.get('top_k', 5)\n self.logger.debug(f\"Retrieval Type: 'similar' with top_k={top_k}\")\n return self.retrieve_similar(query, top_k)\n elif retrieve_type == 'related':\n relationship = kwargs.get('relationship')\n self.logger.debug(f\"Retrieval Type: 'related' with relationship='{relationship}'\")\n return self.retrieve_related(query, relationship)\n else:\n self.logger.error(f\"Unknown retrieve_type: {retrieve_type}\")\n raise MemoryRetrieverError(f\"Unknown retrieve_type: {retrieve_type}\")\n except MemoryRetrieverError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to retrieve memory units: {e}\")\n raise MemoryRetrieverError(f\"Failed to retrieve memory units: {e}\") from e\n\n def retrieve_similar(self, query: Any, top_k: int = 5) -> List[MemoryUnit]:\n \"\"\"\n Retrieves memory units similar to the given input based on embeddings.\n\n Args:\n query (Any): The query for retrieval.\n top_k (int): The number of similar units to retrieve.\n\n Returns:\n List[MemoryUnit]: A list of similar memory units.\n\n Raises:\n MemoryRetrieverError: If retrieval fails due to embedding generation or storage issues.\n \"\"\"\n self.logger.info(f\"Retrieving top {top_k} similar MemoryUnits based on the query.\")\n try:\n # Generate embedding for the query\n self.logger.debug(\"Generating embedding for the query.\")\n embedding_response = self.embedder.embed(query)\n if not embedding_response.get(\"data\"):\n self.logger.error(\"Failed to generate embedding for the query.\")\n raise MemoryRetrieverError(\"Failed to generate embedding for the query.\")\n\n query_embedding = embedding_response[\"data\"][0].get(\"embedding\")\n if not query_embedding:\n self.logger.error(\"Embedding data is missing in the response.\")\n raise MemoryRetrieverError(\"Embedding data is missing in the response.\")\n\n self.logger.debug(f\"Embedding generated successfully: {query_embedding}\")\n\n # Perform similarity search via MemoryStorage\n self.logger.debug(\"Performing similarity search in the vector database.\")\n similar_units = self.storage.retrieve_similar_memory_units(query_embedding, top_k)\n self.logger.info(f\"Retrieved {len(similar_units)} similar MemoryUnits.\")\n return similar_units\n\n except MemoryRetrieverError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Unexpected error during retrieve_similar: {e}\")\n raise MemoryRetrieverError(f\"Unexpected error during retrieve_similar: {e}\") from e\n\n def retrieve_related(\n self,\n query: Any,\n relationship: Optional[str] = None\n ) -> List[MemoryUnit]: # TODO: revise the method later\n \"\"\"\n Retrieves memory units related to the given query based on relationships.\n\n Args:\n query (Any): The query for retrieval. Expected to be a MemoryUnit ID or similar identifier.\n relationship (Optional[str]): The type of relationship to filter by.\n\n Returns:\n List[MemoryUnit]: A list of related memory units.\n\n Raises:\n MemoryRetrieverError: If retrieval fails due to storage issues or invalid queries.\n \"\"\"\n self.logger.info(f\"Retrieving memories related to the query with relationship: {relationship}\")\n try:\n # Assuming 'query' is a MemoryUnit ID or can be used to fetch a MemoryUnit\n self.logger.debug(\"Fetching the target MemoryUnit from storage.\")\n target_memory_unit = self.storage.get_memory_unit(query)\n if not target_memory_unit:\n self.logger.error(f\"MemoryUnit with ID '{query}' not found.\")\n raise MemoryRetrieverError(f\"MemoryUnit with ID '{query}' not found.\")\n\n self.logger.debug(f\"MemoryUnit fetched successfully: {target_memory_unit}\")\n\n # Perform related retrieval via MemoryStorage\n self.logger.debug(\"Retrieving related MemoryUnits from the graph database.\")\n related_units = self.storage.retrieve_related_memory_units(target_memory_unit.id, relationship)\n self.logger.info(f\"Retrieved {len(related_units)} related MemoryUnits.\")\n return related_units\n\n except MemoryRetrieverError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Unexpected error during retrieve_related: {e}\")\n raise MemoryRetrieverError(f\"Unexpected error during retrieve_related: {e}\") from e\n\n def handle_error(self, error: Exception):\n \"\"\"\n Handles errors by logging or performing other necessary actions.\n\n Args:\n error (Exception): The exception to handle.\n \"\"\"\n # Implement any error handling logic here\n # For now, we'll just log the error\n self.logger.error(f\"An error occurred: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_retriever.MemoryRetriever.__init__","title":"__init__(embedder, storage)
","text":"Initializes the MemoryRetriever.
Parameters:
Name Type Description Default embedder
Embedder
An instance responsible for generating embeddings.
required storage
MemoryStorage
An instance managing data storage and retrieval.
required Source code in src/aeiva/cognition/memory/memory_retriever.py
def __init__(self, embedder: Embedder, storage: MemoryStorage):\n \"\"\"\n Initializes the MemoryRetriever.\n\n Args:\n embedder (Embedder): An instance responsible for generating embeddings.\n storage (MemoryStorage): An instance managing data storage and retrieval.\n \"\"\"\n self.embedder = embedder\n self.storage = storage\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemoryRetriever with provided embedder and storage.\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_retriever.MemoryRetriever.handle_error","title":"handle_error(error)
","text":"Handles errors by logging or performing other necessary actions.
Parameters:
Name Type Description Default error
Exception
The exception to handle.
required Source code in src/aeiva/cognition/memory/memory_retriever.py
def handle_error(self, error: Exception):\n \"\"\"\n Handles errors by logging or performing other necessary actions.\n\n Args:\n error (Exception): The exception to handle.\n \"\"\"\n # Implement any error handling logic here\n # For now, we'll just log the error\n self.logger.error(f\"An error occurred: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_retriever.MemoryRetriever.retrieve","title":"retrieve(query, retrieve_type, **kwargs)
","text":"Factory method to retrieve memory units based on the specified retrieval type.
Parameters:
Name Type Description Default query
Any
The query for retrieval.
required retrieve_type
str
The type of retrieval ('similar' or 'related').
required **kwargs
Additional parameters required for specific retrieval types. For 'similar' retrieval: - top_k (int): The number of similar units to retrieve. For 'related' retrieval: - relationship (Optional[str]): The type of relationship to filter by.
{}
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of retrieved memory units.
Raises:
Type Description MemoryRetrieverError
If an unknown retrieval_type is provided or if retrieval fails.
Source code in src/aeiva/cognition/memory/memory_retriever.py
def retrieve(\n self,\n query: Any,\n retrieve_type: str,\n **kwargs\n) -> List[MemoryUnit]:\n \"\"\"\n Factory method to retrieve memory units based on the specified retrieval type.\n\n Args:\n query (Any): The query for retrieval.\n retrieve_type (str): The type of retrieval ('similar' or 'related').\n **kwargs: Additional parameters required for specific retrieval types.\n For 'similar' retrieval:\n - top_k (int): The number of similar units to retrieve.\n For 'related' retrieval:\n - relationship (Optional[str]): The type of relationship to filter by.\n\n Returns:\n List[MemoryUnit]: A list of retrieved memory units.\n\n Raises:\n MemoryRetrieverError: If an unknown retrieval_type is provided or if retrieval fails.\n \"\"\"\n self.logger.info(f\"Initiating retrieval of type '{retrieve_type}' with query: {query}\")\n try:\n if retrieve_type == 'similar':\n top_k = kwargs.get('top_k', 5)\n self.logger.debug(f\"Retrieval Type: 'similar' with top_k={top_k}\")\n return self.retrieve_similar(query, top_k)\n elif retrieve_type == 'related':\n relationship = kwargs.get('relationship')\n self.logger.debug(f\"Retrieval Type: 'related' with relationship='{relationship}'\")\n return self.retrieve_related(query, relationship)\n else:\n self.logger.error(f\"Unknown retrieve_type: {retrieve_type}\")\n raise MemoryRetrieverError(f\"Unknown retrieve_type: {retrieve_type}\")\n except MemoryRetrieverError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to retrieve memory units: {e}\")\n raise MemoryRetrieverError(f\"Failed to retrieve memory units: {e}\") from e\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_retriever.MemoryRetriever.retrieve_related","title":"retrieve_related(query, relationship=None)
","text":"Retrieves memory units related to the given query based on relationships.
Parameters:
Name Type Description Default query
Any
The query for retrieval. Expected to be a MemoryUnit ID or similar identifier.
required relationship
Optional[str]
The type of relationship to filter by.
None
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of related memory units.
Raises:
Type Description MemoryRetrieverError
If retrieval fails due to storage issues or invalid queries.
Source code in src/aeiva/cognition/memory/memory_retriever.py
def retrieve_related(\n self,\n query: Any,\n relationship: Optional[str] = None\n) -> List[MemoryUnit]: # TODO: revise the method later\n \"\"\"\n Retrieves memory units related to the given query based on relationships.\n\n Args:\n query (Any): The query for retrieval. Expected to be a MemoryUnit ID or similar identifier.\n relationship (Optional[str]): The type of relationship to filter by.\n\n Returns:\n List[MemoryUnit]: A list of related memory units.\n\n Raises:\n MemoryRetrieverError: If retrieval fails due to storage issues or invalid queries.\n \"\"\"\n self.logger.info(f\"Retrieving memories related to the query with relationship: {relationship}\")\n try:\n # Assuming 'query' is a MemoryUnit ID or can be used to fetch a MemoryUnit\n self.logger.debug(\"Fetching the target MemoryUnit from storage.\")\n target_memory_unit = self.storage.get_memory_unit(query)\n if not target_memory_unit:\n self.logger.error(f\"MemoryUnit with ID '{query}' not found.\")\n raise MemoryRetrieverError(f\"MemoryUnit with ID '{query}' not found.\")\n\n self.logger.debug(f\"MemoryUnit fetched successfully: {target_memory_unit}\")\n\n # Perform related retrieval via MemoryStorage\n self.logger.debug(\"Retrieving related MemoryUnits from the graph database.\")\n related_units = self.storage.retrieve_related_memory_units(target_memory_unit.id, relationship)\n self.logger.info(f\"Retrieved {len(related_units)} related MemoryUnits.\")\n return related_units\n\n except MemoryRetrieverError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Unexpected error during retrieve_related: {e}\")\n raise MemoryRetrieverError(f\"Unexpected error during retrieve_related: {e}\") from e\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_retriever.MemoryRetriever.retrieve_similar","title":"retrieve_similar(query, top_k=5)
","text":"Retrieves memory units similar to the given input based on embeddings.
Parameters:
Name Type Description Default query
Any
The query for retrieval.
required top_k
int
The number of similar units to retrieve.
5
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of similar memory units.
Raises:
Type Description MemoryRetrieverError
If retrieval fails due to embedding generation or storage issues.
Source code in src/aeiva/cognition/memory/memory_retriever.py
def retrieve_similar(self, query: Any, top_k: int = 5) -> List[MemoryUnit]:\n \"\"\"\n Retrieves memory units similar to the given input based on embeddings.\n\n Args:\n query (Any): The query for retrieval.\n top_k (int): The number of similar units to retrieve.\n\n Returns:\n List[MemoryUnit]: A list of similar memory units.\n\n Raises:\n MemoryRetrieverError: If retrieval fails due to embedding generation or storage issues.\n \"\"\"\n self.logger.info(f\"Retrieving top {top_k} similar MemoryUnits based on the query.\")\n try:\n # Generate embedding for the query\n self.logger.debug(\"Generating embedding for the query.\")\n embedding_response = self.embedder.embed(query)\n if not embedding_response.get(\"data\"):\n self.logger.error(\"Failed to generate embedding for the query.\")\n raise MemoryRetrieverError(\"Failed to generate embedding for the query.\")\n\n query_embedding = embedding_response[\"data\"][0].get(\"embedding\")\n if not query_embedding:\n self.logger.error(\"Embedding data is missing in the response.\")\n raise MemoryRetrieverError(\"Embedding data is missing in the response.\")\n\n self.logger.debug(f\"Embedding generated successfully: {query_embedding}\")\n\n # Perform similarity search via MemoryStorage\n self.logger.debug(\"Performing similarity search in the vector database.\")\n similar_units = self.storage.retrieve_similar_memory_units(query_embedding, top_k)\n self.logger.info(f\"Retrieved {len(similar_units)} similar MemoryUnits.\")\n return similar_units\n\n except MemoryRetrieverError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Unexpected error during retrieve_similar: {e}\")\n raise MemoryRetrieverError(f\"Unexpected error during retrieve_similar: {e}\") from e\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_retriever.MemoryRetrieverError","title":"MemoryRetrieverError
","text":" Bases: Exception
Exception raised when an error occurs in the MemoryRetriever.
Source code in src/aeiva/cognition/memory/memory_retriever.py
class MemoryRetrieverError(Exception):\n \"\"\"Exception raised when an error occurs in the MemoryRetriever.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_skillizer","title":"memory_skillizer
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_skillizer.MemorySkillizer","title":"MemorySkillizer
","text":"A class to skillize memory units based on various skillizing algorithms.
Supported skill types - 'skill_type_example': Placeholder for future skillizing algorithms.
Source code in src/aeiva/cognition/memory/memory_skillizer.py
class MemorySkillizer:\n \"\"\"\n A class to skillize memory units based on various skillizing algorithms.\n\n Supported skill types:\n - 'skill_type_example': Placeholder for future skillizing algorithms.\n \"\"\"\n\n def __init__(self):\n \"\"\"\n Initializes the MemorySkillizer.\n\n Currently, no initialization parameters are required.\n \"\"\"\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemorySkillizer without default parameters.\")\n\n def skillize(\n self,\n memory_units: List[MemoryUnit],\n skill_type: str,\n **kwargs\n ) -> List[MemoryUnit]:\n \"\"\"\n Skillizes the provided memory units based on the specified skill type.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be skillized.\n skill_type (str): The type of skillizing algorithm to use ('skill_type_example').\n **kwargs: Additional parameters required for specific skillizers.\n\n Returns:\n List[MemoryUnit]: The list of memory units after skillizing.\n\n Raises:\n MemorySkillizerError: If an unknown skill_type is provided or if skillizing fails.\n \"\"\"\n self.logger.debug(f\"Skillizing memory units using skill_type='{skill_type}' with kwargs={kwargs}\")\n try:\n if skill_type == 'skill_type_example':\n # Placeholder for actual skillizing logic\n return self.skillize_example(memory_units, **kwargs)\n else:\n self.logger.error(f\"Unknown skill_type: {skill_type}\")\n raise MemorySkillizerError(f\"Unknown skill_type: {skill_type}\")\n except MemorySkillizerError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to skillize memory units: {e}\")\n raise MemorySkillizerError(f\"Failed to skillize memory units: {e}\")\n\n def skillize_example(\n self,\n memory_units: List[MemoryUnit],\n **kwargs\n ) -> List[MemoryUnit]:\n \"\"\"\n Example skillizing method. Currently a placeholder that returns memory units unchanged.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be skillized.\n **kwargs: Additional parameters (currently unused).\n\n Returns:\n List[MemoryUnit]: The original list of memory units, unchanged.\n \"\"\"\n self.logger.debug(\"Executing skillize_example: No changes applied to memory units.\")\n # Placeholder: No operation performed\n return memory_units\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_skillizer.MemorySkillizer.__init__","title":"__init__()
","text":"Initializes the MemorySkillizer.
Currently, no initialization parameters are required.
Source code in src/aeiva/cognition/memory/memory_skillizer.py
def __init__(self):\n \"\"\"\n Initializes the MemorySkillizer.\n\n Currently, no initialization parameters are required.\n \"\"\"\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemorySkillizer without default parameters.\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_skillizer.MemorySkillizer.skillize","title":"skillize(memory_units, skill_type, **kwargs)
","text":"Skillizes the provided memory units based on the specified skill type.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be skillized.
required skill_type
str
The type of skillizing algorithm to use ('skill_type_example').
required **kwargs
Additional parameters required for specific skillizers.
{}
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The list of memory units after skillizing.
Raises:
Type Description MemorySkillizerError
If an unknown skill_type is provided or if skillizing fails.
Source code in src/aeiva/cognition/memory/memory_skillizer.py
def skillize(\n self,\n memory_units: List[MemoryUnit],\n skill_type: str,\n **kwargs\n) -> List[MemoryUnit]:\n \"\"\"\n Skillizes the provided memory units based on the specified skill type.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be skillized.\n skill_type (str): The type of skillizing algorithm to use ('skill_type_example').\n **kwargs: Additional parameters required for specific skillizers.\n\n Returns:\n List[MemoryUnit]: The list of memory units after skillizing.\n\n Raises:\n MemorySkillizerError: If an unknown skill_type is provided or if skillizing fails.\n \"\"\"\n self.logger.debug(f\"Skillizing memory units using skill_type='{skill_type}' with kwargs={kwargs}\")\n try:\n if skill_type == 'skill_type_example':\n # Placeholder for actual skillizing logic\n return self.skillize_example(memory_units, **kwargs)\n else:\n self.logger.error(f\"Unknown skill_type: {skill_type}\")\n raise MemorySkillizerError(f\"Unknown skill_type: {skill_type}\")\n except MemorySkillizerError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to skillize memory units: {e}\")\n raise MemorySkillizerError(f\"Failed to skillize memory units: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_skillizer.MemorySkillizer.skillize_example","title":"skillize_example(memory_units, **kwargs)
","text":"Example skillizing method. Currently a placeholder that returns memory units unchanged.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be skillized.
required **kwargs
Additional parameters (currently unused).
{}
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The original list of memory units, unchanged.
Source code in src/aeiva/cognition/memory/memory_skillizer.py
def skillize_example(\n self,\n memory_units: List[MemoryUnit],\n **kwargs\n) -> List[MemoryUnit]:\n \"\"\"\n Example skillizing method. Currently a placeholder that returns memory units unchanged.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be skillized.\n **kwargs: Additional parameters (currently unused).\n\n Returns:\n List[MemoryUnit]: The original list of memory units, unchanged.\n \"\"\"\n self.logger.debug(\"Executing skillize_example: No changes applied to memory units.\")\n # Placeholder: No operation performed\n return memory_units\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_skillizer.MemorySkillizerError","title":"MemorySkillizerError
","text":" Bases: Exception
Exception raised when an error occurs in the MemorySkillizer.
Source code in src/aeiva/cognition/memory/memory_skillizer.py
class MemorySkillizerError(Exception):\n \"\"\"Exception raised when an error occurs in the MemorySkillizer.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage","title":"memory_storage
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryEventRepository","title":"MemoryEventRepository
","text":"Repository for MemoryEvent to handle CRUD operations without SQLAlchemy.
Source code in src/aeiva/cognition/memory/memory_storage.py
class MemoryEventRepository:\n \"\"\"\n Repository for MemoryEvent to handle CRUD operations without SQLAlchemy.\n \"\"\"\n\n def __init__(self, db: Any):\n \"\"\"\n Initialize the repository with a DatabaseFactory instance.\n\n Args:\n db (Any): An instance of DatabaseFactory for relational databases.\n \"\"\"\n self.db = db\n self.table_name = 'memory_events'\n self._create_table()\n\n def _create_table(self):\n \"\"\"\n Creates the memory_events table if it does not exist.\n \"\"\"\n create_table_query = f\"\"\"\n CREATE TABLE IF NOT EXISTS {self.table_name} (\n id TEXT PRIMARY KEY,\n memory_id TEXT NOT NULL,\n event_type TEXT NOT NULL,\n timestamp TEXT NOT NULL,\n memory_data TEXT,\n previous_data TEXT\n );\n \"\"\"\n self.db.execute_sql(create_table_query)\n\n def add(self, event: Dict[str, Any]) -> None:\n \"\"\"\n Adds a MemoryEvent to the relational database.\n\n Args:\n event (Dict[str, Any]): The event data to add.\n \"\"\"\n insert_query = f\"\"\"\n INSERT INTO {self.table_name} (id, memory_id, event_type, timestamp, memory_data, previous_data)\n VALUES (?, ?, ?, ?, ?, ?);\n \"\"\"\n data = (\n event.get('id', uuid4().hex),\n event['memory_id'],\n event['event_type'],\n datetime.utcnow().isoformat(), # TODO: revise utcnow.\n event.get('memory_data'),\n event.get('previous_data')\n )\n self.db.execute_sql(insert_query, data)\n\n def get(self, event_id: str) -> Optional[Dict[str, Any]]:\n \"\"\"\n Retrieves a MemoryEvent by its ID.\n\n Args:\n event_id (str): The unique identifier of the event.\n\n Returns:\n Optional[Dict[str, Any]]: The event data or None if not found.\n \"\"\"\n select_query = f\"SELECT * FROM {self.table_name} WHERE id = ?;\"\n result = self.db.execute_sql(select_query, (event_id,))\n row = result.fetchone()\n if row:\n return self._row_to_event(row)\n return None\n\n def delete_all(self) -> None:\n \"\"\"\n Deletes all MemoryEvents from the relational database.\n \"\"\"\n delete_query = f\"DELETE FROM {self.table_name};\"\n self.db.execute_sql(delete_query)\n\n def list_all(self) -> List[Dict[str, Any]]:\n \"\"\"\n Retrieves all MemoryEvents from the relational database.\n\n Returns:\n List[Dict[str, Any]]: A list of all events.\n \"\"\"\n select_query = f\"SELECT * FROM {self.table_name};\"\n results = self.db.execute_sql(select_query)\n return [self._row_to_event(row) for row in results.fetchall()]\n\n def _row_to_event(self, row: Any) -> Dict[str, Any]:\n \"\"\"\n Converts a database row to an event dictionary.\n\n Args:\n row (Any): A row fetched from the database.\n\n Returns:\n Dict[str, Any]: The corresponding event data.\n \"\"\"\n return {\n \"id\": row['id'],\n \"memory_id\": row['memory_id'],\n \"event_type\": row['event_type'],\n \"timestamp\": datetime.fromisoformat(row['timestamp']),\n \"memory_data\": json.loads(row['memory_data']) if row['memory_data'] else None,\n \"previous_data\": json.loads(row['previous_data']) if row['previous_data'] else None\n }\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryEventRepository.__init__","title":"__init__(db)
","text":"Initialize the repository with a DatabaseFactory instance.
Parameters:
Name Type Description Default db
Any
An instance of DatabaseFactory for relational databases.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def __init__(self, db: Any):\n \"\"\"\n Initialize the repository with a DatabaseFactory instance.\n\n Args:\n db (Any): An instance of DatabaseFactory for relational databases.\n \"\"\"\n self.db = db\n self.table_name = 'memory_events'\n self._create_table()\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryEventRepository.add","title":"add(event)
","text":"Adds a MemoryEvent to the relational database.
Parameters:
Name Type Description Default event
Dict[str, Any]
The event data to add.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def add(self, event: Dict[str, Any]) -> None:\n \"\"\"\n Adds a MemoryEvent to the relational database.\n\n Args:\n event (Dict[str, Any]): The event data to add.\n \"\"\"\n insert_query = f\"\"\"\n INSERT INTO {self.table_name} (id, memory_id, event_type, timestamp, memory_data, previous_data)\n VALUES (?, ?, ?, ?, ?, ?);\n \"\"\"\n data = (\n event.get('id', uuid4().hex),\n event['memory_id'],\n event['event_type'],\n datetime.utcnow().isoformat(), # TODO: revise utcnow.\n event.get('memory_data'),\n event.get('previous_data')\n )\n self.db.execute_sql(insert_query, data)\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryEventRepository.delete_all","title":"delete_all()
","text":"Deletes all MemoryEvents from the relational database.
Source code in src/aeiva/cognition/memory/memory_storage.py
def delete_all(self) -> None:\n \"\"\"\n Deletes all MemoryEvents from the relational database.\n \"\"\"\n delete_query = f\"DELETE FROM {self.table_name};\"\n self.db.execute_sql(delete_query)\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryEventRepository.get","title":"get(event_id)
","text":"Retrieves a MemoryEvent by its ID.
Parameters:
Name Type Description Default event_id
str
The unique identifier of the event.
required Returns:
Type Description Optional[Dict[str, Any]]
Optional[Dict[str, Any]]: The event data or None if not found.
Source code in src/aeiva/cognition/memory/memory_storage.py
def get(self, event_id: str) -> Optional[Dict[str, Any]]:\n \"\"\"\n Retrieves a MemoryEvent by its ID.\n\n Args:\n event_id (str): The unique identifier of the event.\n\n Returns:\n Optional[Dict[str, Any]]: The event data or None if not found.\n \"\"\"\n select_query = f\"SELECT * FROM {self.table_name} WHERE id = ?;\"\n result = self.db.execute_sql(select_query, (event_id,))\n row = result.fetchone()\n if row:\n return self._row_to_event(row)\n return None\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryEventRepository.list_all","title":"list_all()
","text":"Retrieves all MemoryEvents from the relational database.
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of all events.
Source code in src/aeiva/cognition/memory/memory_storage.py
def list_all(self) -> List[Dict[str, Any]]:\n \"\"\"\n Retrieves all MemoryEvents from the relational database.\n\n Returns:\n List[Dict[str, Any]]: A list of all events.\n \"\"\"\n select_query = f\"SELECT * FROM {self.table_name};\"\n results = self.db.execute_sql(select_query)\n return [self._row_to_event(row) for row in results.fetchall()]\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage","title":"MemoryStorage
","text":"Handles storage operations for MemoryPalace, including interactions with vector, graph, and relational databases.
Source code in src/aeiva/cognition/memory/memory_storage.py
class MemoryStorage:\n \"\"\"\n Handles storage operations for MemoryPalace, including interactions with vector,\n graph, and relational databases.\n \"\"\"\n\n def __init__(self, config: Dict):\n \"\"\"\n Initialize the MemoryStorage with the provided configuration.\n\n Args:\n config (Any): Configuration settings for MemoryStorage.\n \"\"\"\n self.config_dict = config\n self.config = None\n self.setup()\n\n def setup(self) -> None:\n \"\"\"\n Set up the MemoryStorage's components based on the provided configuration.\n \"\"\"\n try:\n # Initialize Vector Database Configuration\n vector_db_conf_dict = self.config_dict.get('vector_db_config', {})\n vector_db_provider_name = vector_db_conf_dict.get('provider_name', 'milvus')\n vector_db_config = DatabaseConfigFactory.create(\n provider_name=vector_db_conf_dict.get('provider_name', 'milvus'),\n uri=vector_db_conf_dict.get('uri', 'storage/milvus_demo.db'),\n collection_name=vector_db_conf_dict.get('collection_name', 'test_collection'),\n embedding_model_dims=vector_db_conf_dict.get('embedding_model_dims', 1536), # 'text-embedding-ada-002': 1536,\n metric_type=vector_db_conf_dict.get('metric_type', 'COSINE')\n )\n\n # Initialize Graph Database Configuration\n graph_db_conf_dict = self.config_dict.get('graph_db_config', {})\n graph_db_provider_name = graph_db_conf_dict.get('provider_name', 'neo4j')\n graph_db_password = graph_db_conf_dict.get('password')\n graph_db_config = DatabaseConfigFactory.create(\n provider_name=graph_db_conf_dict.get('provider_name', 'neo4j'),\n uri=graph_db_conf_dict.get('uri', 'bolt://localhost:7687'),\n user=graph_db_conf_dict.get('user', 'neo4j'),\n password=graph_db_password,\n database=graph_db_conf_dict.get('database', 'neo4j'),\n encrypted=graph_db_conf_dict.get('encrypted', False)\n )\n\n # Initialize Relational Database Configuration\n relational_db_conf_dict = self.config_dict.get('relational_db_config', {})\n relational_db_provider_name = relational_db_conf_dict.get('provider_name', 'sqlite')\n relational_db_config = DatabaseConfigFactory.create(\n provider_name=relational_db_conf_dict.get('provider_name', 'sqlite'),\n database=relational_db_conf_dict.get('database', 'storage/test_database.db')\n )\n\n self.config = StorageConfig(\n vector_db_provider=self.config_dict.get('vector_db_provider', 'milvus'),\n vector_db_config=vector_db_config,\n graph_db_provider=self.config_dict.get('graph_db_provider', 'neo4j'),\n graph_db_config=graph_db_config,\n relational_db_provider=self.config_dict.get('relational_db_provider', 'sqlite'),\n relational_db_config=relational_db_config,\n )\n\n # Initialize the vector database\n self.vector_db = DatabaseFactory.create(\n provider_name=vector_db_provider_name,\n config=self.config.vector_db_config\n )\n\n # Initialize the graph database if provided\n if graph_db_provider_name and self.config.graph_db_config:\n self.graph_db = DatabaseFactory.create(\n provider_name=graph_db_provider_name,\n config=self.config.graph_db_config\n )\n else:\n self.graph_db = None\n\n # Initialize the relational database if provided\n if relational_db_provider_name and self.config.relational_db_config:\n self.relational_db = DatabaseFactory.create(\n provider_name=relational_db_provider_name,\n config=self.config.relational_db_config\n )\n self.memory_unit_repo = MemoryUnitRepository(self.relational_db)\n self.memory_event_repo = MemoryEventRepository(self.relational_db)\n else:\n self.relational_db = None\n self.memory_unit_repo = None\n self.memory_event_repo = None\n\n logger.info(\"MemoryStorage setup completed successfully.\")\n except Exception as e:\n logger.error(f\"Error during MemoryStorage setup: {e}\")\n self.handle_error(e)\n raise # Re-raise the exception after logging\n\n def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during storage operations.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n logger.error(f\"MemoryStorage encountered an error: {error}\")\n # Additional error handling can be implemented here\n\n def add_memory_unit(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Adds a MemoryUnit to all configured databases.\n\n Args:\n memory_unit (MemoryUnit): The memory unit to add.\n \"\"\"\n try:\n # Add to vector database\n self._add_to_vector_db(memory_unit)\n\n # Add to graph database\n if self.graph_db:\n self._add_to_graph_db(memory_unit)\n\n # Add to relational database\n if self.relational_db and self.memory_unit_repo:\n self._add_to_relational_db(memory_unit)\n\n # Record creation event\n if self.relational_db and self.memory_event_repo:\n self._record_event(\n event_type=\"CREATE\",\n memory_unit=memory_unit\n )\n\n logger.info(f\"Added MemoryUnit with ID: {memory_unit.id} to all databases.\")\n except Exception as e:\n logger.error(f\"Error adding MemoryUnit to databases: {e}\")\n self.handle_error(e)\n raise\n\n def get_memory_unit(self, unit_id: str) -> MemoryUnit:\n \"\"\"\n Retrieves a MemoryUnit by its unique identifier from the relational database.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n\n Returns:\n MemoryUnit: The retrieved memory unit.\n \"\"\"\n try:\n if not self.relational_db or not self.memory_unit_repo:\n raise ValueError(\"Relational database is not configured.\")\n\n memory_unit = self.memory_unit_repo.get(unit_id)\n if not memory_unit:\n raise ValueError(f\"MemoryUnit with ID {unit_id} does not exist.\")\n\n logger.info(f\"Retrieved MemoryUnit with ID: {unit_id} from Relational DB.\")\n return memory_unit\n except Exception as e:\n logger.error(f\"Error retrieving MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n\n def update_memory_unit(self, unit_id: str, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a MemoryUnit in all configured databases.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n updates (Dict[str, Any]): The updates to apply.\n \"\"\"\n try:\n # Retrieve existing MemoryUnit\n memory_unit = self.get_memory_unit(unit_id)\n previous_state = memory_unit.to_dict()\n\n # Apply updates\n for key, value in updates.items():\n setattr(memory_unit, key, value)\n\n # Update in vector database\n self._update_vector_db(memory_unit)\n\n # Update in graph database\n if self.graph_db:\n self._update_graph_db(memory_unit)\n\n # Update in relational database\n if self.relational_db and self.memory_unit_repo:\n self._update_relational_db(memory_unit)\n\n # Record update event\n if self.relational_db and self.memory_event_repo:\n self._record_event(\n event_type=\"UPDATE\",\n memory_unit=memory_unit,\n previous_state=previous_state\n )\n\n logger.info(f\"Updated MemoryUnit with ID: {unit_id} in all databases.\")\n except Exception as e:\n logger.error(f\"Error updating MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n\n def delete_memory_unit(self, unit_id: str) -> None:\n \"\"\"\n Deletes a MemoryUnit from all configured databases.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n try:\n # Retrieve existing MemoryUnit\n memory_unit = self.get_memory_unit(unit_id)\n\n # Delete from vector database\n self._delete_from_vector_db(unit_id)\n\n # Delete from graph database\n if self.graph_db:\n self._delete_from_graph_db(unit_id)\n\n # Delete from relational database\n if self.relational_db and self.memory_unit_repo:\n self._delete_relational_db(unit_id)\n\n # Record deletion event\n if self.relational_db and self.memory_event_repo:\n self._record_event(\n event_type=\"DELETE\",\n memory_unit=memory_unit\n )\n\n logger.info(f\"Deleted MemoryUnit with ID: {unit_id} from all databases.\")\n except Exception as e:\n logger.error(f\"Error deleting MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n\n def get_all_memory_units(self) -> List[MemoryUnit]:\n \"\"\"\n Retrieves all MemoryUnits from the relational database.\n\n Returns:\n List[MemoryUnit]: A list of all memory units.\n \"\"\"\n try:\n if not self.relational_db or not self.memory_unit_repo:\n raise ValueError(\"Relational database is not configured.\")\n\n memory_units = self.memory_unit_repo.list_all()\n logger.info(f\"Retrieved all MemoryUnits from Relational DB. Total count: {len(memory_units)}\")\n return memory_units\n except Exception as e:\n logger.error(f\"Error retrieving all MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n\n def delete_all_memory_units(self) -> None:\n \"\"\"\n Deletes all MemoryUnits from all configured databases.\n \"\"\"\n try:\n # Delete from vector database\n self.vector_db.delete_collection(\n collection_name=self.config.vector_db_config.collection_name\n )\n\n # Delete all nodes from graph database\n if self.graph_db:\n self.graph_db.delete_all()\n\n # Delete all records from relational database\n if self.relational_db and self.memory_unit_repo and self.memory_event_repo:\n self.memory_unit_repo.delete_all()\n self.memory_event_repo.delete_all()\n\n logger.info(\"Deleted all MemoryUnits from all databases.\")\n except Exception as e:\n logger.error(f\"Error deleting all MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n\n # Internal helper methods\n\n def _add_to_vector_db(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Adds the embedding vector of a MemoryUnit to the vector database.\n\n Args:\n memory_unit (MemoryUnit): The memory unit to add.\n \"\"\"\n try:\n # Ensure embedding exists\n if not memory_unit.embedding:\n raise ValueError(\"MemoryUnit does not have an embedding.\")\n\n # Prepare payload with essential fields\n payload = {\n \"id\": memory_unit.id,\n \"type\": memory_unit.type,\n \"modality\": memory_unit.modality\n }\n\n # Insert into vector database\n self.vector_db.insert_vectors(\n collection_name=self.config.vector_db_config.collection_name,\n vectors=[memory_unit.embedding],\n payloads=[payload],\n ids=[memory_unit.id]\n )\n\n logger.info(f\"Inserted embedding for MemoryUnit ID: {memory_unit.id} into Vector DB.\")\n except Exception as e:\n logger.error(f\"Error adding MemoryUnit to Vector DB: {e}\")\n self.handle_error(e)\n raise\n\n def _update_vector_db(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Updates the embedding vector of a MemoryUnit in the vector database.\n\n Args:\n memory_unit (MemoryUnit): The memory unit to update.\n \"\"\"\n try:\n if not memory_unit.embedding:\n raise ValueError(\"MemoryUnit does not have an embedding.\")\n\n payload = {\n \"type\": memory_unit.type,\n \"modality\": memory_unit.modality\n }\n\n self.vector_db.update_vector(\n collection_name=self.config.vector_db_config.collection_name,\n vector_id=memory_unit.id,\n vector=memory_unit.embedding,\n payload=payload\n )\n\n logger.info(f\"Updated embedding for MemoryUnit ID: {memory_unit.id} in Vector DB.\")\n except Exception as e:\n logger.error(f\"Error updating MemoryUnit in Vector DB: {e}\")\n self.handle_error(e)\n raise\n\n def _delete_from_vector_db(self, unit_id: str) -> None:\n \"\"\"\n Deletes a MemoryUnit's embedding from the vector database.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n try:\n self.vector_db.delete_vector(\n collection_name=self.config.vector_db_config.collection_name,\n vector_id=unit_id\n )\n\n logger.info(f\"Deleted embedding for MemoryUnit ID: {unit_id} from Vector DB.\")\n except Exception as e:\n logger.error(f\"Error deleting MemoryUnit from Vector DB: {e}\")\n self.handle_error(e)\n raise\n\n def _add_to_graph_db(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Adds a MemoryUnit as a node in the graph database and establishes relationships.\n\n Args:\n memory_unit (MemoryUnit): The memory unit to add.\n \"\"\"\n try:\n # Serialize complex fields\n properties = {\n \"id\": memory_unit.id,\n \"content\": memory_unit.content,\n \"timestamp\": memory_unit.timestamp.isoformat(),\n \"modality\": memory_unit.modality,\n \"type\": memory_unit.type,\n \"status\": memory_unit.status,\n \"tags\": memory_unit.tags,\n \"embedding\": memory_unit.embedding,\n \"location\": json.dumps(memory_unit.location) if memory_unit.location else None, # Serialized\n \"source_role\": memory_unit.source_role,\n \"source_name\": memory_unit.source_name,\n \"source_id\": memory_unit.source_id,\n \"metadata\": json.dumps(memory_unit.metadata) if memory_unit.metadata else None # Serialized\n }\n\n # Add node to graph database\n self.graph_db.add_node(\n node_id=memory_unit.id,\n properties=properties,\n labels=[memory_unit.type or 'MemoryUnit']\n )\n\n logger.info(f\"Added MemoryUnit ID: {memory_unit.id} to Graph DB.\")\n\n # Add relationships (edges) if any\n for link in memory_unit.edges:\n # Serialize edge metadata if necessary\n edge_properties = {}\n if link.metadata:\n edge_properties['metadata'] = json.dumps(link.metadata)\n\n self.graph_db.add_edge(\n source_id=link.source_id,\n target_id=link.target_id,\n relationship=link.relationship,\n properties=edge_properties\n )\n\n logger.info(f\"Added {len(memory_unit.edges)} edges for MemoryUnit ID: {memory_unit.id} in Graph DB.\")\n except Exception as e:\n logger.error(f\"Error adding MemoryUnit to Graph DB: {e}\")\n self.handle_error(e)\n raise\n\n def _update_graph_db(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Updates a MemoryUnit in the graph database.\n\n Args:\n memory_unit (MemoryUnit): The memory unit to update.\n \"\"\"\n try:\n # Update node properties\n properties = {\n \"content\": memory_unit.content,\n \"timestamp\": memory_unit.timestamp.isoformat(),\n \"modality\": memory_unit.modality,\n \"type\": memory_unit.type,\n \"status\": memory_unit.status,\n \"tags\": memory_unit.tags,\n \"embedding\": memory_unit.embedding,\n \"location\": json.dumps(memory_unit.location) if memory_unit.location else None, # Serialized\n \"source_role\": memory_unit.source_role,\n \"source_name\": memory_unit.source_name,\n \"source_id\": memory_unit.source_id,\n \"metadata\": json.dumps(memory_unit.metadata) if memory_unit.metadata else None # Serialized\n }\n\n self.graph_db.update_node(\n node_id=memory_unit.id,\n properties=properties\n )\n\n # Handle edges updates as needed\n # This can be complex and depends on your specific requirements\n\n logger.info(f\"Updated MemoryUnit ID: {memory_unit.id} in Graph DB.\")\n except Exception as e:\n logger.error(f\"Error updating MemoryUnit in Graph DB: {e}\")\n self.handle_error(e)\n raise\n\n def _delete_from_graph_db(self, unit_id: str) -> None:\n \"\"\"\n Deletes a MemoryUnit from the graph database.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n try:\n self.graph_db.delete_node(node_id=unit_id)\n logger.info(f\"Deleted MemoryUnit ID: {unit_id} from Graph DB.\")\n except Exception as e:\n logger.error(f\"Error deleting MemoryUnit from Graph DB: {e}\")\n self.handle_error(e)\n raise\n\n def _add_to_relational_db(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Adds a MemoryUnit to the relational database.\n\n Args:\n memory_unit (MemoryUnit): The memory unit to add.\n \"\"\"\n try:\n self.memory_unit_repo.add(memory_unit)\n logger.info(f\"Inserted MemoryUnit ID: {memory_unit.id} into Relational DB.\")\n except Exception as e:\n logger.error(f\"Error adding MemoryUnit to Relational DB: {e}\")\n raise\n\n def _update_relational_db(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Updates a MemoryUnit in the relational database.\n\n Args:\n memory_unit (MemoryUnit): The memory unit to update.\n \"\"\"\n try:\n self.memory_unit_repo.update(memory_unit)\n logger.info(f\"Updated MemoryUnit ID: {memory_unit.id} in Relational DB.\")\n except Exception as e:\n logger.error(f\"Error updating MemoryUnit in Relational DB: {e}\")\n raise\n\n def _delete_relational_db(self, unit_id: str) -> None:\n \"\"\"\n Deletes a MemoryUnit from the relational database.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n try:\n self.memory_unit_repo.delete(unit_id)\n logger.info(f\"Deleted MemoryUnit ID: {unit_id} from Relational DB.\")\n except Exception as e:\n logger.error(f\"Error deleting MemoryUnit from Relational DB: {e}\")\n raise\n\n def _record_event(self, event_type: str, memory_unit: MemoryUnit, previous_state: Optional[Dict[str, Any]] = None) -> None:\n \"\"\"\n Records an event in the relational database.\n\n Args:\n event_type (str): The type of event ('CREATE', 'UPDATE', 'DELETE').\n memory_unit (MemoryUnit): The memory unit involved in the event.\n previous_state (Optional[Dict[str, Any]]): The previous state of the memory unit (for updates).\n \"\"\"\n try:\n event_record = {\n \"memory_id\": memory_unit.id,\n \"event_type\": event_type,\n \"memory_data\": json.dumps(memory_unit.to_dict()),\n \"previous_data\": json.dumps(previous_state) if previous_state else None\n }\n\n self.memory_event_repo.add(event_record)\n logger.info(f\"Recorded event '{event_type}' for MemoryUnit ID: {memory_unit.id}.\")\n except Exception as e:\n logger.error(f\"Error recording event in Relational DB: {e}\")\n raise\n\n def retrieve_similar_memory_units(self, query_embedding: List[float], top_k: int) -> List[MemoryUnit]:\n \"\"\"\n Retrieves memory units similar to the given embedding.\n\n Args:\n query_embedding (List[float]): The embedding vector of the query.\n top_k (int): The number of similar units to retrieve.\n\n Returns:\n List[MemoryUnit]: A list of similar memory units.\n \"\"\"\n try:\n # Perform similarity search\n results = self.vector_db.search_vectors(\n collection_name=self.config.vector_db_config.collection_name,\n query_vector=query_embedding,\n top_k=top_k\n )\n\n memory_units = []\n for result in results:\n unit_id = result['id']\n memory_unit = self.get_memory_unit(unit_id)\n memory_units.append(memory_unit)\n\n logger.info(f\"Retrieved {len(memory_units)} similar MemoryUnits.\")\n return memory_units\n except Exception as e:\n logger.error(f\"Error retrieving similar MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n\n def retrieve_related_memory_units(self, unit_id: str, relationship: Optional[str] = None) -> List[MemoryUnit]:\n \"\"\"\n Retrieves memory units related to the given one based on relationships.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n relationship (Optional[str]): Filter by relationship type.\n\n Returns:\n List[MemoryUnit]: A list of related memory units.\n \"\"\"\n try:\n if not self.graph_db:\n raise ValueError(\"Graph database is not configured.\")\n\n # Retrieve related nodes from graph database\n neighbors = self.graph_db.get_neighbors(\n node_id=unit_id,\n relationship=relationship\n )\n\n related_units = []\n for neighbor in neighbors:\n related_unit = self.get_memory_unit(neighbor['id'])\n related_units.append(related_unit)\n\n logger.info(f\"Retrieved {len(related_units)} related MemoryUnits.\")\n return related_units\n except Exception as e:\n logger.error(f\"Error retrieving related MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.__init__","title":"__init__(config)
","text":"Initialize the MemoryStorage with the provided configuration.
Parameters:
Name Type Description Default config
Any
Configuration settings for MemoryStorage.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def __init__(self, config: Dict):\n \"\"\"\n Initialize the MemoryStorage with the provided configuration.\n\n Args:\n config (Any): Configuration settings for MemoryStorage.\n \"\"\"\n self.config_dict = config\n self.config = None\n self.setup()\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.add_memory_unit","title":"add_memory_unit(memory_unit)
","text":"Adds a MemoryUnit to all configured databases.
Parameters:
Name Type Description Default memory_unit
MemoryUnit
The memory unit to add.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def add_memory_unit(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Adds a MemoryUnit to all configured databases.\n\n Args:\n memory_unit (MemoryUnit): The memory unit to add.\n \"\"\"\n try:\n # Add to vector database\n self._add_to_vector_db(memory_unit)\n\n # Add to graph database\n if self.graph_db:\n self._add_to_graph_db(memory_unit)\n\n # Add to relational database\n if self.relational_db and self.memory_unit_repo:\n self._add_to_relational_db(memory_unit)\n\n # Record creation event\n if self.relational_db and self.memory_event_repo:\n self._record_event(\n event_type=\"CREATE\",\n memory_unit=memory_unit\n )\n\n logger.info(f\"Added MemoryUnit with ID: {memory_unit.id} to all databases.\")\n except Exception as e:\n logger.error(f\"Error adding MemoryUnit to databases: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.delete_all_memory_units","title":"delete_all_memory_units()
","text":"Deletes all MemoryUnits from all configured databases.
Source code in src/aeiva/cognition/memory/memory_storage.py
def delete_all_memory_units(self) -> None:\n \"\"\"\n Deletes all MemoryUnits from all configured databases.\n \"\"\"\n try:\n # Delete from vector database\n self.vector_db.delete_collection(\n collection_name=self.config.vector_db_config.collection_name\n )\n\n # Delete all nodes from graph database\n if self.graph_db:\n self.graph_db.delete_all()\n\n # Delete all records from relational database\n if self.relational_db and self.memory_unit_repo and self.memory_event_repo:\n self.memory_unit_repo.delete_all()\n self.memory_event_repo.delete_all()\n\n logger.info(\"Deleted all MemoryUnits from all databases.\")\n except Exception as e:\n logger.error(f\"Error deleting all MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.delete_memory_unit","title":"delete_memory_unit(unit_id)
","text":"Deletes a MemoryUnit from all configured databases.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def delete_memory_unit(self, unit_id: str) -> None:\n \"\"\"\n Deletes a MemoryUnit from all configured databases.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n \"\"\"\n try:\n # Retrieve existing MemoryUnit\n memory_unit = self.get_memory_unit(unit_id)\n\n # Delete from vector database\n self._delete_from_vector_db(unit_id)\n\n # Delete from graph database\n if self.graph_db:\n self._delete_from_graph_db(unit_id)\n\n # Delete from relational database\n if self.relational_db and self.memory_unit_repo:\n self._delete_relational_db(unit_id)\n\n # Record deletion event\n if self.relational_db and self.memory_event_repo:\n self._record_event(\n event_type=\"DELETE\",\n memory_unit=memory_unit\n )\n\n logger.info(f\"Deleted MemoryUnit with ID: {unit_id} from all databases.\")\n except Exception as e:\n logger.error(f\"Error deleting MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.get_all_memory_units","title":"get_all_memory_units()
","text":"Retrieves all MemoryUnits from the relational database.
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of all memory units.
Source code in src/aeiva/cognition/memory/memory_storage.py
def get_all_memory_units(self) -> List[MemoryUnit]:\n \"\"\"\n Retrieves all MemoryUnits from the relational database.\n\n Returns:\n List[MemoryUnit]: A list of all memory units.\n \"\"\"\n try:\n if not self.relational_db or not self.memory_unit_repo:\n raise ValueError(\"Relational database is not configured.\")\n\n memory_units = self.memory_unit_repo.list_all()\n logger.info(f\"Retrieved all MemoryUnits from Relational DB. Total count: {len(memory_units)}\")\n return memory_units\n except Exception as e:\n logger.error(f\"Error retrieving all MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.get_memory_unit","title":"get_memory_unit(unit_id)
","text":"Retrieves a MemoryUnit by its unique identifier from the relational database.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required Returns:
Name Type Description MemoryUnit
MemoryUnit
The retrieved memory unit.
Source code in src/aeiva/cognition/memory/memory_storage.py
def get_memory_unit(self, unit_id: str) -> MemoryUnit:\n \"\"\"\n Retrieves a MemoryUnit by its unique identifier from the relational database.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n\n Returns:\n MemoryUnit: The retrieved memory unit.\n \"\"\"\n try:\n if not self.relational_db or not self.memory_unit_repo:\n raise ValueError(\"Relational database is not configured.\")\n\n memory_unit = self.memory_unit_repo.get(unit_id)\n if not memory_unit:\n raise ValueError(f\"MemoryUnit with ID {unit_id} does not exist.\")\n\n logger.info(f\"Retrieved MemoryUnit with ID: {unit_id} from Relational DB.\")\n return memory_unit\n except Exception as e:\n logger.error(f\"Error retrieving MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.handle_error","title":"handle_error(error)
","text":"Handle errors that occur during storage operations.
Parameters:
Name Type Description Default error
Exception
The exception that was raised.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during storage operations.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n logger.error(f\"MemoryStorage encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.retrieve_related_memory_units","title":"retrieve_related_memory_units(unit_id, relationship=None)
","text":"Retrieves memory units related to the given one based on relationships.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required relationship
Optional[str]
Filter by relationship type.
None
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of related memory units.
Source code in src/aeiva/cognition/memory/memory_storage.py
def retrieve_related_memory_units(self, unit_id: str, relationship: Optional[str] = None) -> List[MemoryUnit]:\n \"\"\"\n Retrieves memory units related to the given one based on relationships.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n relationship (Optional[str]): Filter by relationship type.\n\n Returns:\n List[MemoryUnit]: A list of related memory units.\n \"\"\"\n try:\n if not self.graph_db:\n raise ValueError(\"Graph database is not configured.\")\n\n # Retrieve related nodes from graph database\n neighbors = self.graph_db.get_neighbors(\n node_id=unit_id,\n relationship=relationship\n )\n\n related_units = []\n for neighbor in neighbors:\n related_unit = self.get_memory_unit(neighbor['id'])\n related_units.append(related_unit)\n\n logger.info(f\"Retrieved {len(related_units)} related MemoryUnits.\")\n return related_units\n except Exception as e:\n logger.error(f\"Error retrieving related MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.retrieve_similar_memory_units","title":"retrieve_similar_memory_units(query_embedding, top_k)
","text":"Retrieves memory units similar to the given embedding.
Parameters:
Name Type Description Default query_embedding
List[float]
The embedding vector of the query.
required top_k
int
The number of similar units to retrieve.
required Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of similar memory units.
Source code in src/aeiva/cognition/memory/memory_storage.py
def retrieve_similar_memory_units(self, query_embedding: List[float], top_k: int) -> List[MemoryUnit]:\n \"\"\"\n Retrieves memory units similar to the given embedding.\n\n Args:\n query_embedding (List[float]): The embedding vector of the query.\n top_k (int): The number of similar units to retrieve.\n\n Returns:\n List[MemoryUnit]: A list of similar memory units.\n \"\"\"\n try:\n # Perform similarity search\n results = self.vector_db.search_vectors(\n collection_name=self.config.vector_db_config.collection_name,\n query_vector=query_embedding,\n top_k=top_k\n )\n\n memory_units = []\n for result in results:\n unit_id = result['id']\n memory_unit = self.get_memory_unit(unit_id)\n memory_units.append(memory_unit)\n\n logger.info(f\"Retrieved {len(memory_units)} similar MemoryUnits.\")\n return memory_units\n except Exception as e:\n logger.error(f\"Error retrieving similar MemoryUnits: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.setup","title":"setup()
","text":"Set up the MemoryStorage's components based on the provided configuration.
Source code in src/aeiva/cognition/memory/memory_storage.py
def setup(self) -> None:\n \"\"\"\n Set up the MemoryStorage's components based on the provided configuration.\n \"\"\"\n try:\n # Initialize Vector Database Configuration\n vector_db_conf_dict = self.config_dict.get('vector_db_config', {})\n vector_db_provider_name = vector_db_conf_dict.get('provider_name', 'milvus')\n vector_db_config = DatabaseConfigFactory.create(\n provider_name=vector_db_conf_dict.get('provider_name', 'milvus'),\n uri=vector_db_conf_dict.get('uri', 'storage/milvus_demo.db'),\n collection_name=vector_db_conf_dict.get('collection_name', 'test_collection'),\n embedding_model_dims=vector_db_conf_dict.get('embedding_model_dims', 1536), # 'text-embedding-ada-002': 1536,\n metric_type=vector_db_conf_dict.get('metric_type', 'COSINE')\n )\n\n # Initialize Graph Database Configuration\n graph_db_conf_dict = self.config_dict.get('graph_db_config', {})\n graph_db_provider_name = graph_db_conf_dict.get('provider_name', 'neo4j')\n graph_db_password = graph_db_conf_dict.get('password')\n graph_db_config = DatabaseConfigFactory.create(\n provider_name=graph_db_conf_dict.get('provider_name', 'neo4j'),\n uri=graph_db_conf_dict.get('uri', 'bolt://localhost:7687'),\n user=graph_db_conf_dict.get('user', 'neo4j'),\n password=graph_db_password,\n database=graph_db_conf_dict.get('database', 'neo4j'),\n encrypted=graph_db_conf_dict.get('encrypted', False)\n )\n\n # Initialize Relational Database Configuration\n relational_db_conf_dict = self.config_dict.get('relational_db_config', {})\n relational_db_provider_name = relational_db_conf_dict.get('provider_name', 'sqlite')\n relational_db_config = DatabaseConfigFactory.create(\n provider_name=relational_db_conf_dict.get('provider_name', 'sqlite'),\n database=relational_db_conf_dict.get('database', 'storage/test_database.db')\n )\n\n self.config = StorageConfig(\n vector_db_provider=self.config_dict.get('vector_db_provider', 'milvus'),\n vector_db_config=vector_db_config,\n graph_db_provider=self.config_dict.get('graph_db_provider', 'neo4j'),\n graph_db_config=graph_db_config,\n relational_db_provider=self.config_dict.get('relational_db_provider', 'sqlite'),\n relational_db_config=relational_db_config,\n )\n\n # Initialize the vector database\n self.vector_db = DatabaseFactory.create(\n provider_name=vector_db_provider_name,\n config=self.config.vector_db_config\n )\n\n # Initialize the graph database if provided\n if graph_db_provider_name and self.config.graph_db_config:\n self.graph_db = DatabaseFactory.create(\n provider_name=graph_db_provider_name,\n config=self.config.graph_db_config\n )\n else:\n self.graph_db = None\n\n # Initialize the relational database if provided\n if relational_db_provider_name and self.config.relational_db_config:\n self.relational_db = DatabaseFactory.create(\n provider_name=relational_db_provider_name,\n config=self.config.relational_db_config\n )\n self.memory_unit_repo = MemoryUnitRepository(self.relational_db)\n self.memory_event_repo = MemoryEventRepository(self.relational_db)\n else:\n self.relational_db = None\n self.memory_unit_repo = None\n self.memory_event_repo = None\n\n logger.info(\"MemoryStorage setup completed successfully.\")\n except Exception as e:\n logger.error(f\"Error during MemoryStorage setup: {e}\")\n self.handle_error(e)\n raise # Re-raise the exception after logging\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryStorage.update_memory_unit","title":"update_memory_unit(unit_id, updates)
","text":"Updates a MemoryUnit in all configured databases.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required updates
Dict[str, Any]
The updates to apply.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def update_memory_unit(self, unit_id: str, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a MemoryUnit in all configured databases.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n updates (Dict[str, Any]): The updates to apply.\n \"\"\"\n try:\n # Retrieve existing MemoryUnit\n memory_unit = self.get_memory_unit(unit_id)\n previous_state = memory_unit.to_dict()\n\n # Apply updates\n for key, value in updates.items():\n setattr(memory_unit, key, value)\n\n # Update in vector database\n self._update_vector_db(memory_unit)\n\n # Update in graph database\n if self.graph_db:\n self._update_graph_db(memory_unit)\n\n # Update in relational database\n if self.relational_db and self.memory_unit_repo:\n self._update_relational_db(memory_unit)\n\n # Record update event\n if self.relational_db and self.memory_event_repo:\n self._record_event(\n event_type=\"UPDATE\",\n memory_unit=memory_unit,\n previous_state=previous_state\n )\n\n logger.info(f\"Updated MemoryUnit with ID: {unit_id} in all databases.\")\n except Exception as e:\n logger.error(f\"Error updating MemoryUnit with ID {unit_id}: {e}\")\n self.handle_error(e)\n raise\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryUnitRepository","title":"MemoryUnitRepository
","text":"Repository for MemoryUnit to handle CRUD operations without SQLAlchemy.
Source code in src/aeiva/cognition/memory/memory_storage.py
class MemoryUnitRepository:\n \"\"\"\n Repository for MemoryUnit to handle CRUD operations without SQLAlchemy.\n \"\"\"\n\n def __init__(self, db: Any):\n \"\"\"\n Initialize the repository with a DatabaseFactory instance.\n\n Args:\n db (Any): An instance of DatabaseFactory for relational databases.\n \"\"\"\n self.db = db\n self.table_name = 'memory_units'\n self._create_table()\n\n def _create_table(self):\n \"\"\"\n Creates the memory_units table if it does not exist.\n \"\"\"\n create_table_query = f\"\"\"\n CREATE TABLE IF NOT EXISTS {self.table_name} (\n id TEXT PRIMARY KEY,\n content TEXT NOT NULL,\n timestamp TEXT NOT NULL,\n modality TEXT,\n type TEXT,\n status TEXT,\n tags TEXT,\n embedding TEXT,\n location TEXT,\n source_role TEXT,\n source_name TEXT,\n source_id TEXT,\n edges TEXT,\n metadata TEXT\n );\n \"\"\"\n self.db.execute_sql(create_table_query)\n\n def add(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Adds a MemoryUnit to the relational database.\n\n Args:\n memory_unit (MemoryUnit): The memory unit to add.\n \"\"\"\n insert_query = f\"\"\"\n INSERT INTO {self.table_name} (id, content, timestamp, modality, type, status, tags, embedding, location, \n source_role, source_name, source_id, edges, metadata)\n VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);\n \"\"\"\n data = (\n memory_unit.id,\n memory_unit.content,\n memory_unit.timestamp.isoformat(),\n memory_unit.modality,\n memory_unit.type,\n memory_unit.status,\n json.dumps(memory_unit.tags),\n json.dumps(memory_unit.embedding) if memory_unit.embedding else None,\n json.dumps(memory_unit.location) if memory_unit.location else None,\n memory_unit.source_role,\n memory_unit.source_name,\n memory_unit.source_id,\n json.dumps([link.to_dict() for link in memory_unit.edges]),\n json.dumps(memory_unit.metadata) if memory_unit.metadata else None\n )\n self.db.execute_sql(insert_query, data)\n\n def get(self, unit_id: str) -> Optional[MemoryUnit]:\n \"\"\"\n Retrieves a MemoryUnit by its ID.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n\n Returns:\n Optional[MemoryUnit]: The retrieved memory unit or None if not found.\n \"\"\"\n select_query = f\"SELECT * FROM {self.table_name} WHERE id = ?;\"\n result = self.db.execute_sql(select_query, (unit_id,))\n row = result.fetchone()\n if row:\n return self._row_to_memory_unit(row)\n return None\n\n def update(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Updates an existing MemoryUnit in the relational database.\n\n Args:\n memory_unit (MemoryUnit): The memory unit with updated data.\n \"\"\"\n update_query = f\"\"\"\n UPDATE {self.table_name}\n SET content = ?, timestamp = ?, modality = ?, type = ?, status = ?, tags = ?, embedding = ?, \n location = ?, source_role = ?, source_name = ?, source_id = ?, edges = ?, metadata = ?\n WHERE id = ?;\n \"\"\"\n data = (\n memory_unit.content,\n memory_unit.timestamp.isoformat(),\n memory_unit.modality,\n memory_unit.type,\n memory_unit.status,\n json.dumps(memory_unit.tags),\n json.dumps(memory_unit.embedding) if memory_unit.embedding else None,\n json.dumps(memory_unit.location) if memory_unit.location else None,\n memory_unit.source_role,\n memory_unit.source_name,\n memory_unit.source_id,\n json.dumps([link.to_dict() for link in memory_unit.edges]),\n json.dumps(memory_unit.metadata) if memory_unit.metadata else None,\n memory_unit.id\n )\n self.db.execute_sql(update_query, data)\n\n def delete(self, unit_id: str) -> None:\n \"\"\"\n Deletes a MemoryUnit from the relational database.\n\n Args:\n unit_id (str): The unique identifier of the memory unit to delete.\n \"\"\"\n delete_query = f\"DELETE FROM {self.table_name} WHERE id = ?;\"\n self.db.execute_sql(delete_query, (unit_id,))\n\n def list_all(self) -> List[MemoryUnit]:\n \"\"\"\n Retrieves all MemoryUnits from the relational database.\n\n Returns:\n List[MemoryUnit]: A list of all memory units.\n \"\"\"\n select_query = f\"SELECT * FROM {self.table_name};\"\n results = self.db.execute_sql(select_query)\n return [self._row_to_memory_unit(row) for row in results.fetchall()]\n\n def delete_all(self) -> None:\n \"\"\"\n Deletes all MemoryUnits from the relational database.\n \"\"\"\n delete_query = f\"DELETE FROM {self.table_name};\"\n self.db.execute_sql(delete_query)\n\n def _row_to_memory_unit(self, row: Any) -> MemoryUnit:\n \"\"\"\n Converts a database row to a MemoryUnit instance.\n\n Args:\n row (Any): A row fetched from the database.\n\n Returns:\n MemoryUnit: The corresponding MemoryUnit instance.\n \"\"\"\n return MemoryUnit(\n id=row['id'],\n content=row['content'],\n timestamp=datetime.fromisoformat(row['timestamp']),\n modality=row['modality'],\n type=row['type'],\n status=row['status'],\n tags=json.loads(row['tags']) if row['tags'] else [],\n embedding=json.loads(row['embedding']) if row['embedding'] else [],\n location=json.loads(row['location']) if row['location'] else {},\n source_role=row['source_role'],\n source_name=row['source_name'],\n source_id=row['source_id'],\n edges=[MemoryLink.from_dict(link) for link in json.loads(row['edges'])] if row['edges'] else [],\n metadata=json.loads(row['metadata']) if row['metadata'] else {}\n )\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryUnitRepository.__init__","title":"__init__(db)
","text":"Initialize the repository with a DatabaseFactory instance.
Parameters:
Name Type Description Default db
Any
An instance of DatabaseFactory for relational databases.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def __init__(self, db: Any):\n \"\"\"\n Initialize the repository with a DatabaseFactory instance.\n\n Args:\n db (Any): An instance of DatabaseFactory for relational databases.\n \"\"\"\n self.db = db\n self.table_name = 'memory_units'\n self._create_table()\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryUnitRepository.add","title":"add(memory_unit)
","text":"Adds a MemoryUnit to the relational database.
Parameters:
Name Type Description Default memory_unit
MemoryUnit
The memory unit to add.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def add(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Adds a MemoryUnit to the relational database.\n\n Args:\n memory_unit (MemoryUnit): The memory unit to add.\n \"\"\"\n insert_query = f\"\"\"\n INSERT INTO {self.table_name} (id, content, timestamp, modality, type, status, tags, embedding, location, \n source_role, source_name, source_id, edges, metadata)\n VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);\n \"\"\"\n data = (\n memory_unit.id,\n memory_unit.content,\n memory_unit.timestamp.isoformat(),\n memory_unit.modality,\n memory_unit.type,\n memory_unit.status,\n json.dumps(memory_unit.tags),\n json.dumps(memory_unit.embedding) if memory_unit.embedding else None,\n json.dumps(memory_unit.location) if memory_unit.location else None,\n memory_unit.source_role,\n memory_unit.source_name,\n memory_unit.source_id,\n json.dumps([link.to_dict() for link in memory_unit.edges]),\n json.dumps(memory_unit.metadata) if memory_unit.metadata else None\n )\n self.db.execute_sql(insert_query, data)\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryUnitRepository.delete","title":"delete(unit_id)
","text":"Deletes a MemoryUnit from the relational database.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit to delete.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def delete(self, unit_id: str) -> None:\n \"\"\"\n Deletes a MemoryUnit from the relational database.\n\n Args:\n unit_id (str): The unique identifier of the memory unit to delete.\n \"\"\"\n delete_query = f\"DELETE FROM {self.table_name} WHERE id = ?;\"\n self.db.execute_sql(delete_query, (unit_id,))\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryUnitRepository.delete_all","title":"delete_all()
","text":"Deletes all MemoryUnits from the relational database.
Source code in src/aeiva/cognition/memory/memory_storage.py
def delete_all(self) -> None:\n \"\"\"\n Deletes all MemoryUnits from the relational database.\n \"\"\"\n delete_query = f\"DELETE FROM {self.table_name};\"\n self.db.execute_sql(delete_query)\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryUnitRepository.get","title":"get(unit_id)
","text":"Retrieves a MemoryUnit by its ID.
Parameters:
Name Type Description Default unit_id
str
The unique identifier of the memory unit.
required Returns:
Type Description Optional[MemoryUnit]
Optional[MemoryUnit]: The retrieved memory unit or None if not found.
Source code in src/aeiva/cognition/memory/memory_storage.py
def get(self, unit_id: str) -> Optional[MemoryUnit]:\n \"\"\"\n Retrieves a MemoryUnit by its ID.\n\n Args:\n unit_id (str): The unique identifier of the memory unit.\n\n Returns:\n Optional[MemoryUnit]: The retrieved memory unit or None if not found.\n \"\"\"\n select_query = f\"SELECT * FROM {self.table_name} WHERE id = ?;\"\n result = self.db.execute_sql(select_query, (unit_id,))\n row = result.fetchone()\n if row:\n return self._row_to_memory_unit(row)\n return None\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryUnitRepository.list_all","title":"list_all()
","text":"Retrieves all MemoryUnits from the relational database.
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: A list of all memory units.
Source code in src/aeiva/cognition/memory/memory_storage.py
def list_all(self) -> List[MemoryUnit]:\n \"\"\"\n Retrieves all MemoryUnits from the relational database.\n\n Returns:\n List[MemoryUnit]: A list of all memory units.\n \"\"\"\n select_query = f\"SELECT * FROM {self.table_name};\"\n results = self.db.execute_sql(select_query)\n return [self._row_to_memory_unit(row) for row in results.fetchall()]\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_storage.MemoryUnitRepository.update","title":"update(memory_unit)
","text":"Updates an existing MemoryUnit in the relational database.
Parameters:
Name Type Description Default memory_unit
MemoryUnit
The memory unit with updated data.
required Source code in src/aeiva/cognition/memory/memory_storage.py
def update(self, memory_unit: MemoryUnit) -> None:\n \"\"\"\n Updates an existing MemoryUnit in the relational database.\n\n Args:\n memory_unit (MemoryUnit): The memory unit with updated data.\n \"\"\"\n update_query = f\"\"\"\n UPDATE {self.table_name}\n SET content = ?, timestamp = ?, modality = ?, type = ?, status = ?, tags = ?, embedding = ?, \n location = ?, source_role = ?, source_name = ?, source_id = ?, edges = ?, metadata = ?\n WHERE id = ?;\n \"\"\"\n data = (\n memory_unit.content,\n memory_unit.timestamp.isoformat(),\n memory_unit.modality,\n memory_unit.type,\n memory_unit.status,\n json.dumps(memory_unit.tags),\n json.dumps(memory_unit.embedding) if memory_unit.embedding else None,\n json.dumps(memory_unit.location) if memory_unit.location else None,\n memory_unit.source_role,\n memory_unit.source_name,\n memory_unit.source_id,\n json.dumps([link.to_dict() for link in memory_unit.edges]),\n json.dumps(memory_unit.metadata) if memory_unit.metadata else None,\n memory_unit.id\n )\n self.db.execute_sql(update_query, data)\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_structurer","title":"memory_structurer
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_structurer.MemoryStructurer","title":"MemoryStructurer
","text":"A class to structure memory units based on various structuring algorithms.
Supported structure types - 'structure_type_example': Placeholder for future structuring algorithms.
Source code in src/aeiva/cognition/memory/memory_structurer.py
class MemoryStructurer:\n \"\"\"\n A class to structure memory units based on various structuring algorithms.\n\n Supported structure types:\n - 'structure_type_example': Placeholder for future structuring algorithms.\n \"\"\"\n\n def __init__(self):\n \"\"\"\n Initializes the MemoryStructurer.\n\n Currently, no initialization parameters are required.\n \"\"\"\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemoryStructurer without default parameters.\")\n\n def structure(\n self,\n memory_units: List[MemoryUnit],\n structure_type: str,\n **kwargs\n ) -> List[MemoryUnit]:\n \"\"\"\n Structures the provided memory units based on the specified structure type.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be structured.\n structure_type (str): The type of structuring algorithm to use ('structure_type_example').\n **kwargs: Additional parameters required for specific structurers.\n\n Returns:\n List[MemoryUnit]: The list of memory units after structuring.\n\n Raises:\n MemoryStructurerError: If an unknown structure_type is provided or if structuring fails.\n \"\"\"\n self.logger.debug(f\"Structuring memory units using structure_type='{structure_type}' with kwargs={kwargs}\")\n try:\n if structure_type == 'structure_type_example':\n # Placeholder for actual structuring logic\n return self.structure_example(memory_units, **kwargs)\n else:\n self.logger.error(f\"Unknown structure_type: {structure_type}\")\n raise MemoryStructurerError(f\"Unknown structure_type: {structure_type}\")\n except MemoryStructurerError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to structure memory units: {e}\")\n raise MemoryStructurerError(f\"Failed to structure memory units: {e}\")\n\n def structure_example(\n self,\n memory_units: List[MemoryUnit],\n **kwargs\n ) -> List[MemoryUnit]:\n \"\"\"\n Example structuring method. Currently a placeholder that returns memory units unchanged.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be structured.\n **kwargs: Additional parameters (currently unused).\n\n Returns:\n List[MemoryUnit]: The original list of memory units, unchanged.\n \"\"\"\n self.logger.debug(\"Executing structure_example: No changes applied to memory units.\")\n # Placeholder: No operation performed\n return memory_units\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_structurer.MemoryStructurer.__init__","title":"__init__()
","text":"Initializes the MemoryStructurer.
Currently, no initialization parameters are required.
Source code in src/aeiva/cognition/memory/memory_structurer.py
def __init__(self):\n \"\"\"\n Initializes the MemoryStructurer.\n\n Currently, no initialization parameters are required.\n \"\"\"\n self.logger = logging.getLogger(self.__class__.__name__)\n self.logger.debug(\"Initialized MemoryStructurer without default parameters.\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_structurer.MemoryStructurer.structure","title":"structure(memory_units, structure_type, **kwargs)
","text":"Structures the provided memory units based on the specified structure type.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be structured.
required structure_type
str
The type of structuring algorithm to use ('structure_type_example').
required **kwargs
Additional parameters required for specific structurers.
{}
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The list of memory units after structuring.
Raises:
Type Description MemoryStructurerError
If an unknown structure_type is provided or if structuring fails.
Source code in src/aeiva/cognition/memory/memory_structurer.py
def structure(\n self,\n memory_units: List[MemoryUnit],\n structure_type: str,\n **kwargs\n) -> List[MemoryUnit]:\n \"\"\"\n Structures the provided memory units based on the specified structure type.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be structured.\n structure_type (str): The type of structuring algorithm to use ('structure_type_example').\n **kwargs: Additional parameters required for specific structurers.\n\n Returns:\n List[MemoryUnit]: The list of memory units after structuring.\n\n Raises:\n MemoryStructurerError: If an unknown structure_type is provided or if structuring fails.\n \"\"\"\n self.logger.debug(f\"Structuring memory units using structure_type='{structure_type}' with kwargs={kwargs}\")\n try:\n if structure_type == 'structure_type_example':\n # Placeholder for actual structuring logic\n return self.structure_example(memory_units, **kwargs)\n else:\n self.logger.error(f\"Unknown structure_type: {structure_type}\")\n raise MemoryStructurerError(f\"Unknown structure_type: {structure_type}\")\n except MemoryStructurerError:\n # Re-raise custom errors without modification\n raise\n except Exception as e:\n self.logger.error(f\"Failed to structure memory units: {e}\")\n raise MemoryStructurerError(f\"Failed to structure memory units: {e}\")\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_structurer.MemoryStructurer.structure_example","title":"structure_example(memory_units, **kwargs)
","text":"Example structuring method. Currently a placeholder that returns memory units unchanged.
Parameters:
Name Type Description Default memory_units
List[MemoryUnit]
The list of memory units to be structured.
required **kwargs
Additional parameters (currently unused).
{}
Returns:
Type Description List[MemoryUnit]
List[MemoryUnit]: The original list of memory units, unchanged.
Source code in src/aeiva/cognition/memory/memory_structurer.py
def structure_example(\n self,\n memory_units: List[MemoryUnit],\n **kwargs\n) -> List[MemoryUnit]:\n \"\"\"\n Example structuring method. Currently a placeholder that returns memory units unchanged.\n\n Args:\n memory_units (List[MemoryUnit]): The list of memory units to be structured.\n **kwargs: Additional parameters (currently unused).\n\n Returns:\n List[MemoryUnit]: The original list of memory units, unchanged.\n \"\"\"\n self.logger.debug(\"Executing structure_example: No changes applied to memory units.\")\n # Placeholder: No operation performed\n return memory_units\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_structurer.MemoryStructurerError","title":"MemoryStructurerError
","text":" Bases: Exception
Exception raised when an error occurs in the MemoryStructurer.
Source code in src/aeiva/cognition/memory/memory_structurer.py
class MemoryStructurerError(Exception):\n \"\"\"Exception raised when an error occurs in the MemoryStructurer.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_unit","title":"memory_unit
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_unit.MemoryUnit","title":"MemoryUnit
","text":" Bases: BaseModel
MemoryUnit represents a single unit of memory with core content and rich metadata. It includes fields for tracking information about the memory\u2019s source, modality, temporal and spatial attributes, and its connections to other memory units.
Essential Fields id (str): Unique identifier for the memory unit, generated as a UUID string by default. content (Any): Core content of the memory, which is convertible to a string.
Metadata timestamp (datetime): Creation timestamp, defaulting to the current time. modality (Optional[str]): Modality type, such as 'text', 'image', 'audio'. type (Optional[str]): Semantic type, such as 'dialogue', 'summary', 'document'. status (Optional[str]): Processing status, e.g., 'raw', 'cleaned', 'processed'. tags (Optional[List[str]]): Tags for categorization and filtering. embedding (Optional[List[float]]): Vector embedding for retrieval. location (Optional[Union[str, Dict]]): Spatial location data.
Source Information source_role (Optional[str]): Role of the source, e.g., 'user', 'agent'. source_name (Optional[str]): Descriptive name of the source. source_id (Optional[str]): Unique identifier for the memory source, generated as a UUID string.
Connections edges (List[MemoryLink]): List of edges connecting this memory unit to others.
Additional Metadata metadata (Optional[Dict[str, Any]]): Dictionary for extensible metadata.
Source code in src/aeiva/cognition/memory/memory_unit.py
class MemoryUnit(BaseModel):\n \"\"\"\n MemoryUnit represents a single unit of memory with core content and rich metadata.\n It includes fields for tracking information about the memory\u2019s source, modality,\n temporal and spatial attributes, and its connections to other memory units.\n\n Essential Fields:\n id (str): Unique identifier for the memory unit, generated as a UUID string by default.\n content (Any): Core content of the memory, which is convertible to a string.\n\n Metadata:\n timestamp (datetime): Creation timestamp, defaulting to the current time.\n modality (Optional[str]): Modality type, such as 'text', 'image', 'audio'.\n type (Optional[str]): Semantic type, such as 'dialogue', 'summary', 'document'.\n status (Optional[str]): Processing status, e.g., 'raw', 'cleaned', 'processed'.\n tags (Optional[List[str]]): Tags for categorization and filtering.\n embedding (Optional[List[float]]): Vector embedding for retrieval.\n location (Optional[Union[str, Dict]]): Spatial location data.\n\n Source Information:\n source_role (Optional[str]): Role of the source, e.g., 'user', 'agent'.\n source_name (Optional[str]): Descriptive name of the source.\n source_id (Optional[str]): Unique identifier for the memory source, generated as a UUID string.\n\n Connections:\n edges (List[MemoryLink]): List of edges connecting this memory unit to others.\n\n Additional Metadata:\n metadata (Optional[Dict[str, Any]]): Dictionary for extensible metadata.\n \"\"\"\n\n # Essential Fields\n id: str = Field(default_factory=lambda: uuid4().hex, description=\"Unique identifier for the memory unit.\")\n content: Any = Field(\"\", description=\"Core content of the memory unit, convertible to a string.\")\n\n # Metadata Fields\n timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), description=\"Creation timestamp of the memory.\")\n modality: Optional[str] = Field(None, description=\"Modality type, e.g., 'text', 'image', 'audio'.\")\n type: Optional[str] = Field(None, description=\"Semantic type, e.g., 'dialogue', 'summary'.\")\n status: Optional[str] = Field(None, description=\"Processing status, e.g., 'raw', 'cleaned', 'derived', 'grouped', 'structured', 'indexed'.\")\n tags: Optional[List[str]] = Field(default_factory=list, description=\"Tags for categorization or filtering.\")\n embedding: Optional[List[float]] = Field(None, description=\"Embedding vector for memory.\")\n location: Optional[Union[str, Dict]] = Field(None, description=\"Location data as a string or structured dictionary.\")\n\n # Source Information\n source_role: Optional[str] = Field(None, description=\"Role of the memory source, e.g., 'user', 'agent'.\")\n source_name: Optional[str] = Field(None, description=\"Descriptive name of the source, e.g., 'User123'.\")\n source_id: Optional[str] = Field(default_factory=lambda: uuid4().hex, description=\"Unique identifier associated with the source.\")\n\n # Connections\n edges: List[MemoryLink] = Field(default_factory=list, description=\"List of edges linking this memory unit to others.\")\n\n # Additional Metadata\n metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description=\"Dictionary for extensible metadata.\")\n\n def to_dict(self) -> Dict[str, Any]:\n \"\"\"\n Converts the MemoryUnit instance to a dictionary format for serialization.\n Each field is handled explicitly to ensure proper serialization.\n\n Returns:\n Dict[str, Any]: A dictionary representation of the MemoryUnit.\n \"\"\"\n return {\n \"id\": self.id,\n \"content\": self.content,\n \"timestamp\": self.timestamp.isoformat(), # Convert datetime to string\n \"modality\": self.modality,\n \"type\": self.type,\n \"status\": self.status,\n \"tags\": self.tags,\n \"embedding\": self.embedding,\n \"location\": self.location,\n \"source_role\": self.source_role,\n \"source_name\": self.source_name,\n \"source_id\": self.source_id,\n \"edges\": [edge.to_dict() for edge in self.edges], # Serialize each MemoryLink\n \"metadata\": self.metadata\n }\n\n @classmethod\n def from_dict(cls, data: dict) -> \"MemoryUnit\":\n \"\"\"\n Creates a MemoryUnit instance from a dictionary.\n Each field is handled explicitly to ensure proper deserialization.\n\n Args:\n data (dict): A dictionary containing MemoryUnit data.\n\n Returns:\n MemoryUnit: The created MemoryUnit instance.\n \"\"\"\n try:\n return cls(\n id=data.get('id', uuid4().hex),\n content=data.get('content', \"\"),\n timestamp=datetime.fromisoformat(data['timestamp']) if 'timestamp' in data else datetime.now(UTC),\n modality=data.get('modality'),\n type=data.get('type'),\n status=data.get('status'),\n tags=data.get('tags', []),\n embedding=data.get('embedding'),\n location=data.get('location'),\n source_role=data.get('source_role'),\n source_name=data.get('source_name'),\n source_id=data.get('source_id', uuid4().hex),\n edges=[MemoryLink.from_dict(edge) for edge in data.get('edges', [])],\n metadata=data.get('metadata', {})\n )\n except Exception as e:\n # logger.error(f\"Error deserializing MemoryUnit from dict: {e}\")\n raise e\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_unit.MemoryUnit.from_dict","title":"from_dict(data)
classmethod
","text":"Creates a MemoryUnit instance from a dictionary. Each field is handled explicitly to ensure proper deserialization.
Parameters:
Name Type Description Default data
dict
A dictionary containing MemoryUnit data.
required Returns:
Name Type Description MemoryUnit
MemoryUnit
The created MemoryUnit instance.
Source code in src/aeiva/cognition/memory/memory_unit.py
@classmethod\ndef from_dict(cls, data: dict) -> \"MemoryUnit\":\n \"\"\"\n Creates a MemoryUnit instance from a dictionary.\n Each field is handled explicitly to ensure proper deserialization.\n\n Args:\n data (dict): A dictionary containing MemoryUnit data.\n\n Returns:\n MemoryUnit: The created MemoryUnit instance.\n \"\"\"\n try:\n return cls(\n id=data.get('id', uuid4().hex),\n content=data.get('content', \"\"),\n timestamp=datetime.fromisoformat(data['timestamp']) if 'timestamp' in data else datetime.now(UTC),\n modality=data.get('modality'),\n type=data.get('type'),\n status=data.get('status'),\n tags=data.get('tags', []),\n embedding=data.get('embedding'),\n location=data.get('location'),\n source_role=data.get('source_role'),\n source_name=data.get('source_name'),\n source_id=data.get('source_id', uuid4().hex),\n edges=[MemoryLink.from_dict(edge) for edge in data.get('edges', [])],\n metadata=data.get('metadata', {})\n )\n except Exception as e:\n # logger.error(f\"Error deserializing MemoryUnit from dict: {e}\")\n raise e\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_unit.MemoryUnit.to_dict","title":"to_dict()
","text":"Converts the MemoryUnit instance to a dictionary format for serialization. Each field is handled explicitly to ensure proper serialization.
Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary representation of the MemoryUnit.
Source code in src/aeiva/cognition/memory/memory_unit.py
def to_dict(self) -> Dict[str, Any]:\n \"\"\"\n Converts the MemoryUnit instance to a dictionary format for serialization.\n Each field is handled explicitly to ensure proper serialization.\n\n Returns:\n Dict[str, Any]: A dictionary representation of the MemoryUnit.\n \"\"\"\n return {\n \"id\": self.id,\n \"content\": self.content,\n \"timestamp\": self.timestamp.isoformat(), # Convert datetime to string\n \"modality\": self.modality,\n \"type\": self.type,\n \"status\": self.status,\n \"tags\": self.tags,\n \"embedding\": self.embedding,\n \"location\": self.location,\n \"source_role\": self.source_role,\n \"source_name\": self.source_name,\n \"source_id\": self.source_id,\n \"edges\": [edge.to_dict() for edge in self.edges], # Serialize each MemoryLink\n \"metadata\": self.metadata\n }\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_utils","title":"memory_utils
","text":""},{"location":"reference/#src.aeiva.cognition.memory.memory_utils.derive_content","title":"derive_content(derivation_type, data)
","text":"You are a creative assistant capable of deriving new content based on specified types. Your task is to derive a {derivation_type} from the provided combined content.
Source code in src/aeiva/cognition/memory/memory_utils.py
@simple(model='gpt-4', temperature=0.7)\ndef derive_content(derivation_type: str, data: str) -> str:\n \"\"\"\n You are a creative assistant capable of deriving new content based on specified types.\n Your task is to derive a {derivation_type} from the provided combined content.\n \"\"\"\n result = f\"Derive a {derivation_type} from the following content:\\n{data}\"\n return result\n
"},{"location":"reference/#src.aeiva.cognition.memory.memory_utils.extract_entities_relationships","title":"extract_entities_relationships(data)
","text":"You are an intelligent assistant skilled in natural language processing. Your task is to extract entities and the relationships between them from the provided content.
Source code in src/aeiva/cognition/memory/memory_utils.py
@simple(model='gpt-4', temperature=0.7)\ndef extract_entities_relationships(data: Any) -> str:\n \"\"\"\n You are an intelligent assistant skilled in natural language processing.\n Your task is to extract entities and the relationships between them from the provided content.\n \"\"\"\n result = f\"Extract entities and relationships from the following content:\\n{data}\"\n return result\n
"},{"location":"reference/#src.aeiva.cognition.memory.storage_config","title":"storage_config
","text":""},{"location":"reference/#src.aeiva.cognition.memory.storage_config.StorageConfig","title":"StorageConfig
dataclass
","text":" Bases: BaseConfig
Configuration class for the Memory storage.
Attributes:
Name Type Description vector_db_config
DatabaseConfig
Configuration for the vector database.
graph_db_config
Optional[DatabaseConfig]
Configuration for the graph database.
relational_db_config
Optional[DatabaseConfig]
Configuration for the relational database.
Source code in src/aeiva/cognition/memory/storage_config.py
@dataclass\nclass StorageConfig(BaseConfig):\n \"\"\"\n Configuration class for the Memory storage.\n\n Attributes:\n vector_db_config (DatabaseConfig): Configuration for the vector database.\n graph_db_config (Optional[DatabaseConfig]): Configuration for the graph database.\n relational_db_config (Optional[DatabaseConfig]): Configuration for the relational database.\n \"\"\"\n vector_db_provider: str = field(\n metadata={\"help\": \"Vector database provider name.\"}\n )\n vector_db_config: BaseConfig = field(\n metadata={\"help\": \"Configuration for the vector database.\"}\n )\n graph_db_provider: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Graph database provider name.\"}\n )\n graph_db_config: Optional[BaseConfig] = field(\n default=None,\n metadata={\"help\": \"Configuration for the graph database.\"}\n )\n relational_db_provider: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Relational database provider name.\"}\n )\n relational_db_config: Optional[BaseConfig] = field(\n default=None,\n metadata={\"help\": \"Configuration for the relational database.\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n # Perform any necessary validation\n if not self.vector_db_config:\n raise ValueError(\"Vector database configuration must be provided.\")\n
"},{"location":"reference/#src.aeiva.cognition.observation","title":"observation
","text":""},{"location":"reference/#src.aeiva.cognition.observation.Observation","title":"Observation
","text":"Represents a processed input from the PerceptionSystem.
Source code in src/aeiva/cognition/observation.py
class Observation:\n \"\"\"\n Represents a processed input from the PerceptionSystem.\n \"\"\"\n def __init__(self, data: Any, modality: str = 'text', timestamp: Optional[datetime] = None, metadata: Optional[Dict[str, Any]] = None):\n self.data = data # The processed data (e.g., text)\n self.modality = modality\n self.timestamp = timestamp or datetime.now()\n self.metadata = metadata or {}\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n 'data': self.data,\n 'modality': self.modality,\n 'timestamp': self.timestamp.isoformat(),\n 'metadata': self.metadata\n }\n
"},{"location":"reference/#src.aeiva.cognition.thought","title":"thought
","text":""},{"location":"reference/#src.aeiva.cognition.thought.Thought","title":"Thought
","text":"Represents the output from the Brain after processing an Observation.
Source code in src/aeiva/cognition/thought.py
class Thought:\n \"\"\"\n Represents the output from the Brain after processing an Observation.\n \"\"\"\n def __init__(self, content: Any, modality: str = 'text', timestamp: Optional[datetime] = None, metadata: Optional[Dict[str, Any]] = None):\n self.content = content # The thought content (e.g., text)\n self.modality = modality\n self.timestamp = timestamp or datetime.now()\n self.metadata = metadata or {}\n\n def to_dict(self) -> Dict[str, Any]:\n return {\n 'content': self.content,\n 'modality': self.modality,\n 'timestamp': self.timestamp.isoformat(),\n 'metadata': self.metadata\n }\n
"},{"location":"reference/#src.aeiva.cognition.world_model","title":"world_model
","text":""},{"location":"reference/#src.aeiva.cognition.world_model.world_model","title":"world_model
","text":""},{"location":"reference/#src.aeiva.cognition.world_model.world_model.WorldModel","title":"WorldModel
","text":" Bases: ABC
Abstract base class representing the World Model system of an agent.
The World Model maintains an internal representation of the environment, enabling the agent to understand, predict, and interact with its surroundings effectively.
Attributes:
Name Type Description config
Any
Configuration settings for the World Model system.
state
Any
The internal state of the World Model system.
Source code in src/aeiva/cognition/world_model/world_model.py
class WorldModel(ABC):\n \"\"\"\n Abstract base class representing the World Model system of an agent.\n\n The World Model maintains an internal representation of the environment, enabling the agent\n to understand, predict, and interact with its surroundings effectively.\n\n Attributes:\n config (Any): Configuration settings for the World Model system.\n state (Any): The internal state of the World Model system.\n \"\"\"\n\n def __init__(self, config: Any):\n \"\"\"\n Initialize the World Model system with the provided configuration.\n\n Args:\n config (Any): Configuration settings for the World Model system.\n \"\"\"\n self.config = config\n self.state = self.init_state()\n\n @abstractmethod\n def init_state(self) -> Any:\n \"\"\"\n Initialize the internal state of the World Model system.\n\n This method should set up the initial state required for the World Model system's operations.\n\n Returns:\n Any: The initial state of the World Model system.\n \"\"\"\n pass\n\n @abstractmethod\n def setup(self) -> None:\n \"\"\"\n Asynchronously set up the World Model system's components.\n\n This method should initialize any necessary components or resources based on the provided configuration.\n\n Raises:\n ConfigurationError: If the configuration is invalid or incomplete.\n \"\"\"\n pass\n\n @abstractmethod\n async def update(self, observation: Any) -> None:\n \"\"\"\n Asynchronously update the world model based on new observations.\n\n Args:\n observation (Any): The new observation to incorporate into the world model.\n\n Raises:\n UpdateError: If updating the world model fails.\n \"\"\"\n pass\n\n @abstractmethod\n async def query(self, query: Any) -> Any:\n \"\"\"\n Asynchronously query the world model for specific information.\n\n Args:\n query (Any): The query or criteria to retrieve specific information from the world model.\n\n Returns:\n Any: The information retrieved from the world model.\n\n Raises:\n QueryError: If the query process fails.\n \"\"\"\n pass\n\n def get_current_state(self) -> Any:\n \"\"\"\n Retrieve the current internal state of the World Model system.\n\n Returns:\n Any: The current internal state.\n \"\"\"\n return self.state\n\n def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during world model operations.\n\n This method can be overridden to implement custom error handling logic.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n # Default error handling: log the error\n print(f\"WorldModel system encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.world_model.world_model.WorldModel.__init__","title":"__init__(config)
","text":"Initialize the World Model system with the provided configuration.
Parameters:
Name Type Description Default config
Any
Configuration settings for the World Model system.
required Source code in src/aeiva/cognition/world_model/world_model.py
def __init__(self, config: Any):\n \"\"\"\n Initialize the World Model system with the provided configuration.\n\n Args:\n config (Any): Configuration settings for the World Model system.\n \"\"\"\n self.config = config\n self.state = self.init_state()\n
"},{"location":"reference/#src.aeiva.cognition.world_model.world_model.WorldModel.get_current_state","title":"get_current_state()
","text":"Retrieve the current internal state of the World Model system.
Returns:
Name Type Description Any
Any
The current internal state.
Source code in src/aeiva/cognition/world_model/world_model.py
def get_current_state(self) -> Any:\n \"\"\"\n Retrieve the current internal state of the World Model system.\n\n Returns:\n Any: The current internal state.\n \"\"\"\n return self.state\n
"},{"location":"reference/#src.aeiva.cognition.world_model.world_model.WorldModel.handle_error","title":"handle_error(error)
","text":"Handle errors that occur during world model operations.
This method can be overridden to implement custom error handling logic.
Parameters:
Name Type Description Default error
Exception
The exception that was raised.
required Source code in src/aeiva/cognition/world_model/world_model.py
def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during world model operations.\n\n This method can be overridden to implement custom error handling logic.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n # Default error handling: log the error\n print(f\"WorldModel system encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.cognition.world_model.world_model.WorldModel.init_state","title":"init_state()
abstractmethod
","text":"Initialize the internal state of the World Model system.
This method should set up the initial state required for the World Model system's operations.
Returns:
Name Type Description Any
Any
The initial state of the World Model system.
Source code in src/aeiva/cognition/world_model/world_model.py
@abstractmethod\ndef init_state(self) -> Any:\n \"\"\"\n Initialize the internal state of the World Model system.\n\n This method should set up the initial state required for the World Model system's operations.\n\n Returns:\n Any: The initial state of the World Model system.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.world_model.world_model.WorldModel.query","title":"query(query)
abstractmethod
async
","text":"Asynchronously query the world model for specific information.
Parameters:
Name Type Description Default query
Any
The query or criteria to retrieve specific information from the world model.
required Returns:
Name Type Description Any
Any
The information retrieved from the world model.
Raises:
Type Description QueryError
If the query process fails.
Source code in src/aeiva/cognition/world_model/world_model.py
@abstractmethod\nasync def query(self, query: Any) -> Any:\n \"\"\"\n Asynchronously query the world model for specific information.\n\n Args:\n query (Any): The query or criteria to retrieve specific information from the world model.\n\n Returns:\n Any: The information retrieved from the world model.\n\n Raises:\n QueryError: If the query process fails.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.world_model.world_model.WorldModel.setup","title":"setup()
abstractmethod
","text":"Asynchronously set up the World Model system's components.
This method should initialize any necessary components or resources based on the provided configuration.
Raises:
Type Description ConfigurationError
If the configuration is invalid or incomplete.
Source code in src/aeiva/cognition/world_model/world_model.py
@abstractmethod\ndef setup(self) -> None:\n \"\"\"\n Asynchronously set up the World Model system's components.\n\n This method should initialize any necessary components or resources based on the provided configuration.\n\n Raises:\n ConfigurationError: If the configuration is invalid or incomplete.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.cognition.world_model.world_model.WorldModel.update","title":"update(observation)
abstractmethod
async
","text":"Asynchronously update the world model based on new observations.
Parameters:
Name Type Description Default observation
Any
The new observation to incorporate into the world model.
required Raises:
Type Description UpdateError
If updating the world model fails.
Source code in src/aeiva/cognition/world_model/world_model.py
@abstractmethod\nasync def update(self, observation: Any) -> None:\n \"\"\"\n Asynchronously update the world model based on new observations.\n\n Args:\n observation (Any): The new observation to incorporate into the world model.\n\n Raises:\n UpdateError: If updating the world model fails.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.command","title":"command
","text":""},{"location":"reference/#src.aeiva.command.aeiva_chat_gradio","title":"aeiva_chat_gradio
","text":"We can run the command like below: (specify your own config file path)
aeiva-chat-gradio --config configs/agent_config.yaml
"},{"location":"reference/#src.aeiva.command.aeiva_chat_gradio.run","title":"run(config, verbose)
","text":"Starts the Aeiva chat Gradio interface with the provided configuration.
Source code in src/aeiva/command/aeiva_chat_gradio.py
@click.command(name=\"aeiva-chat-gradio\")\n@click.option('--config', '-c', default=str(DEFAULT_CONFIG_PATH),\n help='Path to the configuration file (YAML or JSON).',\n type=click.Path(exists=True, dir_okay=False))\n@click.option('--verbose', '-v', is_flag=True, help='Enable verbose logging.')\ndef run(config, verbose):\n \"\"\"\n Starts the Aeiva chat Gradio interface with the provided configuration.\n \"\"\"\n # Setup logging\n logger = setup_logging(DEFAULT_LOG_PATH, verbose)\n\n # Load environment variables (API keys, etc.)\n load_dotenv()\n\n logger.info(f\"Loading configuration from {config}\")\n config_dict = from_json_or_yaml(config)\n\n # Initialize the Agent\n try:\n agent = Agent(config_dict)\n agent.setup()\n logger.info(\"Agent initialized successfully.\")\n except Exception as e:\n logger.error(f\"Failed to initialize Agent: {e}\")\n click.echo(f\"Error: Failed to initialize Agent: {e}\")\n sys.exit(1)\n\n # Function to run the Agent's run method in a separate thread\n def run_agent(agent_instance):\n try:\n asyncio.run(agent_instance.run())\n except Exception as e:\n logger.error(f\"Error running Agent: {e}\")\n\n # Start the Agent in a separate daemon thread\n agent_thread = threading.Thread(target=run_agent, args=(agent,), daemon=True)\n agent_thread.start()\n logger.info(\"Agent run thread started.\")\n\n # Initialize a thread-safe queue to receive responses from the Agent\n response_queue = queue.Queue()\n\n # Define a handler for 'response.gradio' events\n def handle_response_gradio(event: Event):\n response = event.payload\n response_queue.put_nowait(response) # Put response into the thread-safe queue\n logger.info(f\"Received 'response.gradio' event: {response}\")\n\n # Register the handler with the Agent's EventBus\n agent.event_bus.on('response.gradio')(handle_response_gradio)\n logger.info(\"Registered handler for 'response.gradio' events.\")\n\n # Validate and start Neo4j\n neo4j_home = os.getenv('NEO4J_HOME')\n if not neo4j_home:\n logger.error(\"NEO4J_HOME environment variable is not set.\")\n click.echo(\"Error: NEO4J_HOME environment variable is not set.\")\n sys.exit(1)\n\n validate_neo4j_home(logger, neo4j_home)\n neo4j_process = start_neo4j(logger, neo4j_home)\n\n # Register signal handlers to ensure Neo4j stops gracefully\n for sig in [signal.SIGINT, signal.SIGTERM]:\n signal.signal(sig, lambda s, f: handle_exit(s, f, logger, neo4j_process))\n\n # Define handlers for multimodal inputs\n\n def handle_image_upload(image: Image.Image):\n if image is not None:\n timestamp = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n image_path = f\"uploads/uploaded_image_{timestamp}.jpg\"\n try:\n image.save(image_path)\n logger.info(f\"Image uploaded and saved to {image_path}\")\n return \"User uploaded an image.\"\n except Exception as e:\n logger.error(f\"Error saving uploaded image: {e}\")\n return \"Failed to upload image.\"\n return \"\"\n\n def handle_video_upload(video):\n if video is not None:\n timestamp = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n video_path = f\"uploads/uploaded_video_{timestamp}.mp4\"\n try:\n with open(video_path, \"wb\") as f:\n f.write(video.read())\n logger.info(f\"Video uploaded and saved to {video_path}\")\n return \"User uploaded a video.\"\n except Exception as e:\n logger.error(f\"Error saving uploaded video: {e}\")\n return \"Failed to upload video.\"\n return \"\"\n\n def handle_audio_upload(audio):\n if audio is not None:\n try:\n sample_rate, audio_data = audio\n # Normalize audio_data to float32 in the range -1.0 to 1.0\n audio_data_normalized = audio_data.astype(np.float32) / np.abs(audio_data).max()\n timestamp = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n audio_path = f\"uploads/uploaded_audio_{timestamp}.wav\"\n sf.write(audio_path, audio_data_normalized, sample_rate, subtype='PCM_16')\n logger.info(f\"Audio uploaded and saved to {audio_path}\")\n return \"User uploaded an audio file.\"\n except Exception as e:\n logger.error(f\"Error saving uploaded audio: {e}\")\n return \"Failed to upload audio.\"\n return \"\"\n\n def handle_upload(file):\n \"\"\"\n Handles file uploads and delegates to specific handlers based on file type.\n\n Args:\n file: Uploaded file object.\n\n Returns:\n str: Message indicating the upload status.\n \"\"\"\n if file is None:\n return \"\"\n if file.type.startswith(\"image\"):\n return handle_image_upload(file)\n elif file.type.startswith(\"video\"):\n return handle_video_upload(file)\n elif file.type.startswith(\"audio\"):\n return handle_audio_upload(file)\n else:\n logger.warning(f\"Unsupported file type uploaded: {file.type}\")\n return \"Unsupported file type uploaded.\"\n\n def clear_media():\n \"\"\"\n Clears the uploaded media paths.\n \"\"\"\n # Implement any necessary logic to clear media paths or data\n logger.info(\"Cleared uploaded media paths.\")\n return \"\"\n\n async def bot(user_input, history):\n \"\"\"\n Handles chatbot logic by emitting perception.gradio events to the Agent and retrieving responses.\n \"\"\"\n if agent is None:\n logger.error(\"Agent is not initialized.\")\n history.append({\"role\": \"assistant\", \"content\": \"Agent is not initialized.\"})\n yield history, ''\n return\n\n try:\n # Append user's message to history\n history.append({\"role\": \"user\", \"content\": user_input})\n # Append an empty assistant response\n history.append({\"role\": \"assistant\", \"content\": \"\"})\n yield history, '' # Display the user's message\n logger.info(f\"User input appended to history: {user_input}\")\n\n stream = config_dict[\"llm_gateway_config\"][\"llm_stream\"]\n use_async = config_dict[\"llm_gateway_config\"][\"llm_use_async\"]\n\n # Emit the 'perception.gradio' event with stream=True\n emit_future = asyncio.run_coroutine_threadsafe(\n agent.event_bus.emit('perception.gradio', payload=user_input),\n agent.event_bus.loop\n )\n emit_future.result() # Ensure the event is emitted\n logger.info(f\"Emitted 'perception.gradio' event with payload: {user_input} | Stream: {stream}\")\n\n assistant_message = ''\n if stream:\n while True:\n try:\n # Non-blocking response retrieval from the thread-safe queue with timeout\n response = await asyncio.wait_for(\n asyncio.to_thread(response_queue.get, True, 30),\n timeout=30\n )\n logger.info(f\"Retrieved response from queue: {response}\")\n if response == \"<END_OF_RESPONSE>\":\n logger.info(\"Received end of response signal.\")\n break\n assistant_message += response\n # Create a new history list to ensure Gradio detects the update\n new_history = history.copy()\n new_history[-1][\"content\"] = assistant_message\n logger.info(f\"Yielding updated history: {new_history}\")\n yield new_history, ''\n except asyncio.TimeoutError:\n logger.warning(\"Timeout: No response received from Agent.\")\n # Create a new history list to ensure Gradio detects the update\n new_history = history.copy()\n new_history[-1][\"content\"] = \"I'm sorry, I didn't receive a response in time.\"\n yield new_history, ''\n break\n else:\n try:\n # Non-blocking response retrieval from the thread-safe queue with timeout\n response = await asyncio.wait_for(\n asyncio.to_thread(response_queue.get, True, 30),\n timeout=30\n )\n logger.info(f\"Retrieved response from queue: {response}\")\n assistant_message += response\n # Create a new history list to ensure Gradio detects the update\n new_history = history.copy()\n new_history[-1][\"content\"] = assistant_message\n logger.info(f\"Yielding updated history: {new_history}\")\n yield new_history, ''\n except asyncio.TimeoutError:\n logger.warning(\"Timeout: No response received from Agent.\")\n # Create a new history list to ensure Gradio detects the update\n new_history = history.copy()\n new_history[-1][\"content\"] = \"I'm sorry, I didn't receive a response in time.\"\n yield new_history, ''\n\n except Exception as e:\n logger.error(f\"Unexpected Error in bot function: {e}\")\n # Create a new history list to ensure Gradio detects the update\n new_history = history.copy()\n new_history[-1][\"content\"] = \"An unexpected error occurred.\"\n yield new_history, ''\n\n def launch_gradio_interface():\n \"\"\"\n Main gradio interface.\n \"\"\"\n with gr.Blocks(title=\"Multimodal LLM Chatbot with Tools\") as demo:\n # Header Section\n gr.Markdown(\"\"\"\n <h1 align=\"center\">\n <a href=\"https://github.com/chatsci/Aeiva\">\n <img src=\"https://i.ibb.co/P4zQHDk/aeiva-1024.png\",\n alt=\"Aeiva\" border=\"0\" style=\"margin: 0 auto; height: 200px;\" />\n </a>\n </h1>\n\n <h2 align=\"center\">\n AEIVA: An Evolving Intelligent Virtual Assistant\n </h2>\n\n <h5 align=\"center\">\n If you like our project, please give us a star \u2728 on Github for the latest update.\n </h5>\n\n <div align=\"center\">\n <div style=\"display:flex; gap: 0.25rem;\" align=\"center\">\n <a href='https://github.com/chatsci/Aeiva'><img src='https://img.shields.io/badge/Github-Code-blue'></a>\n <a href=\"https://arxiv.org/abs/2304.14178\"><img src=\"https://img.shields.io/badge/Arxiv-2304.14178-red\"></a>\n <a href='https://github.com/chatsci/Aeiva/stargazers'><img src='https://img.shields.io/github/stars/chatsci/Aeiva.svg?style=social'></a>\n </div>\n </div>\n \"\"\")\n\n # Main Layout: Two Columns\n with gr.Row():\n # Left Column: Parameter Settings and Multimodal Inputs\n with gr.Column(scale=1, min_width=700):\n # Parameter Settings Tab\n with gr.Tab(label=\"Parameter Setting\"):\n gr.Markdown(\"# Parameters\")\n top_p = gr.Slider(\n minimum=0,\n maximum=1.0,\n value=0.95,\n step=0.05,\n interactive=True,\n label=\"Top-p\"\n )\n temperature = gr.Slider(\n minimum=0.1,\n maximum=2.0,\n value=1.0,\n step=0.1,\n interactive=True,\n label=\"Temperature\"\n )\n max_length_tokens = gr.Slider(\n minimum=0,\n maximum=512,\n value=512,\n step=8,\n interactive=True,\n label=\"Max Generation Tokens\"\n )\n max_context_length_tokens = gr.Slider(\n minimum=0,\n maximum=4096,\n value=2048,\n step=128,\n interactive=True,\n label=\"Max History Tokens\"\n )\n\n # Multimodal Inputs Section\n with gr.Row():\n imagebox = gr.Image(type=\"pil\", label=\"Upload Image\")\n videobox = gr.File(label=\"Upload Video\", file_types=[\"video\"])\n audiobox = gr.Audio(label=\"Upload Audio\", type=\"numpy\")\n\n with gr.Row():\n record_videobox = gr.Video(label=\"Record Video\")\n record_audiobox = gr.Audio(label=\"Record Audio\")\n\n # Clear Media Button\n with gr.Row():\n clear_media_btn = gr.Button(\"\ud83e\uddf9 Clear Media\", variant=\"secondary\")\n\n # Right Column: Chat Interface and Action Buttons\n with gr.Column(scale=1, min_width=700):\n # Chatbot Component\n chatbot = gr.Chatbot(\n [],\n type=\"messages\", # Specify type as 'messages'\n elem_id=\"chatbot\",\n height=730\n )\n\n # Input Textbox and Upload Button\n with gr.Row():\n with gr.Column(scale=4, min_width=300):\n txt = gr.Textbox(\n show_label=False,\n placeholder=\"Enter text and press enter, or upload an image/video/audio\",\n lines=1,\n elem_classes=[\"input-textbox\"] # Assign a CSS class for styling\n )\n with gr.Column(scale=1, min_width=100):\n btn = gr.UploadButton(\"\ud83d\udcc1\", file_types=[\"image\", \"video\", \"audio\"], elem_classes=[\"upload-button\"])\n # Changed the button label to an icon for a more compact look\n\n # Action Buttons Placed Below the Input Box\n with gr.Row():\n upvote_btn = gr.Button(\"\ud83d\udc4d Upvote\", interactive=True)\n downvote_btn = gr.Button(\"\ud83d\udc4e Downvote\", interactive=True)\n flag_btn = gr.Button(\"\u26a0\ufe0f Flag\", interactive=True)\n regenerate_btn = gr.Button(\"\ud83d\udd04 Regenerate\", interactive=True)\n clear_history_btn = gr.Button(\"\ud83d\uddd1\ufe0f Clear History\", interactive=True)\n new_conv_btn = gr.Button(\"\ud83e\uddf9 New Conversation\", interactive=True)\n del_last_turn_btn = gr.Button(\"\ud83d\uddd1\ufe0f Remove Last Turn\", interactive=True)\n\n # Define interactions\n\n # Text input submission with streaming\n txt.submit(\n bot,\n inputs=[txt, chatbot],\n outputs=[chatbot, txt],\n queue=True, # Enable queue for better performance\n # stream=True # Enable streaming (already handled in the bot function)\n )\n # Removed the .then callback to prevent layout shifts\n\n # File upload (image/video/audio)\n btn.upload(\n handle_upload,\n inputs=btn,\n outputs=txt, # Set message in textbox to trigger bot\n queue=True\n )\n\n # Image upload\n imagebox.upload(\n handle_image_upload,\n inputs=imagebox,\n outputs=txt, # Set message in textbox to trigger bot\n queue=True\n )\n\n # Video upload\n videobox.upload(\n handle_video_upload,\n inputs=videobox,\n outputs=txt, # Set message in textbox to trigger bot\n queue=True\n )\n\n # Audio upload\n audiobox.upload(\n handle_audio_upload,\n inputs=audiobox,\n outputs=txt, # Set message in textbox to trigger bot\n queue=True\n )\n\n # Record Video\n record_videobox.change(\n handle_video_upload,\n inputs=record_videobox,\n outputs=txt, # Set message in textbox to trigger bot\n queue=True\n )\n\n # Record Audio\n record_audiobox.change(\n handle_audio_upload,\n inputs=record_audiobox,\n outputs=txt, # Set message in textbox to trigger bot\n queue=True\n )\n\n # Clear Media Button\n clear_media_btn.click(\n clear_media,\n inputs=None,\n outputs=None,\n queue=False\n )\n\n # Action Buttons Functionality\n\n # Clear History\n clear_history_btn.click(\n lambda: ([], \"\"),\n inputs=None,\n outputs=[chatbot, txt],\n queue=False\n )\n\n # New Conversation\n new_conv_btn.click(\n lambda: ([], \"\"),\n inputs=None,\n outputs=[chatbot, txt],\n queue=False\n )\n\n # Remove Last Turn (Removes the last user and assistant messages)\n del_last_turn_btn.click(\n lambda history: history[:-2] if len(history) >= 2 else history,\n inputs=chatbot,\n outputs=chatbot,\n queue=False\n )\n\n # Launch the Gradio interface\n demo.launch(share=True)\n\n # Launch aeiva chat gradio\n launch_gradio_interface()\n
"},{"location":"reference/#src.aeiva.command.aeiva_chat_terminal","title":"aeiva_chat_terminal
","text":"We can run the command like below: (specify your own config file path)
aeiva-chat-terminal --config configs/agent_config.yaml
"},{"location":"reference/#src.aeiva.command.aeiva_chat_terminal.run","title":"run(config, verbose)
","text":"Starts the Aeiva chat terminal with the provided configuration.
Source code in src/aeiva/command/aeiva_chat_terminal.py
@click.command()\n@click.option('--config', '-c', default=str(DEFAULT_CONFIG_PATH),\n help='Path to the configuration file (YAML or JSON).',\n type=click.Path(exists=True, dir_okay=False))\n@click.option('--verbose', '-v', is_flag=True, help='Enable verbose logging.')\ndef run(config, verbose):\n \"\"\"\n Starts the Aeiva chat terminal with the provided configuration.\n \"\"\"\n # Setup logging\n logger = setup_logging(DEFAULT_LOG_PATH, verbose)\n\n click.echo(f\"Loading configuration from {config}\")\n config_path = Path(config)\n\n # Parse the configuration file with error handling\n try:\n config_data = from_json_or_yaml(config_path)\n except Exception as e:\n logger.error(f\"Failed to parse configuration file: {e}\")\n click.echo(f\"Error: Failed to parse configuration file: {e}\")\n sys.exit(1)\n\n # Retrieve NEO4J_HOME from environment variables\n neo4j_home = os.getenv('NEO4J_HOME')\n if not neo4j_home:\n logger.error(\"NEO4J_HOME is not set in the environment.\")\n click.echo(\"Error: NEO4J_HOME is not set in the environment. Please set it in your shell configuration (e.g., .bashrc or .zshrc).\")\n sys.exit(1)\n\n # Validate NEO4J_HOME path\n validate_neo4j_home(logger, neo4j_home)\n\n # Start Neo4j\n neo4j_process = start_neo4j(logger, neo4j_home)\n\n # Register signal handlers to ensure Neo4j stops gracefully\n signal.signal(signal.SIGINT, lambda s, f: handle_exit(s, f, neo4j_process))\n signal.signal(signal.SIGTERM, lambda s, f: handle_exit(s, f, neo4j_process))\n\n # Start the Agent\n try:\n agent = Agent(config_data)\n agent.setup()\n asyncio.run(agent.run())\n except KeyboardInterrupt:\n logger.info(\"Agent execution interrupted by user.\")\n click.echo(\"\\nAgent execution interrupted by user.\")\n except Exception as e:\n logger.error(f\"An error occurred during agent execution: {e}\")\n click.echo(f\"An error occurred during agent execution: {e}\")\n finally:\n # # Perform any necessary cleanup\n # try:\n # agent.cognition_components['memory'].delete_all()\n # logger.info(\"All memory units deleted during cleanup.\")\n # except NotImplementedError as nie:\n # logger.warning(f\"Delete All feature not implemented: {nie}\")\n # except Exception as e:\n # logger.error(f\"Error during cleanup: {e}\")\n # click.echo(\"Failed to delete all memory units.\")\n\n # Stop Neo4j\n stop_neo4j(logger, neo4j_process)\n logger.info(\"Cleanup completed.\")\n
"},{"location":"reference/#src.aeiva.command.aeiva_server","title":"aeiva_server
","text":""},{"location":"reference/#src.aeiva.command.aeiva_server.run","title":"run(config, host, port, verbose)
","text":"Starts the Aeiva Agent Server using FastAPI.
Source code in src/aeiva/command/aeiva_server.py
@click.command(name=\"aeiva-server\")\n@click.option(\n '--config', '-c',\n default=None,\n help='Path to the configuration file (YAML or JSON).',\n type=click.Path(exists=True, dir_okay=False)\n)\n@click.option(\n '--host', '-H',\n default=\"0.0.0.0\",\n help='Host address to run the server on.',\n show_default=True\n)\n@click.option(\n '--port', '-p',\n default=8000,\n help='Port number to run the server on.',\n show_default=True\n)\n@click.option(\n '--verbose', '-v',\n is_flag=True,\n help='Enable verbose logging.'\n)\ndef run(config, host, port, verbose):\n \"\"\"\n Starts the Aeiva Agent Server using FastAPI.\n \"\"\"\n # Setup logging\n logger = setup_logging(get_log_dir() / 'aeiva-server.log', verbose)\n\n # Load configuration\n if config is None:\n PACKAGE_ROOT = get_package_root()\n config_path = PACKAGE_ROOT / 'configs' / 'agent_config.yaml'\n else:\n config_path = Path(config)\n\n logger.info(f\"Loading configuration from {config_path}\")\n config_dict = from_json_or_yaml(config_path)\n\n # Validate and start Neo4j\n neo4j_home = os.getenv('NEO4J_HOME')\n if not neo4j_home:\n logger.error(\"NEO4J_HOME environment variable is not set.\")\n click.echo(\"Error: NEO4J_HOME environment variable is not set.\")\n sys.exit(1)\n\n validate_neo4j_home(logger, neo4j_home)\n neo4j_process = start_neo4j(logger, neo4j_home)\n\n # Initialize the Agent\n try:\n agent = Agent(config_dict)\n agent.setup()\n logger.info(\"Agent initialized successfully.\")\n except Exception as e:\n logger.error(f\"Failed to initialize Agent: {e}\")\n click.echo(f\"Error: Failed to initialize Agent: {e}\")\n stop_neo4j(logger, neo4j_process)\n sys.exit(1)\n\n # Define the FastAPI app with lifespan\n @asynccontextmanager\n async def lifespan(app: FastAPI):\n app.state.agent = agent\n logger.info(\"Agent has been initialized and is ready to receive messages.\")\n try:\n yield\n finally:\n logger.info(\"Shutting down the agent server.\")\n # If the Agent class has a shutdown method, call it here\n if hasattr(app.state.agent, 'shutdown'):\n await app.state.agent.shutdown()\n stop_neo4j(logger, neo4j_process)\n logger.info(\"Agent server shut down gracefully.\")\n\n app = FastAPI(lifespan=lifespan)\n\n # Enable CORS for all origins (for development purposes)\n app.add_middleware(\n CORSMiddleware,\n allow_origins=[\"*\"], # Adjust in production\n allow_credentials=True,\n allow_methods=[\"*\"],\n allow_headers=[\"*\"],\n )\n\n # Define the endpoint\n @app.post(\"/process_text\", response_model=MessageResponse)\n async def process_text(request: MessageRequest):\n if not request.message:\n raise HTTPException(status_code=400, detail=\"No message provided\")\n\n logger.info(f\"Received message: {request.message}\")\n\n # Process the message using the agent\n try:\n response_text = await app.state.agent.process_input(request.message)\n logger.info(f\"Agent response: {response_text}\")\n return MessageResponse(response=response_text)\n except Exception as e:\n logger.error(f\"Error processing input: {e}\")\n raise HTTPException(status_code=500, detail=\"Internal Server Error\")\n\n # Register signal handlers for graceful shutdown using handle_exit\n for sig in [signal.SIGINT, signal.SIGTERM]:\n signal.signal(sig, lambda s, f: handle_exit(s, f, logger, neo4j_process))\n\n # Run the FastAPI app using Uvicorn\n try:\n logger.info(f\"Starting server at http://{host}:{port}\")\n uvicorn.run(app, host=host, port=port)\n except Exception as e:\n logger.error(f\"Server encountered an error: {e}\")\n handle_exit(None, None, logger, neo4j_process) # Ensure cleanup on exception\n sys.exit(1)\n finally:\n logger.info(\"Server has been stopped.\")\n
"},{"location":"reference/#src.aeiva.command.command_utils","title":"command_utils
","text":"Here we put util functions related to database, logging and so on for different aeiva commands execution.
"},{"location":"reference/#src.aeiva.command.command_utils.get_log_dir","title":"get_log_dir()
","text":"Determines a suitable path for the log file. Logs are stored in the user's home directory under '.aeiva/logs/'.
Source code in src/aeiva/command/command_utils.py
def get_log_dir():\n \"\"\"\n Determines a suitable path for the log file.\n Logs are stored in the user's home directory under '.aeiva/logs/'.\n \"\"\"\n home_dir = Path.home()\n log_dir = home_dir / '.aeiva' / 'logs' # Log saved to `~/.aeiva/logs/`\n log_dir.mkdir(parents=True, exist_ok=True) # Ensure the log directory exists\n return log_dir\n
"},{"location":"reference/#src.aeiva.command.command_utils.get_package_root","title":"get_package_root()
","text":"Determines the root path of the 'aeiva' package.
Source code in src/aeiva/command/command_utils.py
def get_package_root():\n \"\"\"\n Determines the root path of the 'aeiva' package.\n \"\"\"\n aeiva_path = Path(importlib_resources.files(\"aeiva\"))\n package_root = aeiva_path.parents[1]\n return package_root.resolve()\n
"},{"location":"reference/#src.aeiva.command.command_utils.handle_exit","title":"handle_exit(signum, frame, logger, neo4j_process)
","text":"Handles termination signals to ensure Neo4j is stopped gracefully.
Source code in src/aeiva/command/command_utils.py
def handle_exit(signum, frame, logger, neo4j_process):\n \"\"\"\n Handles termination signals to ensure Neo4j is stopped gracefully.\n \"\"\"\n logger.info(f\"Received signal {signum}. Shutting down Neo4j.\")\n click.echo(f\"\\nReceived signal {signum}. Shutting down Neo4j.\")\n stop_neo4j(logger, neo4j_process)\n sys.exit(0)\n
"},{"location":"reference/#src.aeiva.command.command_utils.setup_logging","title":"setup_logging(log_file, verbose=False)
","text":"Sets up logging to both file and console.
Source code in src/aeiva/command/command_utils.py
def setup_logging(log_file, verbose=False):\n \"\"\"\n Sets up logging to both file and console.\n \"\"\"\n logger = get_logger(__name__, level=\"DEBUG\" if verbose else \"INFO\")\n\n # Create a file handler\n file_handler = logging.FileHandler(log_file, mode='a')\n file_handler.setLevel(logging.DEBUG if verbose else logging.INFO)\n\n # Create a console handler\n console_handler = logging.StreamHandler()\n console_handler.setLevel(logging.DEBUG if verbose else logging.INFO)\n\n # Create a logging format\n formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n file_handler.setFormatter(formatter)\n console_handler.setFormatter(formatter)\n\n # Add handlers to the logger\n logger.addHandler(file_handler)\n logger.addHandler(console_handler)\n\n return logger\n
"},{"location":"reference/#src.aeiva.command.command_utils.start_neo4j","title":"start_neo4j(logger, neo4j_home)
","text":"Starts the Neo4j database as a subprocess.
Source code in src/aeiva/command/command_utils.py
def start_neo4j(logger, neo4j_home):\n \"\"\"\n Starts the Neo4j database as a subprocess.\n \"\"\"\n neo4j_command = [os.path.join(neo4j_home, 'bin', 'neo4j'), 'console']\n try:\n neo4j_process = subprocess.Popen(\n neo4j_command,\n stdout=subprocess.DEVNULL, # Suppress stdout\n stderr=subprocess.DEVNULL, # Suppress stderr\n stdin=subprocess.DEVNULL, # Prevent Neo4j from waiting for input\n preexec_fn=os.setsid # Start the process in a new session\n )\n logger.info(\"Neo4j database started successfully.\")\n click.echo(\"Neo4j database started successfully.\")\n return neo4j_process\n except FileNotFoundError:\n logger.error(f\"Neo4j executable not found in {neo4j_command}.\")\n click.echo(f\"Error: Neo4j executable not found in {neo4j_command}.\")\n sys.exit(1)\n except Exception as e:\n logger.error(f\"Failed to start Neo4j: {e}\")\n click.echo(f\"Error: Failed to start Neo4j: {e}\")\n sys.exit(1)\n
"},{"location":"reference/#src.aeiva.command.command_utils.stop_neo4j","title":"stop_neo4j(logger, neo4j_process)
","text":"Stops the Neo4j database subprocess gracefully.
Source code in src/aeiva/command/command_utils.py
def stop_neo4j(logger, neo4j_process):\n \"\"\"\n Stops the Neo4j database subprocess gracefully.\n \"\"\"\n try:\n # Check if the process is still running\n if neo4j_process.poll() is None:\n os.killpg(os.getpgid(neo4j_process.pid), signal.SIGINT) # Send SIGINT for graceful shutdown\n logger.info(\"Sent SIGINT to Neo4j subprocess.\")\n click.echo(\"Shutting down Neo4j...\")\n neo4j_process.wait(timeout=15) # Increased timeout to 15 seconds\n logger.info(\"Neo4j database stopped successfully.\")\n click.echo(\"Neo4j database stopped successfully.\")\n else:\n logger.warning(\"Neo4j subprocess is already terminated.\")\n click.echo(\"Warning: Neo4j subprocess is already terminated.\")\n except subprocess.TimeoutExpired:\n logger.error(\"Neo4j did not terminate within the timeout period.\")\n click.echo(\"Error: Neo4j did not terminate within the timeout period.\")\n # Optionally, force kill\n try:\n os.killpg(os.getpgid(neo4j_process.pid), signal.SIGKILL)\n neo4j_process.wait(timeout=5)\n logger.info(\"Neo4j database forcefully terminated.\")\n click.echo(\"Neo4j database forcefully terminated.\")\n except Exception as e:\n logger.error(f\"Failed to forcefully terminate Neo4j: {e}\")\n click.echo(f\"Error: Failed to forcefully terminate Neo4j: {e}\")\n except ProcessLookupError:\n logger.warning(\"Neo4j subprocess does not exist.\")\n click.echo(\"Warning: Neo4j subprocess does not exist. It may have already terminated.\")\n except Exception as e:\n logger.error(f\"Error stopping Neo4j: {e}\")\n click.echo(f\"Error: Failed to stop Neo4j: {e}\")\n
"},{"location":"reference/#src.aeiva.command.command_utils.validate_neo4j_home","title":"validate_neo4j_home(logger, neo4j_home)
","text":"Validates that the NEO4J_HOME path exists and contains the Neo4j executable.
Source code in src/aeiva/command/command_utils.py
def validate_neo4j_home(logger, neo4j_home):\n \"\"\"\n Validates that the NEO4J_HOME path exists and contains the Neo4j executable.\n \"\"\"\n if not os.path.isdir(neo4j_home):\n logger.error(f\"NEO4J_HOME path does not exist or is not a directory: {neo4j_home}\")\n click.echo(f\"Error: NEO4J_HOME path does not exist or is not a directory: {neo4j_home}\")\n sys.exit(1)\n\n neo4j_executable = os.path.join(neo4j_home, 'bin', 'neo4j')\n if not os.path.isfile(neo4j_executable) or not os.access(neo4j_executable, os.X_OK):\n logger.error(f\"Neo4j executable not found or not executable at: {neo4j_executable}\")\n click.echo(f\"Error: Neo4j executable not found or not executable at: {neo4j_executable}\")\n sys.exit(1)\n
"},{"location":"reference/#src.aeiva.command.maid_chat","title":"maid_chat
","text":""},{"location":"reference/#src.aeiva.command.maid_chat.run","title":"run(config, host, port, verbose)
","text":"Starts the Aeiva Agent Server and launches the Unity application.
Source code in src/aeiva/command/maid_chat.py
@click.command(name=\"maid-chat\")\n@click.option(\n '--config', '-c',\n default=None,\n help='Path to the configuration file (YAML or JSON).',\n type=click.Path(exists=True, dir_okay=False)\n)\n@click.option(\n '--host', '-H',\n default=\"0.0.0.0\",\n help='Host address to run the server on.',\n show_default=True\n)\n@click.option(\n '--port', '-p',\n default=8000,\n help='Port number to run the server on.',\n show_default=True\n)\n@click.option(\n '--verbose', '-v',\n is_flag=True,\n help='Enable verbose logging.'\n)\ndef run(config, host, port, verbose):\n \"\"\"\n Starts the Aeiva Agent Server and launches the Unity application.\n \"\"\"\n # Setup logging\n logger = setup_logging(get_log_dir() / 'maid-chat.log', verbose)\n\n # Load configuration\n if config is None:\n PACKAGE_ROOT = get_package_root()\n config_path = PACKAGE_ROOT / 'configs' / 'agent_config.yaml'\n else:\n config_path = Path(config)\n\n logger.info(f\"Loading configuration from {config_path}\")\n config_dict = from_json_or_yaml(config_path)\n\n # Validate and start Neo4j\n neo4j_home = os.getenv('NEO4J_HOME')\n if not neo4j_home:\n logger.error(\"NEO4J_HOME environment variable is not set.\")\n click.echo(\"Error: NEO4J_HOME environment variable is not set.\")\n sys.exit(1)\n\n validate_neo4j_home(logger, neo4j_home)\n neo4j_process = start_neo4j(logger, neo4j_home)\n\n # Initialize the Agent\n try:\n agent = Agent(config_dict)\n agent.setup()\n logger.info(\"Agent initialized successfully.\")\n except Exception as e:\n logger.error(f\"Failed to initialize Agent: {e}\")\n click.echo(f\"Error: Failed to initialize Agent: {e}\")\n stop_neo4j(logger, neo4j_process)\n sys.exit(1)\n\n # Read MAID_HOME environment variable\n maid_home = os.getenv('MAID_HOME')\n if not maid_home:\n logger.error(\"MAID_HOME environment variable is not set.\")\n click.echo(\"Error: MAID_HOME environment variable is not set.\")\n stop_neo4j(logger, neo4j_process)\n sys.exit(1)\n\n maid_home_path = Path(maid_home)\n if not maid_home_path.exists():\n logger.error(f\"Unity application not found at MAID_HOME: {maid_home}\")\n click.echo(f\"Error: Unity application not found at MAID_HOME: {maid_home}\")\n stop_neo4j(logger, neo4j_process)\n sys.exit(1)\n\n # Start the Unity application\n unity_process = start_unity_app(str(maid_home_path), logger)\n if unity_process is None:\n stop_neo4j(logger, neo4j_process)\n sys.exit(1)\n\n # Define the FastAPI app with lifespan\n @asynccontextmanager\n async def lifespan(app: FastAPI):\n app.state.agent = agent\n logger.info(\"Agent has been initialized and is ready to receive messages.\")\n try:\n yield\n finally:\n logger.info(\"Shutting down the agent server.\")\n # If the Agent class has a shutdown method, call it here\n if hasattr(app.state.agent, 'shutdown'):\n await app.state.agent.shutdown()\n stop_neo4j(logger, neo4j_process)\n # Terminate the Unity application\n stop_unity_app(unity_process, logger)\n logger.info(\"Agent server shut down gracefully.\")\n\n app = FastAPI(lifespan=lifespan)\n\n # Enable CORS for all origins (for development purposes)\n app.add_middleware(\n CORSMiddleware,\n allow_origins=[\"*\"], # Adjust in production\n allow_credentials=True,\n allow_methods=[\"*\"],\n allow_headers=[\"*\"],\n )\n\n # Define the endpoint\n @app.post(\"/process_text\", response_model=MessageResponse)\n async def process_text(request: MessageRequest):\n if not request.message:\n raise HTTPException(status_code=400, detail=\"No message provided\")\n\n logger.info(f\"Received message: {request.message}\")\n\n # Process the message using the agent\n try:\n response_text = await app.state.agent.process_input(request.message)\n logger.info(f\"Agent response: {response_text}\")\n return MessageResponse(response=response_text)\n except Exception as e:\n logger.error(f\"Error processing input: {e}\")\n raise HTTPException(status_code=500, detail=\"Internal Server Error\")\n\n # Register signal handlers for graceful shutdown using handle_exit\n for sig in [signal.SIGINT, signal.SIGTERM]:\n signal.signal(sig, lambda s, f: handle_exit(s, f, logger, neo4j_process, unity_process))\n\n # Run the FastAPI app using Uvicorn\n try:\n logger.info(f\"Starting server at http://{host}:{port}\")\n uvicorn.run(app, host=host, port=port)\n except Exception as e:\n logger.error(f\"Server encountered an error: {e}\")\n handle_exit(None, None, logger, neo4j_process, unity_process) # Ensure cleanup on exception\n sys.exit(1)\n finally:\n logger.info(\"Server has been stopped.\")\n
"},{"location":"reference/#src.aeiva.command.maid_chat.start_unity_app","title":"start_unity_app(maid_home, logger)
","text":"Starts the Unity application.
Parameters:
Name Type Description Default maid_home
str
Path to the Unity application executable.
required logger
Logger
Logger instance.
required Returns:
Type Description Optional[Popen]
Optional[subprocess.Popen]: The subprocess running the Unity application, or None if failed.
Source code in src/aeiva/command/maid_chat.py
def start_unity_app(maid_home: str, logger: logging.Logger) -> Optional[subprocess.Popen]:\n \"\"\"\n Starts the Unity application.\n\n Args:\n maid_home (str): Path to the Unity application executable.\n logger (logging.Logger): Logger instance.\n\n Returns:\n Optional[subprocess.Popen]: The subprocess running the Unity application, or None if failed.\n \"\"\"\n try:\n unity_process = subprocess.Popen(\n [maid_home],\n stdout=subprocess.DEVNULL,\n stderr=subprocess.DEVNULL,\n preexec_fn=os.setsid # Start the process in a new session\n )\n logger.info(f\"Unity application started from {maid_home}.\")\n click.echo(f\"Unity application started from {maid_home}.\")\n return unity_process\n except FileNotFoundError:\n logger.error(f\"Unity application not found at {maid_home}.\")\n click.echo(f\"Error: Unity application not found at {maid_home}.\")\n return None\n except Exception as e:\n logger.error(f\"Failed to start Unity application: {e}\")\n click.echo(f\"Error: Failed to start Unity application: {e}.\")\n return None\n
"},{"location":"reference/#src.aeiva.command.maid_chat.stop_unity_app","title":"stop_unity_app(unity_process, logger)
","text":"Stops the Unity application gracefully.
Parameters:
Name Type Description Default unity_process
Popen
The subprocess running the Unity application.
required logger
Logger
Logger instance.
required Source code in src/aeiva/command/maid_chat.py
def stop_unity_app(unity_process: subprocess.Popen, logger: logging.Logger):\n \"\"\"\n Stops the Unity application gracefully.\n\n Args:\n unity_process (subprocess.Popen): The subprocess running the Unity application.\n logger (logging.Logger): Logger instance.\n \"\"\"\n try:\n os.killpg(os.getpgid(unity_process.pid), signal.SIGTERM)\n unity_process.wait(timeout=10)\n logger.info(\"Unity application terminated gracefully.\")\n click.echo(\"Unity application terminated gracefully.\")\n except Exception as e:\n logger.error(f\"Error terminating Unity application: {e}\")\n click.echo(f\"Error: Failed to terminate Unity application: {e}.\")\n
"},{"location":"reference/#src.aeiva.common","title":"common
","text":""},{"location":"reference/#src.aeiva.common.decorators","title":"decorators
","text":""},{"location":"reference/#src.aeiva.common.decorators.import_submodules","title":"import_submodules(package, recursive=True)
","text":"Import all submodules of a module, recursively, including subpackages
Source code in src/aeiva/common/decorators.py
def import_submodules(package, recursive=True):\n \"\"\" Import all submodules of a module, recursively, including subpackages \"\"\"\n\n if isinstance(package, str):\n package = importlib.import_module(package)\n\n results = {}\n\n for loader, name, is_pkg in pkgutil.walk_packages(package.__path__):\n full_name = package.__name__ + \".\" + name\n results[full_name] = importlib.import_module(full_name)\n if recursive and is_pkg:\n results.update(import_submodules(full_name))\n\n return results\n
"},{"location":"reference/#src.aeiva.common.id_generator","title":"id_generator
","text":""},{"location":"reference/#src.aeiva.common.id_generator.IDGenerator","title":"IDGenerator
","text":"A simple class to generate unique IDs for distinct names.
Attributes:
Name Type Description name_to_id
dict
A dictionary to map names to IDs.
next_id
int
The next ID to be assigned.
Source code in src/aeiva/common/id_generator.py
class IDGenerator:\n \"\"\"\n A simple class to generate unique IDs for distinct names.\n\n Attributes:\n name_to_id (dict): A dictionary to map names to IDs.\n next_id (int): The next ID to be assigned.\n \"\"\"\n\n def __init__(self):\n \"\"\"\n Constructs all the necessary attributes for the IDGenerator object.\n\n Attributes:\n name_to_id (dict): Initializes an empty dictionary to map names to IDs.\n next_id (int): Initializes the next ID to be assigned as 0.\n \"\"\"\n self.name_to_id = {}\n self.next_id = 0\n\n def get_id(self, name: str) -> int:\n \"\"\"\n Returns the ID of the 'name'. If 'name' does not exist, assigns a new ID.\n\n Parameters:\n name (str): The name for which the ID is required.\n\n Returns:\n int: The ID associated with the 'name'.\n \"\"\"\n if name not in self.name_to_id:\n self.name_to_id[name] = self.next_id\n self.next_id += 1\n return self.name_to_id[name]\n
"},{"location":"reference/#src.aeiva.common.id_generator.IDGenerator.__init__","title":"__init__()
","text":"Constructs all the necessary attributes for the IDGenerator object.
Attributes:
Name Type Description name_to_id
dict
Initializes an empty dictionary to map names to IDs.
next_id
int
Initializes the next ID to be assigned as 0.
Source code in src/aeiva/common/id_generator.py
def __init__(self):\n \"\"\"\n Constructs all the necessary attributes for the IDGenerator object.\n\n Attributes:\n name_to_id (dict): Initializes an empty dictionary to map names to IDs.\n next_id (int): Initializes the next ID to be assigned as 0.\n \"\"\"\n self.name_to_id = {}\n self.next_id = 0\n
"},{"location":"reference/#src.aeiva.common.id_generator.IDGenerator.get_id","title":"get_id(name)
","text":"Returns the ID of the 'name'. If 'name' does not exist, assigns a new ID.
Parameters:
Name Type Description Default name
str
The name for which the ID is required.
required Returns:
Name Type Description int
int
The ID associated with the 'name'.
Source code in src/aeiva/common/id_generator.py
def get_id(self, name: str) -> int:\n \"\"\"\n Returns the ID of the 'name'. If 'name' does not exist, assigns a new ID.\n\n Parameters:\n name (str): The name for which the ID is required.\n\n Returns:\n int: The ID associated with the 'name'.\n \"\"\"\n if name not in self.name_to_id:\n self.name_to_id[name] = self.next_id\n self.next_id += 1\n return self.name_to_id[name]\n
"},{"location":"reference/#src.aeiva.common.pipeline","title":"pipeline
","text":""},{"location":"reference/#src.aeiva.common.pipeline.Pipeline","title":"Pipeline
","text":"This class is used to rurn a list of functions into a pipeline.
Source code in src/aeiva/common/pipeline.py
class Pipeline:\n r\"\"\"This class is used to rurn a list of functions into a pipeline.\"\"\"\n def __init__(self, functions):\n self.functions = functions\n\n def run(self, *args, **kwargs):\n result = self.functions[0](*args, **kwargs)\n for f in self.functions[1:]:\n if isinstance(result, tuple):\n result = f(*result)\n else:\n result = f(result)\n return result\n\n def __call__(self, *args, **kwargs):\n return self.run(*args, **kwargs)\n
"},{"location":"reference/#src.aeiva.common.types","title":"types
","text":""},{"location":"reference/#src.aeiva.common.types.DataBatch","title":"DataBatch
","text":" Bases: TypedDict
DataBatch is a batch of data items created by a dataloader.
Source code in src/aeiva/common/types.py
class DataBatch(TypedDict):\n r\"\"\"DataBatch is a batch of data items created by a dataloader.\n \"\"\"\n videos: Optional[torch.Tensor] # videos representation\n audios: Optional[torch.Tensor] # audios representation\n images: Optional[torch.Tensor] # images representation\n input_ids: Optional[torch.Tensor] # text token ids\n attention_mask: Optional[torch.Tensor] # attention mask\n image_starts: Optional[torch.Tensor] # image start token\n image_ends: Optional[torch.Tensor] # image end token\n audio_starts: Optional[torch.Tensor] # audio start token\n audio_ends: Optional[torch.Tensor] # audio end token\n video_starts: Optional[torch.Tensor] # video start token\n video_ends: Optional[torch.Tensor] # video end token\n labels: Optional[torch.Tensor] # labels\n
"},{"location":"reference/#src.aeiva.common.types.DataItem","title":"DataItem
","text":" Bases: TypedDict
DataItem is a dictionary that contains all the information for a single data item.
Source code in src/aeiva/common/types.py
class DataItem(TypedDict):\n r\"\"\"DataItem is a dictionary that contains all the information for a single data item.\n \"\"\"\n instruction: str # instruction text\n input: Optional[str] # input text\n output: Optional[str] # output text\n text: Optional[str] # text field. How it is formed depends on the task.\n\n image: Optional[str] # image name or path\n transformed_image: Optional[torch.Tensor] # transformed image tensor\n\n audio: Optional[str] # audio name or path\n audio_mels: Optional[torch.Tensor] # audio melspectrogram tensor\n\n video: Optional[str] # video name or path\n sampled_video_frame_indices: Optional[list[int]] # sampled video frame indices\n video_frames: Optional[torch.Tensor] # video frames tensor\n
"},{"location":"reference/#src.aeiva.common.types.DataSet","title":"DataSet
","text":" Bases: TypedDict
DataSet is a dictionary that contains data items and meta information.
Source code in src/aeiva/common/types.py
class DataSet(TypedDict):\n r\"\"\"DataSet is a dictionary that contains data items and meta information.\n \"\"\"\n data: list[DataItem]\n metadata: dict[str, Any]\n
"},{"location":"reference/#src.aeiva.common.types.ModelInput","title":"ModelInput
","text":" Bases: TypedDict
ModelInput is a dictionary that contains all the information for a model input. We use it to construct LEGO style models.
Source code in src/aeiva/common/types.py
class ModelInput(TypedDict):\n r\"\"\"ModelInput is a dictionary that contains all the information for a model input.\n We use it to construct LEGO style models.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.common.types.ModelOutput","title":"ModelOutput
","text":" Bases: TypedDict
ModelOutput is a dictionary that contains all the information for a model output. We use it to construct LEGO style models.
Source code in src/aeiva/common/types.py
class ModelOutput(TypedDict):\n r\"\"\"ModelOutput is a dictionary that contains all the information for a model output.\n We use it to construct LEGO style models.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.common.types.TaskContext","title":"TaskContext
","text":" Bases: TypedDict
TaskContext is a dictionary that contains all the information for a task.
Source code in src/aeiva/common/types.py
class TaskContext(TypedDict):\n r\"\"\"TaskContext is a dictionary that contains all the information for a task.\n \"\"\"\n config_path: Optional[str]\n config: Optional[OmniConfig]\n dataloader: Optional[torch.utils.data.DataLoader]\n tokenizer: Optional[Any]\n model: Optional[Any]\n logger: Optional[Any]\n trainer: Optional[Any]\n current_model_input: Optional[DataItem]\n current_model_output: Optional[Any]\n
"},{"location":"reference/#src.aeiva.config","title":"config
","text":""},{"location":"reference/#src.aeiva.config.DataConfig","title":"DataConfig
dataclass
","text":" Bases: BaseConfig
This class contains the data configuration.
Source code in src/aeiva/config/general_configs.py
@dataclass\nclass DataConfig(BaseConfig):\n \"\"\"This class contains the data configuration.\"\"\"\n dataset_path: Optional[str] = field(\n default=None, metadata={\"help\": \"The path of the dataset to use.\"}\n )\n dataset_name: Optional[str] = field(\n default=\"customized\", metadata={\"help\": \"Should be \\\"customized\\\"\"}\n )\n is_custom_dataset: Optional[bool] = field(\n default=False, metadata={\"help\": \"whether to use custom data\"}\n )\n customized_cache_dir: Optional[str] = field(\n default=\".cache/llm-ft/datasets\",\n metadata={\"help\": \"Where do you want to store the customized dataset caches\"},\n )\n dataset_config_name: Optional[str] = field(\n default=None, metadata={\"help\": \"The configuration name of the dataset to use (via the datasets library).\"}\n )\n train_file: Optional[str] = field(default=None, metadata={\"help\": \"The input training data file (a text file).\"})\n validation_file: Optional[str] = field(\n default=None,\n metadata={\"help\": \"An optional input evaluation data file to evaluate the perplexity on (a text file).\"},\n )\n max_train_samples: Optional[int] = field(\n default=None,\n metadata={\n \"help\": (\n \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n \"value if set.\"\n )\n },\n )\n max_eval_samples: Optional[int] = field(\n default=1e10,\n metadata={\n \"help\": (\n \"For debugging purposes or quicker training, truncate the number of evaluation examples to this \"\n \"value if set.\"\n )\n },\n )\n streaming: Optional[bool] = field(default=False, metadata={\"help\": \"Enable streaming mode\"})\n block_size: Optional[int] = field(\n default=512,\n metadata={\n \"help\": (\n \"Optional input sequence length after tokenization. \"\n \"The training dataset will be truncated in block of this size for training. \"\n \"Default to the model max input length for single sentence inputs (take into account special tokens).\"\n )\n },\n )\n overwrite_cache: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"Overwrite the cached training and evaluation sets\"}\n )\n validation_split_percentage: Optional[int] = field(\n default=5,\n metadata={\n \"help\": \"The percentage of the train set used as validation set in case there's no validation split\"\n },\n )\n preprocessing_num_workers: Optional[int] = field(\n default=None,\n metadata={\"help\": \"The number of processes to use for the preprocessing.\"},\n )\n group_texts_batch_size: Optional[int] = field(\n default=1000,\n metadata={\n \"help\": (\n \"Number of samples that will be grouped together to go though\"\n \" `group_texts` operation. See `--disable_group_texts` for\"\n \" detailed explanation of this operation.\"\n )\n }\n )\n disable_group_texts: Optional[bool] = field(\n default=False,\n metadata={\n \"help\": (\n \"Whether we group original samples together to generate sample\"\n \" sequences of length `block_size`. By default, we group every\"\n \" 1000 tokenized sequences together, divide them into \"\n \" [{total_num_tokens} / {block_size}] sequences, each with\"\n \" `block_size` tokens (the remaining tokens are ommited.\"\n \" If this flag is set to True, we only group 1 tokenized\"\n \" sequence, i.e. cutting long sequence into chunks.\"\n )\n },\n )\n keep_linebreaks: Optional[bool] = field(\n default=True,\n metadata={\"help\": \"Whether to keep line breaks when using TXT files or not.\"}\n )\n test_file: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Evaluation File Path\"},\n )\n\n def __post_init__(self):\n if self.streaming:\n require_version(\"datasets>=2.0.0\", \"The streaming feature requires `datasets>=2.0.0`\")\n\n if self.dataset_name is None and self.train_file is None and self.validation_file is None:\n raise ValueError(\"Need either a dataset name or a training/validation file.\")\n else:\n if self.train_file is not None:\n extension = self.train_file.split(\".\")[-1]\n assert extension in [\"csv\", \"json\", \"txt\"], \"`train_file` should be a csv, a json or a txt file.\"\n if self.validation_file is not None:\n extension = self.validation_file.split(\".\")[-1]\n assert extension in [\"csv\", \"json\", \"txt\"], \"`validation_file` should be a csv, a json or a txt file.\"\n
"},{"location":"reference/#src.aeiva.config.ExplicitEnum","title":"ExplicitEnum
","text":" Bases: str
, Enum
Enum with more explicit error message for missing values.
Source code in src/aeiva/config/general_configs.py
class ExplicitEnum(str, Enum):\n \"\"\"\n Enum with more explicit error message for missing values.\n \"\"\"\n @classmethod\n def _missing_(cls, value):\n raise ValueError(\n f\"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}\"\n )\n
"},{"location":"reference/#src.aeiva.config.ModelConfig","title":"ModelConfig
dataclass
","text":" Bases: BaseConfig
Model configuration class.
Source code in src/aeiva/config/general_configs.py
@dataclass\nclass ModelConfig(BaseConfig):\n \"\"\"Model configuration class.\"\"\"\n model_name_or_path: Optional[str] = field(\n default=None,\n metadata={\n \"help\": (\n \"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch.\"\n )\n },\n )\n lora_model_path: Optional[str] = field(\n default=None,\n metadata={\n \"help\": (\n \"The incremental model diff introduced by LoRA finetuning.\"\n \" Along with the original non-finetuned model forms the whole\"\n \" finetuned model.\"\n )\n }\n )\n model_type: Optional[str] = field(\n default=None,\n metadata={\"help\": \"If training from scratch, pass a model type from the list: \" + \", \".join(MODEL_TYPES)},\n )\n arch_type: Optional[str] = field(\n default=\"decoder_only\",\n metadata={\"help\": \"The architecture type of the model. Currently supported decoder_only or encoder_decoder\"}\n )\n config_overrides: Optional[str] = field(\n default=None,\n metadata={\n \"help\": (\n \"Override some existing default config settings when a model is trained from scratch. Example: \"\n \"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index\"\n )\n },\n )\n arch_type: Optional[str] = field(\n default=\"decoder_only\",\n metadata={\n \"help\": (\n \"Model architecture type, e.g. \\\"decoder_only\\\",\"\n \" \\\"encoder_decoder\\\"\"\n ),\n \"choices\": [\"decoder_only\", \"encoder_decoder\", \"text_regression\", \"vision_encoder_decoder\"],\n },\n )\n config_name: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Pretrained config name or path if not the same as model_name\"}\n )\n tokenizer_name: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Pretrained tokenizer name or path if not the same as model_name\"}\n )\n cache_dir: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Where do you want to store the pretrained models downloaded from huggingface.co\"},\n )\n use_fast_tokenizer: Optional[bool] = field(\n default=True,\n metadata={\"help\": \"Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.\"},\n )\n model_revision: Optional[str] = field(\n default=\"main\",\n metadata={\"help\": \"The specific model version to use (can be a branch name, tag name or commit id).\"},\n )\n use_auth_token: Optional[bool] = field(\n default=False,\n metadata={\n \"help\": (\n \"Will use the token generated when running `huggingface-cli login` (necessary to use this script \"\n \"with private models).\"\n )\n },\n )\n torch_dtype: Optional[str] = field(\n default=None,\n metadata={\n \"help\": (\n \"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the \"\n \"dtype will be automatically derived from the model's weights.\"\n ),\n \"choices\": [\"auto\", \"bfloat16\", \"float16\", \"float32\"],\n },\n )\n use_lora: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"Whether to lora.\"},\n )\n lora_r: Optional[int] = field(\n default=8,\n metadata={\"help\": \"the rank of the lora parameters. The smaller lora_r is , the fewer parameters lora has.\"},\n )\n lora_alpha: Optional[int] = field(\n default=32,\n metadata={\"help\": \"Merging ratio between the fine-tuned model and the original. This is controlled by a parameter called alpha in the paper.\"},\n )\n lora_target_modules: Optional[list[str]] = field(\n default=None,\n metadata={\"help\": \"Pretrained config name or path if not the same as model_name\",\n }\n )\n lora_dropout: Optional[float] = field(\n default=0.1,\n metadata={\"help\": \"The dropout rate in lora.linear.\"},\n )\n save_aggregated_lora: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"Whether to save aggregated lora.\"},\n )\n use_ram_optimized_load: Optional[bool] = field(\n default=True,\n metadata={\"help\": \"Whether use disk mapping when memory is not enough.\"}\n )\n use_flash_attention: Optional[bool] = field(\n default=False,\n metadata={\n \"help\": (\n \"whether use flash attention layer to reduce GPU memory with\"\n \" higher time cost.\"\n )\n }\n )\n use_int8: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"whether to load int8 quantization for inference\"}\n )\n custom_model: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"flag for the model from huggingface or not\"}\n )\n # below is added for macaw model\n n_frames: Optional[int] = field(\n default=6,\n metadata={\n \"help\": \"The number of frames for encoding a video.\"\n },\n )\n attention_heads: Optional[int] = field(\n default=220,\n metadata={\n \"help\": \"The number of attention heads used in multi-head-attention.\"\n },\n )\n image_conv_kernel: Optional[int] = field(\n default=48,\n metadata={\n \"help\": \"The size of the convolutional kernel for the image stream.\"\n },\n )\n image_conv_stride: Optional[int] = field(\n default=36,\n metadata={\n \"help\": \"The stride of the convolutional kernel for the image stream.\"\n },\n )\n video_conv_kernel: Optional[int] = field(\n default=36,\n metadata={\n \"help\": \"The size of the convolutional kernel for the video stream.\"\n },\n )\n video_conv_stride: Optional[int] = field(\n default=30,\n metadata={\n \"help\": \"The stride of the convolutional kernel for the video stream.\"\n },\n )\n audio_conv_kernel: Optional[int] = field(\n default=240,\n metadata={\n \"help\": \"The size of the convolutional kernel for the audio stream.\"\n },\n )\n audio_conv_stride: Optional[int] = field(\n default=220,\n metadata={\n \"help\": \"The stride of the convolutional kernel for the audio stream.\"\n },\n )\n freeze_multi_modal_encoder: bool = field(\n default=False,\n metadata={\n \"help\": (\n \"Whether to freeze the parameters of multi-modal encoders during training.).\"\n )\n },\n )\n\n def __post_init__(self):\n if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):\n raise ValueError(\n \"--config_overrides can't be used in combination with --config_name or --model_name_or_path\"\n )\n
"},{"location":"reference/#src.aeiva.config.OptimizerNames","title":"OptimizerNames
","text":" Bases: ExplicitEnum
Stores the acceptable string identifiers for optimizers.
Source code in src/aeiva/config/general_configs.py
class OptimizerNames(ExplicitEnum):\n \"\"\"\n Stores the acceptable string identifiers for optimizers.\n \"\"\"\n ADAMW_HF = \"adamw_hf\"\n ADAMW_TORCH = \"adamw_torch\"\n ADAMW_TORCH_FUSED = \"adamw_torch_fused\"\n ADAMW_TORCH_XLA = \"adamw_torch_xla\"\n ADAMW_APEX_FUSED = \"adamw_apex_fused\"\n ADAFACTOR = \"adafactor\"\n ADAMW_ANYPRECISION = \"adamw_anyprecision\"\n SGD = \"sgd\"\n ADAGRAD = \"adagrad\"\n ADAMW_BNB = \"adamw_bnb_8bit\"\n ADAMW_8BIT = \"adamw_8bit\" # just an alias for adamw_bnb_8bit\n LION_8BIT = \"lion_8bit\"\n LION = \"lion_32bit\"\n PAGED_ADAMW = \"paged_adamw_32bit\"\n PAGED_ADAMW_8BIT = \"paged_adamw_8bit\"\n PAGED_LION = \"paged_lion_32bit\"\n PAGED_LION_8BIT = \"paged_lion_8bit\"\n
"},{"location":"reference/#src.aeiva.config.base_config","title":"base_config
","text":"This module contains the base config classes.
We can define separate config classes for different modules, e.g., data, model, trainer, llm, etc. They will be automatically registered in the BaseConfig class.
Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig","title":"BaseConfig
dataclass
","text":"Base class for all configuration classes.
Source code in src/aeiva/config/base_config.py
@dataclass\nclass BaseConfig:\n \"\"\"\n Base class for all configuration classes.\n \"\"\"\n subclasses = {} # Dictionary to store subclasses\n\n def __init_subclass__(cls, **kwargs):\n \"\"\"\n This method is called when a subclass is created.\n \"\"\"\n super().__init_subclass__(**kwargs)\n BaseConfig.subclasses[cls.__name__] = cls\n\n def __post_init__(self):\n \"\"\"\n Empty post-init to allow subclasses to call super().__post_init__().\n \"\"\"\n pass\n\n @classmethod\n def from_dict(cls, data: dict):\n \"\"\"\n Create a new instance of the class from a dictionary.\n \"\"\"\n try:\n return cls(**data)\n except TypeError as e:\n invalid_keys = [key.strip(\"'\") for key in re.findall(r\"'(\\w+)'\", str(e))]\n raise ValueError(f\"Invalid config keys provided: {invalid_keys}. Details: {e}\")\n\n def to_dict(self):\n \"\"\"\n Convert the instance to a dictionary.\n \"\"\"\n return {k: v for k, v in self.__dict__.items() if not k.startswith('_')}\n\n @classmethod\n def from_json(cls, json_path: str):\n \"\"\"\n Create a new instance of the class from a JSON file.\n \"\"\"\n with open(json_path, \"r\") as json_file:\n data = json.load(json_file)\n return cls.from_dict(data)\n\n def to_json(self, filepath: str):\n \"\"\"\n Convert the instance to a JSON file.\n \"\"\"\n with open(filepath, 'w') as json_file:\n json.dump(self.to_dict(), json_file, indent=4)\n\n @classmethod\n def from_yaml(cls, yaml_path: str):\n \"\"\"\n Create a new instance of the class from a YAML file.\n \"\"\"\n with open(yaml_path, \"r\") as yaml_file:\n data = yaml.safe_load(yaml_file)\n return cls.from_dict(data)\n\n def to_yaml(self, filepath: str):\n \"\"\"\n Convert the instance to a YAML file.\n \"\"\"\n with open(filepath, 'w') as yaml_file:\n yaml.dump(self.to_dict(), yaml_file)\n\n @classmethod\n def from_json_or_yaml(cls, file_path: str):\n \"\"\"\n Create a new instance of the class from a JSON or YAML file.\n \"\"\"\n _, file_extension = os.path.splitext(file_path)\n if file_extension == \".json\":\n return cls.from_json(file_path)\n elif file_extension == \".yaml\" or file_extension == \".yml\":\n return cls.from_yaml(file_path)\n else:\n raise ValueError(f\"Unsupported file extension: {file_extension}. Please use .json or .yaml\")\n\n def __str__(self):\n \"\"\"\n Return a string representation of the instance.\n \"\"\"\n return pprint.pformat(self.to_dict(), indent=4)\n
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig.__init_subclass__","title":"__init_subclass__(**kwargs)
","text":"This method is called when a subclass is created.
Source code in src/aeiva/config/base_config.py
def __init_subclass__(cls, **kwargs):\n \"\"\"\n This method is called when a subclass is created.\n \"\"\"\n super().__init_subclass__(**kwargs)\n BaseConfig.subclasses[cls.__name__] = cls\n
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig.__post_init__","title":"__post_init__()
","text":"Empty post-init to allow subclasses to call super().post_init().
Source code in src/aeiva/config/base_config.py
def __post_init__(self):\n \"\"\"\n Empty post-init to allow subclasses to call super().__post_init__().\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig.__str__","title":"__str__()
","text":"Return a string representation of the instance.
Source code in src/aeiva/config/base_config.py
def __str__(self):\n \"\"\"\n Return a string representation of the instance.\n \"\"\"\n return pprint.pformat(self.to_dict(), indent=4)\n
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig.from_dict","title":"from_dict(data)
classmethod
","text":"Create a new instance of the class from a dictionary.
Source code in src/aeiva/config/base_config.py
@classmethod\ndef from_dict(cls, data: dict):\n \"\"\"\n Create a new instance of the class from a dictionary.\n \"\"\"\n try:\n return cls(**data)\n except TypeError as e:\n invalid_keys = [key.strip(\"'\") for key in re.findall(r\"'(\\w+)'\", str(e))]\n raise ValueError(f\"Invalid config keys provided: {invalid_keys}. Details: {e}\")\n
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig.from_json","title":"from_json(json_path)
classmethod
","text":"Create a new instance of the class from a JSON file.
Source code in src/aeiva/config/base_config.py
@classmethod\ndef from_json(cls, json_path: str):\n \"\"\"\n Create a new instance of the class from a JSON file.\n \"\"\"\n with open(json_path, \"r\") as json_file:\n data = json.load(json_file)\n return cls.from_dict(data)\n
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig.from_json_or_yaml","title":"from_json_or_yaml(file_path)
classmethod
","text":"Create a new instance of the class from a JSON or YAML file.
Source code in src/aeiva/config/base_config.py
@classmethod\ndef from_json_or_yaml(cls, file_path: str):\n \"\"\"\n Create a new instance of the class from a JSON or YAML file.\n \"\"\"\n _, file_extension = os.path.splitext(file_path)\n if file_extension == \".json\":\n return cls.from_json(file_path)\n elif file_extension == \".yaml\" or file_extension == \".yml\":\n return cls.from_yaml(file_path)\n else:\n raise ValueError(f\"Unsupported file extension: {file_extension}. Please use .json or .yaml\")\n
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig.from_yaml","title":"from_yaml(yaml_path)
classmethod
","text":"Create a new instance of the class from a YAML file.
Source code in src/aeiva/config/base_config.py
@classmethod\ndef from_yaml(cls, yaml_path: str):\n \"\"\"\n Create a new instance of the class from a YAML file.\n \"\"\"\n with open(yaml_path, \"r\") as yaml_file:\n data = yaml.safe_load(yaml_file)\n return cls.from_dict(data)\n
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig.to_dict","title":"to_dict()
","text":"Convert the instance to a dictionary.
Source code in src/aeiva/config/base_config.py
def to_dict(self):\n \"\"\"\n Convert the instance to a dictionary.\n \"\"\"\n return {k: v for k, v in self.__dict__.items() if not k.startswith('_')}\n
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig.to_json","title":"to_json(filepath)
","text":"Convert the instance to a JSON file.
Source code in src/aeiva/config/base_config.py
def to_json(self, filepath: str):\n \"\"\"\n Convert the instance to a JSON file.\n \"\"\"\n with open(filepath, 'w') as json_file:\n json.dump(self.to_dict(), json_file, indent=4)\n
"},{"location":"reference/#src.aeiva.config.base_config.BaseConfig.to_yaml","title":"to_yaml(filepath)
","text":"Convert the instance to a YAML file.
Source code in src/aeiva/config/base_config.py
def to_yaml(self, filepath: str):\n \"\"\"\n Convert the instance to a YAML file.\n \"\"\"\n with open(filepath, 'w') as yaml_file:\n yaml.dump(self.to_dict(), yaml_file)\n
"},{"location":"reference/#src.aeiva.config.custom_configs","title":"custom_configs
","text":""},{"location":"reference/#src.aeiva.config.custom_configs.macaw_config","title":"macaw_config
","text":"This module contains the config for macaw model.
We can define separate config classes for different specific models/datasets/tasks.
Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.
"},{"location":"reference/#src.aeiva.config.custom_configs.macaw_config.MacawConfig","title":"MacawConfig
dataclass
","text":" Bases: BaseConfig
Define user-customized config here.
Source code in src/aeiva/config/custom_configs/macaw_config.py
@dataclass\nclass MacawConfig(BaseConfig):\n \"\"\"\n Define user-customized config here.\n \"\"\"\n image_dir: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The directory of image data\"}\n )\n video_dir: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The directory of video data\"}\n )\n frame_dir: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The directory to save video frames\"}\n )\n audio_dir: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The directory to save video audios\"}\n )\n num_frames_to_sample: Optional[int] = field(\n default=120,\n metadata={\"help\": \"The number of frames to sample from a video\"}\n )\n num_frames_to_load: Optional[int] = field(\n default=6,\n metadata={\"help\": \"The number of frames to load as a part of model inputs\"}\n )\n num_samples_per_dataset: Optional[int] = field(\n default=100,\n metadata={\"help\": \"The number of samples to load from each dataset\"}\n )\n num_samples_per_merged_dataset: Optional[int] = field(\n default=20,\n metadata={\"help\": \"The number of samples to save after merging datasets\"}\n )\n batch_size: Optional[int] = field(\n default=1,\n metadata={\"help\": \"The batch size of model inputs\"}\n )\n max_seq_len_for_preprocess: Optional[int] = field(\n default=256,\n metadata={\"help\": \"The maximum sequence length for preprocess\"}\n )\n run_time_cache_dir: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The directory to save running time data, such as video frames, audios, and so on.\"}\n )\n tokenizer_name_or_path: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The name or path of tokenizer\"}\n )\n clip_model_name_or_path: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The name or path of clip model\"}\n )\n whisper_model_name_or_path: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The name or path of whisper model\"}\n )\n llama7b_model_name_or_path: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The name or path of llama7b model\"}\n )\n macaw_model_name_or_path: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The name or path of macaw model\"}\n )\n mode: Optional[str] = field(\n default=\"train\",\n metadata={\"help\": \"The mode of train, eval, or inference\"}\n )\n model_name: Optional[str] = field(\n default=\"macaw\",\n metadata={\"help\": \"The name of model\"}\n )\n resource_ready: Optional[bool] = field(\n default=True,\n metadata={\"help\": \"Whether the pre-requisite resource is ready, e.g., download pretrained models and datasets\"}\n )\n
"},{"location":"reference/#src.aeiva.config.general_configs","title":"general_configs
","text":"This module contains some general config classes that can be used in deep learning projects.
E.g., data config, model config, trainer config, etc.
Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.
"},{"location":"reference/#src.aeiva.config.general_configs.DataConfig","title":"DataConfig
dataclass
","text":" Bases: BaseConfig
This class contains the data configuration.
Source code in src/aeiva/config/general_configs.py
@dataclass\nclass DataConfig(BaseConfig):\n \"\"\"This class contains the data configuration.\"\"\"\n dataset_path: Optional[str] = field(\n default=None, metadata={\"help\": \"The path of the dataset to use.\"}\n )\n dataset_name: Optional[str] = field(\n default=\"customized\", metadata={\"help\": \"Should be \\\"customized\\\"\"}\n )\n is_custom_dataset: Optional[bool] = field(\n default=False, metadata={\"help\": \"whether to use custom data\"}\n )\n customized_cache_dir: Optional[str] = field(\n default=\".cache/llm-ft/datasets\",\n metadata={\"help\": \"Where do you want to store the customized dataset caches\"},\n )\n dataset_config_name: Optional[str] = field(\n default=None, metadata={\"help\": \"The configuration name of the dataset to use (via the datasets library).\"}\n )\n train_file: Optional[str] = field(default=None, metadata={\"help\": \"The input training data file (a text file).\"})\n validation_file: Optional[str] = field(\n default=None,\n metadata={\"help\": \"An optional input evaluation data file to evaluate the perplexity on (a text file).\"},\n )\n max_train_samples: Optional[int] = field(\n default=None,\n metadata={\n \"help\": (\n \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n \"value if set.\"\n )\n },\n )\n max_eval_samples: Optional[int] = field(\n default=1e10,\n metadata={\n \"help\": (\n \"For debugging purposes or quicker training, truncate the number of evaluation examples to this \"\n \"value if set.\"\n )\n },\n )\n streaming: Optional[bool] = field(default=False, metadata={\"help\": \"Enable streaming mode\"})\n block_size: Optional[int] = field(\n default=512,\n metadata={\n \"help\": (\n \"Optional input sequence length after tokenization. \"\n \"The training dataset will be truncated in block of this size for training. \"\n \"Default to the model max input length for single sentence inputs (take into account special tokens).\"\n )\n },\n )\n overwrite_cache: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"Overwrite the cached training and evaluation sets\"}\n )\n validation_split_percentage: Optional[int] = field(\n default=5,\n metadata={\n \"help\": \"The percentage of the train set used as validation set in case there's no validation split\"\n },\n )\n preprocessing_num_workers: Optional[int] = field(\n default=None,\n metadata={\"help\": \"The number of processes to use for the preprocessing.\"},\n )\n group_texts_batch_size: Optional[int] = field(\n default=1000,\n metadata={\n \"help\": (\n \"Number of samples that will be grouped together to go though\"\n \" `group_texts` operation. See `--disable_group_texts` for\"\n \" detailed explanation of this operation.\"\n )\n }\n )\n disable_group_texts: Optional[bool] = field(\n default=False,\n metadata={\n \"help\": (\n \"Whether we group original samples together to generate sample\"\n \" sequences of length `block_size`. By default, we group every\"\n \" 1000 tokenized sequences together, divide them into \"\n \" [{total_num_tokens} / {block_size}] sequences, each with\"\n \" `block_size` tokens (the remaining tokens are ommited.\"\n \" If this flag is set to True, we only group 1 tokenized\"\n \" sequence, i.e. cutting long sequence into chunks.\"\n )\n },\n )\n keep_linebreaks: Optional[bool] = field(\n default=True,\n metadata={\"help\": \"Whether to keep line breaks when using TXT files or not.\"}\n )\n test_file: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Evaluation File Path\"},\n )\n\n def __post_init__(self):\n if self.streaming:\n require_version(\"datasets>=2.0.0\", \"The streaming feature requires `datasets>=2.0.0`\")\n\n if self.dataset_name is None and self.train_file is None and self.validation_file is None:\n raise ValueError(\"Need either a dataset name or a training/validation file.\")\n else:\n if self.train_file is not None:\n extension = self.train_file.split(\".\")[-1]\n assert extension in [\"csv\", \"json\", \"txt\"], \"`train_file` should be a csv, a json or a txt file.\"\n if self.validation_file is not None:\n extension = self.validation_file.split(\".\")[-1]\n assert extension in [\"csv\", \"json\", \"txt\"], \"`validation_file` should be a csv, a json or a txt file.\"\n
"},{"location":"reference/#src.aeiva.config.general_configs.ExplicitEnum","title":"ExplicitEnum
","text":" Bases: str
, Enum
Enum with more explicit error message for missing values.
Source code in src/aeiva/config/general_configs.py
class ExplicitEnum(str, Enum):\n \"\"\"\n Enum with more explicit error message for missing values.\n \"\"\"\n @classmethod\n def _missing_(cls, value):\n raise ValueError(\n f\"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}\"\n )\n
"},{"location":"reference/#src.aeiva.config.general_configs.ModelConfig","title":"ModelConfig
dataclass
","text":" Bases: BaseConfig
Model configuration class.
Source code in src/aeiva/config/general_configs.py
@dataclass\nclass ModelConfig(BaseConfig):\n \"\"\"Model configuration class.\"\"\"\n model_name_or_path: Optional[str] = field(\n default=None,\n metadata={\n \"help\": (\n \"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch.\"\n )\n },\n )\n lora_model_path: Optional[str] = field(\n default=None,\n metadata={\n \"help\": (\n \"The incremental model diff introduced by LoRA finetuning.\"\n \" Along with the original non-finetuned model forms the whole\"\n \" finetuned model.\"\n )\n }\n )\n model_type: Optional[str] = field(\n default=None,\n metadata={\"help\": \"If training from scratch, pass a model type from the list: \" + \", \".join(MODEL_TYPES)},\n )\n arch_type: Optional[str] = field(\n default=\"decoder_only\",\n metadata={\"help\": \"The architecture type of the model. Currently supported decoder_only or encoder_decoder\"}\n )\n config_overrides: Optional[str] = field(\n default=None,\n metadata={\n \"help\": (\n \"Override some existing default config settings when a model is trained from scratch. Example: \"\n \"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index\"\n )\n },\n )\n arch_type: Optional[str] = field(\n default=\"decoder_only\",\n metadata={\n \"help\": (\n \"Model architecture type, e.g. \\\"decoder_only\\\",\"\n \" \\\"encoder_decoder\\\"\"\n ),\n \"choices\": [\"decoder_only\", \"encoder_decoder\", \"text_regression\", \"vision_encoder_decoder\"],\n },\n )\n config_name: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Pretrained config name or path if not the same as model_name\"}\n )\n tokenizer_name: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Pretrained tokenizer name or path if not the same as model_name\"}\n )\n cache_dir: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Where do you want to store the pretrained models downloaded from huggingface.co\"},\n )\n use_fast_tokenizer: Optional[bool] = field(\n default=True,\n metadata={\"help\": \"Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.\"},\n )\n model_revision: Optional[str] = field(\n default=\"main\",\n metadata={\"help\": \"The specific model version to use (can be a branch name, tag name or commit id).\"},\n )\n use_auth_token: Optional[bool] = field(\n default=False,\n metadata={\n \"help\": (\n \"Will use the token generated when running `huggingface-cli login` (necessary to use this script \"\n \"with private models).\"\n )\n },\n )\n torch_dtype: Optional[str] = field(\n default=None,\n metadata={\n \"help\": (\n \"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the \"\n \"dtype will be automatically derived from the model's weights.\"\n ),\n \"choices\": [\"auto\", \"bfloat16\", \"float16\", \"float32\"],\n },\n )\n use_lora: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"Whether to lora.\"},\n )\n lora_r: Optional[int] = field(\n default=8,\n metadata={\"help\": \"the rank of the lora parameters. The smaller lora_r is , the fewer parameters lora has.\"},\n )\n lora_alpha: Optional[int] = field(\n default=32,\n metadata={\"help\": \"Merging ratio between the fine-tuned model and the original. This is controlled by a parameter called alpha in the paper.\"},\n )\n lora_target_modules: Optional[list[str]] = field(\n default=None,\n metadata={\"help\": \"Pretrained config name or path if not the same as model_name\",\n }\n )\n lora_dropout: Optional[float] = field(\n default=0.1,\n metadata={\"help\": \"The dropout rate in lora.linear.\"},\n )\n save_aggregated_lora: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"Whether to save aggregated lora.\"},\n )\n use_ram_optimized_load: Optional[bool] = field(\n default=True,\n metadata={\"help\": \"Whether use disk mapping when memory is not enough.\"}\n )\n use_flash_attention: Optional[bool] = field(\n default=False,\n metadata={\n \"help\": (\n \"whether use flash attention layer to reduce GPU memory with\"\n \" higher time cost.\"\n )\n }\n )\n use_int8: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"whether to load int8 quantization for inference\"}\n )\n custom_model: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"flag for the model from huggingface or not\"}\n )\n # below is added for macaw model\n n_frames: Optional[int] = field(\n default=6,\n metadata={\n \"help\": \"The number of frames for encoding a video.\"\n },\n )\n attention_heads: Optional[int] = field(\n default=220,\n metadata={\n \"help\": \"The number of attention heads used in multi-head-attention.\"\n },\n )\n image_conv_kernel: Optional[int] = field(\n default=48,\n metadata={\n \"help\": \"The size of the convolutional kernel for the image stream.\"\n },\n )\n image_conv_stride: Optional[int] = field(\n default=36,\n metadata={\n \"help\": \"The stride of the convolutional kernel for the image stream.\"\n },\n )\n video_conv_kernel: Optional[int] = field(\n default=36,\n metadata={\n \"help\": \"The size of the convolutional kernel for the video stream.\"\n },\n )\n video_conv_stride: Optional[int] = field(\n default=30,\n metadata={\n \"help\": \"The stride of the convolutional kernel for the video stream.\"\n },\n )\n audio_conv_kernel: Optional[int] = field(\n default=240,\n metadata={\n \"help\": \"The size of the convolutional kernel for the audio stream.\"\n },\n )\n audio_conv_stride: Optional[int] = field(\n default=220,\n metadata={\n \"help\": \"The stride of the convolutional kernel for the audio stream.\"\n },\n )\n freeze_multi_modal_encoder: bool = field(\n default=False,\n metadata={\n \"help\": (\n \"Whether to freeze the parameters of multi-modal encoders during training.).\"\n )\n },\n )\n\n def __post_init__(self):\n if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):\n raise ValueError(\n \"--config_overrides can't be used in combination with --config_name or --model_name_or_path\"\n )\n
"},{"location":"reference/#src.aeiva.config.general_configs.OptimizerNames","title":"OptimizerNames
","text":" Bases: ExplicitEnum
Stores the acceptable string identifiers for optimizers.
Source code in src/aeiva/config/general_configs.py
class OptimizerNames(ExplicitEnum):\n \"\"\"\n Stores the acceptable string identifiers for optimizers.\n \"\"\"\n ADAMW_HF = \"adamw_hf\"\n ADAMW_TORCH = \"adamw_torch\"\n ADAMW_TORCH_FUSED = \"adamw_torch_fused\"\n ADAMW_TORCH_XLA = \"adamw_torch_xla\"\n ADAMW_APEX_FUSED = \"adamw_apex_fused\"\n ADAFACTOR = \"adafactor\"\n ADAMW_ANYPRECISION = \"adamw_anyprecision\"\n SGD = \"sgd\"\n ADAGRAD = \"adagrad\"\n ADAMW_BNB = \"adamw_bnb_8bit\"\n ADAMW_8BIT = \"adamw_8bit\" # just an alias for adamw_bnb_8bit\n LION_8BIT = \"lion_8bit\"\n LION = \"lion_32bit\"\n PAGED_ADAMW = \"paged_adamw_32bit\"\n PAGED_ADAMW_8BIT = \"paged_adamw_8bit\"\n PAGED_LION = \"paged_lion_32bit\"\n PAGED_LION_8BIT = \"paged_lion_8bit\"\n
"},{"location":"reference/#src.aeiva.config.omni_config","title":"omni_config
","text":"This module contains the OmniConfig classes.
We can define separate config classes for different modules, e.g., data, model, trainer, etc. The OmniConfig class is the combination of all config classes. It can also accept command line arguments to update the config values.
Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.
"},{"location":"reference/#src.aeiva.config.omni_config.OmniConfig","title":"OmniConfig
dataclass
","text":" Bases: BaseConfig
Source code in src/aeiva/config/omni_config.py
@dataclass\nclass OmniConfig(BaseConfig):\n @staticmethod\n def create_omni_config():\n \"\"\"\n Initializes OmniConfig by aggregating all configuration classes.\n \"\"\"\n # Aggregating default values from all config classes\n defaults = {}\n for config_class_name, config_class in BaseConfig.subclasses.items():\n if config_class_name == \"OmniConfig\":\n continue\n for field_name, field_obj in config_class.__dataclass_fields__.items():\n if field_name in defaults:\n raise ValueError(f\"Overlapping config argument: '{field_name}' found in {config_class.__name__}\")\n default_value = getattr(config_class(), field_name, None)\n defaults[field_name] = default_value\n\n def __init__(self, **kwargs):\n for key, default_value in defaults.items():\n setattr(self, key, kwargs.get(key, default_value))\n\n OmniConfig.__init__ = __init__\n return OmniConfig\n\n def update_from_args(self, namespace_args: argparse.Namespace):\n \"\"\"\n Updates the configuration based on parsed command-line arguments.\n \"\"\"\n for key, value in vars(namespace_args).items():\n if hasattr(self, key) and value is not None:\n setattr(self, key, value)\n\n def get_argparse_parser(self):\n \"\"\"\n Creates an argument parser that can handle complex types.\n \"\"\"\n parser = argparse.ArgumentParser()\n for config_class_name, config_class in BaseConfig.subclasses.items():\n if config_class_name == \"OmniConfig\":\n continue\n for field_name, field_obj in config_class.__dataclass_fields__.items():\n field_type = field_obj.type\n\n # Handle Optional types\n if get_origin(field_type) is Union and type(None) in get_args(field_type):\n field_type = next(arg for arg in get_args(field_type) if arg is not type(None))\n\n arg_name = '--' + field_name\n help_msg = field_obj.metadata.get(\"help\", f\"{field_name} ({field_type})\")\n\n origin = get_origin(field_type)\n args = get_args(field_type)\n\n # Handle Enums\n if isinstance(field_type, type) and issubclass(field_type, enum.Enum):\n choices = [item.value for item in field_type]\n parser.add_argument(arg_name, type=str, choices=choices, help=help_msg)\n continue\n\n # Handle list types\n if origin is list:\n item_type = args[0]\n if item_type is str:\n parser.add_argument(arg_name, nargs='+', type=str, help=help_msg)\n elif item_type is int:\n parser.add_argument(arg_name, nargs='+', type=int, help=help_msg)\n else:\n # Default to strings if item type is not specifically handled\n parser.add_argument(arg_name, nargs='+', type=str, help=help_msg)\n continue\n\n # Handle tuple types\n if origin is tuple:\n # Accept comma-separated values and convert to tuple\n def tuple_type(s):\n try:\n return tuple(map(int, s.split(',')))\n except ValueError:\n raise argparse.ArgumentTypeError(\"Tuples must be comma-separated integers.\")\n\n parser.add_argument(arg_name, type=tuple_type, help=help_msg)\n continue\n\n # Handle dict types\n if origin is dict:\n # Expect JSON string\n def dict_type(s):\n try:\n return json.loads(s)\n except json.JSONDecodeError:\n raise argparse.ArgumentTypeError(\"Dictionaries must be valid JSON strings.\")\n\n parser.add_argument(arg_name, type=dict_type, help=help_msg)\n continue\n\n # Handle basic types\n if field_type is int:\n parser.add_argument(arg_name, type=int, help=help_msg)\n elif field_type is float:\n parser.add_argument(arg_name, type=float, help=help_msg)\n elif field_type is str:\n parser.add_argument(arg_name, type=str, help=help_msg)\n elif field_type is bool:\n parser.add_argument(arg_name, action='store_true', help=help_msg)\n else:\n print(f\"Warning: unsupported type {field_type} for field '{field_name}'\")\n return parser\n
"},{"location":"reference/#src.aeiva.config.omni_config.OmniConfig.create_omni_config","title":"create_omni_config()
staticmethod
","text":"Initializes OmniConfig by aggregating all configuration classes.
Source code in src/aeiva/config/omni_config.py
@staticmethod\ndef create_omni_config():\n \"\"\"\n Initializes OmniConfig by aggregating all configuration classes.\n \"\"\"\n # Aggregating default values from all config classes\n defaults = {}\n for config_class_name, config_class in BaseConfig.subclasses.items():\n if config_class_name == \"OmniConfig\":\n continue\n for field_name, field_obj in config_class.__dataclass_fields__.items():\n if field_name in defaults:\n raise ValueError(f\"Overlapping config argument: '{field_name}' found in {config_class.__name__}\")\n default_value = getattr(config_class(), field_name, None)\n defaults[field_name] = default_value\n\n def __init__(self, **kwargs):\n for key, default_value in defaults.items():\n setattr(self, key, kwargs.get(key, default_value))\n\n OmniConfig.__init__ = __init__\n return OmniConfig\n
"},{"location":"reference/#src.aeiva.config.omni_config.OmniConfig.get_argparse_parser","title":"get_argparse_parser()
","text":"Creates an argument parser that can handle complex types.
Source code in src/aeiva/config/omni_config.py
def get_argparse_parser(self):\n \"\"\"\n Creates an argument parser that can handle complex types.\n \"\"\"\n parser = argparse.ArgumentParser()\n for config_class_name, config_class in BaseConfig.subclasses.items():\n if config_class_name == \"OmniConfig\":\n continue\n for field_name, field_obj in config_class.__dataclass_fields__.items():\n field_type = field_obj.type\n\n # Handle Optional types\n if get_origin(field_type) is Union and type(None) in get_args(field_type):\n field_type = next(arg for arg in get_args(field_type) if arg is not type(None))\n\n arg_name = '--' + field_name\n help_msg = field_obj.metadata.get(\"help\", f\"{field_name} ({field_type})\")\n\n origin = get_origin(field_type)\n args = get_args(field_type)\n\n # Handle Enums\n if isinstance(field_type, type) and issubclass(field_type, enum.Enum):\n choices = [item.value for item in field_type]\n parser.add_argument(arg_name, type=str, choices=choices, help=help_msg)\n continue\n\n # Handle list types\n if origin is list:\n item_type = args[0]\n if item_type is str:\n parser.add_argument(arg_name, nargs='+', type=str, help=help_msg)\n elif item_type is int:\n parser.add_argument(arg_name, nargs='+', type=int, help=help_msg)\n else:\n # Default to strings if item type is not specifically handled\n parser.add_argument(arg_name, nargs='+', type=str, help=help_msg)\n continue\n\n # Handle tuple types\n if origin is tuple:\n # Accept comma-separated values and convert to tuple\n def tuple_type(s):\n try:\n return tuple(map(int, s.split(',')))\n except ValueError:\n raise argparse.ArgumentTypeError(\"Tuples must be comma-separated integers.\")\n\n parser.add_argument(arg_name, type=tuple_type, help=help_msg)\n continue\n\n # Handle dict types\n if origin is dict:\n # Expect JSON string\n def dict_type(s):\n try:\n return json.loads(s)\n except json.JSONDecodeError:\n raise argparse.ArgumentTypeError(\"Dictionaries must be valid JSON strings.\")\n\n parser.add_argument(arg_name, type=dict_type, help=help_msg)\n continue\n\n # Handle basic types\n if field_type is int:\n parser.add_argument(arg_name, type=int, help=help_msg)\n elif field_type is float:\n parser.add_argument(arg_name, type=float, help=help_msg)\n elif field_type is str:\n parser.add_argument(arg_name, type=str, help=help_msg)\n elif field_type is bool:\n parser.add_argument(arg_name, action='store_true', help=help_msg)\n else:\n print(f\"Warning: unsupported type {field_type} for field '{field_name}'\")\n return parser\n
"},{"location":"reference/#src.aeiva.config.omni_config.OmniConfig.update_from_args","title":"update_from_args(namespace_args)
","text":"Updates the configuration based on parsed command-line arguments.
Source code in src/aeiva/config/omni_config.py
def update_from_args(self, namespace_args: argparse.Namespace):\n \"\"\"\n Updates the configuration based on parsed command-line arguments.\n \"\"\"\n for key, value in vars(namespace_args).items():\n if hasattr(self, key) and value is not None:\n setattr(self, key, value)\n
"},{"location":"reference/#src.aeiva.data","title":"data
","text":""},{"location":"reference/#src.aeiva.data.processor","title":"processor
","text":"This module contains the data processor.
@Author: Bang Liu (chatsci.ai@gmail.com) @Date: 2023-07-11
Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.
"},{"location":"reference/#src.aeiva.data.processor.process_dataset","title":"process_dataset(formatted_dataset, pipeline, output_dir, dataset_name='')
","text":"Process a dataset with a pipeline of functions.
Parameters:
Name Type Description Default formatted_dataset
DataSet
the dataset to be processed.
required pipeline
list[Callable]
a list of functions to be applied to the dataset.
required output_dir
Optional[str]
the output directory to save the processed dataset.
required dataset_name
Optional[str]
the name of the dataset. Defaults to \"\".
''
Returns:
Name Type Description DataSet
DataSet
the processed dataset.
Source code in src/aeiva/data/processor.py
def process_dataset(formatted_dataset: DataSet,\n pipeline: list[Callable],\n output_dir: Optional[str],\n dataset_name: Optional[str] = \"\") -> DataSet:\n \"\"\"\n Process a dataset with a pipeline of functions.\n\n Args:\n formatted_dataset (DataSet): the dataset to be processed.\n pipeline (list[Callable]): a list of functions to be applied to the dataset.\n output_dir (Optional[str]): the output directory to save the processed dataset.\n dataset_name (Optional[str], optional): the name of the dataset. Defaults to \"\".\n\n Returns:\n DataSet: the processed dataset.\n \"\"\"\n processed_data = []\n pipeline = Pipeline(pipeline)\n for item in formatted_dataset[\"data\"]:\n processed_data.append(pipeline(item.copy()))\n\n output = {\"data\": processed_data, \"metadata\": formatted_dataset[\"metadata\"]}\n if output_dir is not None:\n ensure_dir(output_dir)\n dump_json(output, f\"{output_dir}/{dataset_name}_dataset.processed.json\")\n return output\n
"},{"location":"reference/#src.aeiva.demo","title":"demo
","text":""},{"location":"reference/#src.aeiva.demo.chat_gradio","title":"chat_gradio
","text":""},{"location":"reference/#src.aeiva.demo.chat_gradio.bot","title":"bot(user_input, history)
async
","text":"Handles chatbot logic by emitting perception.gradio events to the Agent and retrieving responses.
Source code in src/aeiva/demo/chat_gradio.py
async def bot(user_input, history):\n \"\"\"\n Handles chatbot logic by emitting perception.gradio events to the Agent and retrieving responses.\n \"\"\"\n if agent is None:\n logger.error(\"Agent is not initialized.\")\n history.append({\"role\": \"assistant\", \"content\": \"Agent is not initialized.\"})\n yield history, ''\n return\n\n try:\n # Append user's message to history\n history.append({\"role\": \"user\", \"content\": user_input})\n # Append an empty assistant response\n history.append({\"role\": \"assistant\", \"content\": \"\"})\n yield history, '' # Display the user's message\n logger.info(f\"User input appended to history: {user_input}\")\n\n stream = config_dict[\"llm_gateway_config\"][\"llm_stream\"]\n use_async = config_dict[\"llm_gateway_config\"][\"llm_use_async\"]\n\n # Emit the 'perception.gradio' event with stream=True\n emit_future = asyncio.run_coroutine_threadsafe(\n agent.event_bus.emit('perception.gradio', payload=user_input), # TODO: maybe simplify payload, Agent can directly read stream and use_async from config.\n agent.event_bus.loop\n )\n emit_future.result() # Ensure the event is emitted\n logger.info(f\"Emitted 'perception.gradio' event with payload: {user_input} | Stream: {stream}\")\n\n assistant_message = ''\n if stream:\n while True:\n try:\n # Non-blocking response retrieval from the thread-safe queue with timeout\n response = await asyncio.wait_for(\n asyncio.to_thread(response_queue.get, True, 30),\n timeout=30\n )\n logger.info(f\"Retrieved response from queue: {response}\")\n if response == \"<END_OF_RESPONSE>\":\n logger.info(\"Received end of response signal.\")\n break\n assistant_message += response\n # Create a new history list to ensure Gradio detects the update\n new_history = history.copy()\n new_history[-1][\"content\"] = assistant_message\n logger.info(f\"Yielding updated history: {new_history}\")\n yield new_history, ''\n except asyncio.TimeoutError:\n logger.warning(\"Timeout: No response received from Agent.\")\n # Create a new history list to ensure Gradio detects the update\n new_history = history.copy()\n new_history[-1][\"content\"] = \"I'm sorry, I didn't receive a response in time.\"\n yield new_history, ''\n break\n else:\n try:\n # Non-blocking response retrieval from the thread-safe queue with timeout\n response = await asyncio.wait_for(\n asyncio.to_thread(response_queue.get, True, 30),\n timeout=30\n )\n logger.info(f\"Retrieved response from queue: {response}\")\n assistant_message += response\n # Create a new history list to ensure Gradio detects the update\n new_history = history.copy()\n new_history[-1][\"content\"] = assistant_message\n logger.info(f\"Yielding updated history: {new_history}\")\n yield new_history, ''\n except asyncio.TimeoutError:\n logger.warning(\"Timeout: No response received from Agent.\")\n # Create a new history list to ensure Gradio detects the update\n new_history = history.copy()\n new_history[-1][\"content\"] = \"I'm sorry, I didn't receive a response in time.\"\n yield new_history, ''\n\n except Exception as e:\n logger.error(f\"Unexpected Error in bot function: {e}\")\n # Create a new history list to ensure Gradio detects the update\n new_history = history.copy()\n new_history[-1][\"content\"] = \"An unexpected error occurred.\"\n yield new_history, ''\n
"},{"location":"reference/#src.aeiva.demo.chat_gradio.clear_media","title":"clear_media()
","text":"Clears the uploaded media paths.
Source code in src/aeiva/demo/chat_gradio.py
def clear_media():\n \"\"\"\n Clears the uploaded media paths.\n \"\"\"\n # Implement any necessary logic to clear media paths or data\n logger.info(\"Cleared uploaded media paths.\")\n return \"\"\n
"},{"location":"reference/#src.aeiva.demo.chat_gradio.handle_upload","title":"handle_upload(file)
","text":"Handles file uploads and delegates to specific handlers based on file type.
Parameters:
Name Type Description Default file
Uploaded file object.
required Returns:
Name Type Description str
Message indicating the upload status.
Source code in src/aeiva/demo/chat_gradio.py
def handle_upload(file):\n \"\"\"\n Handles file uploads and delegates to specific handlers based on file type.\n\n Args:\n file: Uploaded file object.\n\n Returns:\n str: Message indicating the upload status.\n \"\"\"\n if file is None:\n return \"\"\n if file.type.startswith(\"image\"):\n return handle_image_upload(file)\n elif file.type.startswith(\"video\"):\n return handle_video_upload(file)\n elif file.type.startswith(\"audio\"):\n return handle_audio_upload(file)\n else:\n logger.warning(f\"Unsupported file type uploaded: {file.type}\")\n return \"Unsupported file type uploaded.\"\n
"},{"location":"reference/#src.aeiva.demo.mm_chatbot","title":"mm_chatbot
","text":"This module defines a multimodal chatbot demo with gradio.
@Author: Bang Liu (chatsci.ai@gmail.com) @Date: 2023-07-13
Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.
"},{"location":"reference/#src.aeiva.environment","title":"environment
","text":""},{"location":"reference/#src.aeiva.environment.environment","title":"environment
","text":""},{"location":"reference/#src.aeiva.environment.environment.Environment","title":"Environment
","text":" Bases: ABC
Abstract base class for an environment in which an intelligent agent operates.
Each environment provides context, defines interactions, and manages its own state. Subclasses should implement specific methods for different types of environments.
Attributes:
Name Type Description config
EnvironmentConfig
Configuration settings for the environment.
state
Any
Current state of the environment, initialized from the config.
entities
List[Any]
Entities present within the environment.
constraints
Dict[str, Any]
Rules or limitations for interactions in the environment.
time
Optional[int]
Time progression within the environment, if enabled.
Source code in src/aeiva/environment/environment.py
class Environment(ABC):\n \"\"\"\n Abstract base class for an environment in which an intelligent agent operates.\n\n Each environment provides context, defines interactions, and manages its own state.\n Subclasses should implement specific methods for different types of environments.\n\n Attributes:\n config (EnvironmentConfig): Configuration settings for the environment.\n state (Any): Current state of the environment, initialized from the config.\n entities (List[Any]): Entities present within the environment.\n constraints (Dict[str, Any]): Rules or limitations for interactions in the environment.\n time (Optional[int]): Time progression within the environment, if enabled.\n \"\"\"\n\n def __init__(self, config: EnvironmentConfig):\n \"\"\"\n Initialize the environment with a given configuration.\n\n Args:\n config (EnvironmentConfig): Configuration settings for the environment.\n \"\"\"\n self.config = config\n self.state = config.initial_state\n self.entities = config.entities\n self.constraints = config.constraints\n self.time = 0 if config.time_enabled else None\n self.setup()\n\n @abstractmethod\n def setup(self):\n \"\"\"\n Set up the environment based on its configuration.\n Subclasses should define any initialization logic here.\n \"\"\"\n pass\n\n @abstractmethod\n def reset(self):\n \"\"\"\n Reset the environment to its initial state as defined by the configuration.\n \"\"\"\n self.state = self.config.initial_state\n self.time = 0 if self.config.time_enabled else None\n\n @abstractmethod\n def step(self, actions: Dict[Any, Any]):\n \"\"\"\n Advance the environment by one step based on actions taken by agents.\n\n Args:\n actions (Dict[Any, Any]): A dictionary of actions performed by agents.\n \"\"\"\n pass\n\n @abstractmethod\n def observe(self, agent: Any) -> Any:\n \"\"\"\n Provide observations to an agent based on the current state.\n\n Args:\n agent (Any): The agent requesting observation.\n\n Returns:\n Any: Observation data formatted according to the agent's perception capabilities.\n \"\"\"\n pass\n\n @abstractmethod\n def act(self, action: Any, target: Optional[Any] = None):\n \"\"\"\n Execute an action in the environment, potentially modifying its state.\n\n Args:\n action (Any): The action to be executed.\n target (Optional[Any]): Target entity for the action, if applicable.\n \"\"\"\n pass\n\n def render(self):\n \"\"\"\n Visualize or output the environment's current state. Optional for subclasses.\n \"\"\"\n print(f\"Environment State: {self.state}\")\n\n def get_context(self) -> Any:\n \"\"\"\n Retrieve relevant context information from the environment, useful for agent processing.\n\n Returns:\n Any: Contextual data or state relevant to the agent's tasks.\n \"\"\"\n return self.state\n\n def close(self):\n \"\"\"\n Clean up any resources tied to the environment when it's no longer needed.\n \"\"\"\n print(\"Closing environment and releasing resources.\")\n\n def __repr__(self) -> str:\n return (f\"Environment(type={self.config.environment_type}, \"\n f\"state={self.state}, \"\n f\"entities={self.entities}, \"\n f\"time={self.time}, \"\n f\"constraints={self.constraints})\")\n
"},{"location":"reference/#src.aeiva.environment.environment.Environment.__init__","title":"__init__(config)
","text":"Initialize the environment with a given configuration.
Parameters:
Name Type Description Default config
EnvironmentConfig
Configuration settings for the environment.
required Source code in src/aeiva/environment/environment.py
def __init__(self, config: EnvironmentConfig):\n \"\"\"\n Initialize the environment with a given configuration.\n\n Args:\n config (EnvironmentConfig): Configuration settings for the environment.\n \"\"\"\n self.config = config\n self.state = config.initial_state\n self.entities = config.entities\n self.constraints = config.constraints\n self.time = 0 if config.time_enabled else None\n self.setup()\n
"},{"location":"reference/#src.aeiva.environment.environment.Environment.act","title":"act(action, target=None)
abstractmethod
","text":"Execute an action in the environment, potentially modifying its state.
Parameters:
Name Type Description Default action
Any
The action to be executed.
required target
Optional[Any]
Target entity for the action, if applicable.
None
Source code in src/aeiva/environment/environment.py
@abstractmethod\ndef act(self, action: Any, target: Optional[Any] = None):\n \"\"\"\n Execute an action in the environment, potentially modifying its state.\n\n Args:\n action (Any): The action to be executed.\n target (Optional[Any]): Target entity for the action, if applicable.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.environment.environment.Environment.close","title":"close()
","text":"Clean up any resources tied to the environment when it's no longer needed.
Source code in src/aeiva/environment/environment.py
def close(self):\n \"\"\"\n Clean up any resources tied to the environment when it's no longer needed.\n \"\"\"\n print(\"Closing environment and releasing resources.\")\n
"},{"location":"reference/#src.aeiva.environment.environment.Environment.get_context","title":"get_context()
","text":"Retrieve relevant context information from the environment, useful for agent processing.
Returns:
Name Type Description Any
Any
Contextual data or state relevant to the agent's tasks.
Source code in src/aeiva/environment/environment.py
def get_context(self) -> Any:\n \"\"\"\n Retrieve relevant context information from the environment, useful for agent processing.\n\n Returns:\n Any: Contextual data or state relevant to the agent's tasks.\n \"\"\"\n return self.state\n
"},{"location":"reference/#src.aeiva.environment.environment.Environment.observe","title":"observe(agent)
abstractmethod
","text":"Provide observations to an agent based on the current state.
Parameters:
Name Type Description Default agent
Any
The agent requesting observation.
required Returns:
Name Type Description Any
Any
Observation data formatted according to the agent's perception capabilities.
Source code in src/aeiva/environment/environment.py
@abstractmethod\ndef observe(self, agent: Any) -> Any:\n \"\"\"\n Provide observations to an agent based on the current state.\n\n Args:\n agent (Any): The agent requesting observation.\n\n Returns:\n Any: Observation data formatted according to the agent's perception capabilities.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.environment.environment.Environment.render","title":"render()
","text":"Visualize or output the environment's current state. Optional for subclasses.
Source code in src/aeiva/environment/environment.py
def render(self):\n \"\"\"\n Visualize or output the environment's current state. Optional for subclasses.\n \"\"\"\n print(f\"Environment State: {self.state}\")\n
"},{"location":"reference/#src.aeiva.environment.environment.Environment.reset","title":"reset()
abstractmethod
","text":"Reset the environment to its initial state as defined by the configuration.
Source code in src/aeiva/environment/environment.py
@abstractmethod\ndef reset(self):\n \"\"\"\n Reset the environment to its initial state as defined by the configuration.\n \"\"\"\n self.state = self.config.initial_state\n self.time = 0 if self.config.time_enabled else None\n
"},{"location":"reference/#src.aeiva.environment.environment.Environment.setup","title":"setup()
abstractmethod
","text":"Set up the environment based on its configuration. Subclasses should define any initialization logic here.
Source code in src/aeiva/environment/environment.py
@abstractmethod\ndef setup(self):\n \"\"\"\n Set up the environment based on its configuration.\n Subclasses should define any initialization logic here.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.environment.environment.Environment.step","title":"step(actions)
abstractmethod
","text":"Advance the environment by one step based on actions taken by agents.
Parameters:
Name Type Description Default actions
Dict[Any, Any]
A dictionary of actions performed by agents.
required Source code in src/aeiva/environment/environment.py
@abstractmethod\ndef step(self, actions: Dict[Any, Any]):\n \"\"\"\n Advance the environment by one step based on actions taken by agents.\n\n Args:\n actions (Dict[Any, Any]): A dictionary of actions performed by agents.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.environment.environment_config","title":"environment_config
","text":""},{"location":"reference/#src.aeiva.environment.environment_config.EnvironmentConfig","title":"EnvironmentConfig
dataclass
","text":" Bases: BaseConfig
Configuration class for initializing an environment.
Attributes:
Name Type Description environment_type
str
Type of the environment, see EnvironmentType class.
initial_state
Optional[Any]
Optional initial state of the environment.
constraints
Dict[str, Any]
Rules or constraints governing the environment.
entities
List[Any]
Entities present within the environment.
time_enabled
bool
Whether the environment tracks time progression.
Source code in src/aeiva/environment/environment_config.py
@dataclass\nclass EnvironmentConfig(BaseConfig):\n \"\"\"\n Configuration class for initializing an environment.\n\n Attributes:\n environment_type (str): Type of the environment, see EnvironmentType class.\n initial_state (Optional[Any]): Optional initial state of the environment.\n constraints (Dict[str, Any]): Rules or constraints governing the environment.\n entities (List[Any]): Entities present within the environment.\n time_enabled (bool): Whether the environment tracks time progression.\n \"\"\"\n\n environment_type: str = field(\n metadata={\"help\": \"Type of the environment (e.g., 'user', 'document', 'game').\"}\n )\n initial_state: Optional[Any] = field(\n default=None,\n metadata={\"help\": \"Optional initial state of the environment.\"}\n )\n constraints: Dict[str, Any] = field(\n default_factory=dict,\n metadata={\"help\": \"Rules or constraints for the environment.\"}\n )\n entities: List[Any] = field(\n default_factory=list,\n metadata={\"help\": \"Entities within the environment.\"}\n )\n time_enabled: bool = field(\n default=False,\n metadata={\"help\": \"Flag to enable time progression.\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n # Perform any necessary validation\n if not self.environment_type:\n raise ValueError(\"Environment type must be provided.\")\n
"},{"location":"reference/#src.aeiva.environment.environment_type","title":"environment_type
","text":""},{"location":"reference/#src.aeiva.environment.environment_type.EnvironmentType","title":"EnvironmentType
","text":"A class to hold constants for various environment types, organized by broad categories to maximize generality while supporting diverse use cases.
Categories - Interaction-Based: Environments with user or agent interaction.
- Digital: Environments involving digital interfaces, applications, or software systems.
- Data-Based: Static or dynamic data collections or document repositories.
- Virtual/Simulated: Simulated, spatial, or immersive virtual environments.
- World-Level: Comprehensive real or virtual world environments.
Source code in src/aeiva/environment/environment_type.py
class EnvironmentType:\n \"\"\"\n A class to hold constants for various environment types, organized by broad categories\n to maximize generality while supporting diverse use cases.\n\n Categories:\n - Interaction-Based: Environments with user or agent interaction.\n - Digital: Environments involving digital interfaces, applications, or software systems.\n - Data-Based: Static or dynamic data collections or document repositories.\n - Virtual/Simulated: Simulated, spatial, or immersive virtual environments.\n - World-Level: Comprehensive real or virtual world environments.\n \"\"\"\n\n # Interaction-Based Environments\n INTERACTIVE = \"Interactive\" # Environments involving user or multi-agent interaction.\n\n # Digital Environments\n DIGITAL_ENVIRONMENT = \"Digital Environment\" # Digital workspaces, applications, OS, or software systems.\n\n # Data-Based Environments\n DATA_REPOSITORY = \"Data Repository\" # Static datasets, dynamic data streams, or document repositories (e.g., knowledge bases).\n\n # Virtual/Simulated Environments\n VIRTUAL_ENVIRONMENT = \"Virtual Environment\" # Simulated or immersive 3D spaces, including games and VR.\n\n # World-Level Environments\n FULL_WORLD = \"Full World\" # Comprehensive virtual or real-world environment.\n\n # Meta/Complex Environments\n HYBRID_ENVIRONMENT = \"Hybrid Environment\" # Combination of multiple types.\n\n # Custom environment type for unique or unspecified cases.\n CUSTOM = \"Custom\"\n
"},{"location":"reference/#src.aeiva.event","title":"event
","text":""},{"location":"reference/#src.aeiva.event.event","title":"event
","text":""},{"location":"reference/#src.aeiva.event.event.Event","title":"Event
dataclass
","text":"Represents an event in the event bus system.
Attributes:
Name Type Description name
str
The name of the event.
payload
Any
The data associated with the event.
timestamp
datetime
The time the event was created.
priority
int
The priority of the event.
Source code in src/aeiva/event/event.py
@dataclass\nclass Event:\n \"\"\"\n Represents an event in the event bus system.\n\n Attributes:\n name (str): The name of the event.\n payload (Any): The data associated with the event.\n timestamp (datetime): The time the event was created.\n priority (int): The priority of the event.\n \"\"\"\n name: str\n payload: Any = None\n timestamp: datetime = field(default_factory=datetime.utcnow)\n priority: int = 0\n
"},{"location":"reference/#src.aeiva.event.event_bus","title":"event_bus
","text":""},{"location":"reference/#src.aeiva.event.event_bus.EventBus","title":"EventBus
","text":"An asynchronous event bus for publishing and subscribing to events.
Features: - Subscribers can use wildcard patterns to subscribe to multiple events. - Subscribers can cancel event propagation. - Subscribers can be set to auto-unsubscribe after one call. - Event-level prioritization in the queue. - Customizable error handling. - Logging for key actions. - emit, emit_after, and emit_only methods for flexible event emission.
Source code in src/aeiva/event/event_bus.py
class EventBus:\n \"\"\"\n An asynchronous event bus for publishing and subscribing to events.\n\n Features:\n - Subscribers can use wildcard patterns to subscribe to multiple events.\n - Subscribers can cancel event propagation.\n - Subscribers can be set to auto-unsubscribe after one call.\n - Event-level prioritization in the queue.\n - Customizable error handling.\n - Logging for key actions.\n - emit, emit_after, and emit_only methods for flexible event emission.\n \"\"\"\n\n def __init__(self):\n \"\"\"\n Initializes the event bus.\n \"\"\"\n self._subscribers: List[Dict] = [] # List of subscriber dictionaries\n self._event_queue = asyncio.PriorityQueue()\n self._processing_task: Optional[asyncio.Task] = None\n self._event_counter = 0 # Counter to maintain order of events with same priority\n self.loop = None\n\n def subscribe(\n self,\n event_pattern: str,\n callback: Callable[[Event], Any],\n *,\n priority: int = 0,\n once: bool = False\n ):\n \"\"\"\n Subscribes a callback function to events matching a pattern.\n\n Args:\n event_pattern (str): The event name or pattern to subscribe to.\n callback (Callable[[Event], Any]): The callback function.\n priority (int, optional): Priority of the callback.\n once (bool, optional): If True, unsubscribe after one call.\n \"\"\"\n subscriber = {\n 'pattern': re.compile(event_pattern.replace('*', '.*')),\n 'callback': callback,\n 'priority': priority,\n 'once': once\n }\n self._subscribers.append(subscriber)\n logger.info(f\"Subscribed '{callback.__name__}' to pattern '{event_pattern}' with priority {priority}.\")\n\n def unsubscribe(self, callback: Callable[[Event], Any]):\n \"\"\"\n Unsubscribes a callback function from all events.\n\n Args:\n callback (Callable[[Event], Any]): The callback function to remove.\n \"\"\"\n self._subscribers = [\n sub for sub in self._subscribers\n if sub['callback'] != callback\n ]\n logger.info(f\"Unsubscribed '{callback.__name__}' from all events.\")\n\n async def publish(self, event: Event, only: Union[str, List[str]] = None):\n \"\"\"\n Publishes an event to the event bus.\n\n Args:\n event (Event): The event to publish.\n only (str or List[str], optional): Names of specific subscribers to notify.\n \"\"\"\n self._event_counter += 1\n # Use a tuple of (priority, counter) to ensure proper ordering\n await self._event_queue.put((event.priority * -1, self._event_counter, event, only))\n logger.info(f\"Published event '{event.name}' with priority {event.priority}.\")\n\n async def _process_events(self):\n \"\"\"\n Internal coroutine that processes events from the queue and dispatches them to subscribers.\n \"\"\"\n while True:\n try:\n _, _, event, only = await self._event_queue.get()\n logger.info(f\"Processing event '{event.name}'.\")\n await self._dispatch_event(event, only)\n self._event_queue.task_done()\n except asyncio.CancelledError:\n # Exit the loop gracefully\n break\n except Exception as e:\n logger.error(f\"Error processing event: {e}\")\n self._event_queue.task_done()\n\n async def _dispatch_event(self, event: Event, only: Union[str, List[str]] = None):\n \"\"\"\n Dispatches an event to the appropriate subscribers.\n\n Args:\n event (Event): The event to dispatch.\n only (str or List[str], optional): Names of specific subscribers to notify.\n \"\"\"\n subscribers = sorted(\n [\n sub for sub in self._subscribers\n if sub['pattern'].fullmatch(event.name)\n and (only is None or sub['callback'].__name__ in (only if isinstance(only, list) else [only]))\n ],\n key=lambda x: x['priority'],\n reverse=True\n )\n for subscriber in subscribers:\n callback = subscriber['callback']\n try:\n if asyncio.iscoroutinefunction(callback):\n await callback(event)\n else:\n await asyncio.get_event_loop().run_in_executor(None, callback, event)\n except EventCancelled:\n logger.info(f\"Event '{event.name}' cancelled by '{callback.__name__}'.\")\n break # Stop further propagation\n except Exception as e:\n logger.error(f\"Error in callback '{callback.__name__}' for event '{event.name}': {e}\")\n self._handle_callback_exception(e, callback, event)\n finally:\n if subscriber.get('once'):\n self.unsubscribe(callback)\n\n def _handle_callback_exception(self, exception, callback, event):\n \"\"\"\n Handle exceptions raised by subscriber callbacks.\n\n Args:\n exception (Exception): The exception raised.\n callback (Callable): The subscriber callback.\n event (Event): The event being processed.\n \"\"\"\n # Default behavior is to log the exception.\n pass # Can be customized as needed.\n\n def start(self):\n \"\"\"\n Starts the event bus processing loop.\n \"\"\"\n if self._processing_task is None:\n self.loop = asyncio.get_running_loop()\n self._processing_task = asyncio.create_task(self._process_events())\n logger.info(\"Event bus started.\")\n\n def stop(self):\n \"\"\"\n Stops the event bus processing loop.\n \"\"\"\n if self._processing_task:\n self._processing_task.cancel()\n logger.info(\"Event bus stopped.\")\n\n def on(self, event_pattern: str, priority: int = 0, once: bool = False):\n \"\"\"\n Decorator for subscribing a function to events matching a pattern.\n\n Usage:\n @event_bus.on('event.*', priority=10)\n async def handler(event):\n ...\n\n Args:\n event_pattern (str): The event name or pattern to subscribe to.\n priority (int, optional): Priority of the callback.\n once (bool, optional): If True, unsubscribe after one call.\n\n Returns:\n Callable: The decorator function.\n \"\"\"\n def decorator(callback: Callable[[Event], Any]):\n self.subscribe(event_pattern, callback, priority=priority, once=once)\n return callback\n return decorator\n\n def emit_after(self, event_name: str, priority: int = 0):\n \"\"\"\n Decorator that emits an event after the decorated function is called.\n\n Usage:\n @event_bus.emit_after('event_name')\n def some_function():\n ...\n\n Args:\n event_name (str): The name of the event to emit after function execution.\n priority (int, optional): The priority of the event.\n\n Returns:\n Callable: The decorator function.\n \"\"\"\n def decorator(func: Callable):\n if asyncio.iscoroutinefunction(func):\n @wraps(func)\n async def async_wrapper(*args, **kwargs):\n result = await func(*args, **kwargs)\n await self.emit(event_name, priority=priority)\n return result\n return async_wrapper\n else:\n @wraps(func)\n def sync_wrapper(*args, **kwargs):\n result = func(*args, **kwargs)\n asyncio.create_task(self.emit(event_name, priority=priority))\n return result\n return sync_wrapper\n return decorator\n\n async def emit(self, event_name: str, payload: Any = None, priority: int = 0):\n \"\"\"\n Emits an event to all matching subscribers.\n\n Args:\n event_name (str): The name of the event to emit.\n payload (Any, optional): The payload of the event.\n priority (int, optional): The priority of the event.\n \"\"\"\n await self.publish(Event(name=event_name, payload=payload, priority=priority))\n\n async def emit_only(self, event_name: str, subscriber_names: Union[str, List[str]], payload: Any = None, priority: int = 0):\n \"\"\"\n Emits an event only to specified subscribers.\n\n Args:\n event_name (str): The name of the event to emit.\n subscriber_names (str or List[str]): The name(s) of subscribers to notify.\n payload (Any, optional): The payload of the event.\n priority (int, optional): The priority of the event.\n \"\"\"\n await self.publish(Event(name=event_name, payload=payload, priority=priority), only=subscriber_names)\n\n async def wait_until_all_events_processed(self):\n \"\"\"\n Waits until all events in the queue have been processed.\n \"\"\"\n await self._event_queue.join()\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.__init__","title":"__init__()
","text":"Initializes the event bus.
Source code in src/aeiva/event/event_bus.py
def __init__(self):\n \"\"\"\n Initializes the event bus.\n \"\"\"\n self._subscribers: List[Dict] = [] # List of subscriber dictionaries\n self._event_queue = asyncio.PriorityQueue()\n self._processing_task: Optional[asyncio.Task] = None\n self._event_counter = 0 # Counter to maintain order of events with same priority\n self.loop = None\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.emit","title":"emit(event_name, payload=None, priority=0)
async
","text":"Emits an event to all matching subscribers.
Parameters:
Name Type Description Default event_name
str
The name of the event to emit.
required payload
Any
The payload of the event.
None
priority
int
The priority of the event.
0
Source code in src/aeiva/event/event_bus.py
async def emit(self, event_name: str, payload: Any = None, priority: int = 0):\n \"\"\"\n Emits an event to all matching subscribers.\n\n Args:\n event_name (str): The name of the event to emit.\n payload (Any, optional): The payload of the event.\n priority (int, optional): The priority of the event.\n \"\"\"\n await self.publish(Event(name=event_name, payload=payload, priority=priority))\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.emit_after","title":"emit_after(event_name, priority=0)
","text":"Decorator that emits an event after the decorated function is called.
Usage @event_bus.emit_after('event_name') def some_function(): ...
Parameters:
Name Type Description Default event_name
str
The name of the event to emit after function execution.
required priority
int
The priority of the event.
0
Returns:
Name Type Description Callable
The decorator function.
Source code in src/aeiva/event/event_bus.py
def emit_after(self, event_name: str, priority: int = 0):\n \"\"\"\n Decorator that emits an event after the decorated function is called.\n\n Usage:\n @event_bus.emit_after('event_name')\n def some_function():\n ...\n\n Args:\n event_name (str): The name of the event to emit after function execution.\n priority (int, optional): The priority of the event.\n\n Returns:\n Callable: The decorator function.\n \"\"\"\n def decorator(func: Callable):\n if asyncio.iscoroutinefunction(func):\n @wraps(func)\n async def async_wrapper(*args, **kwargs):\n result = await func(*args, **kwargs)\n await self.emit(event_name, priority=priority)\n return result\n return async_wrapper\n else:\n @wraps(func)\n def sync_wrapper(*args, **kwargs):\n result = func(*args, **kwargs)\n asyncio.create_task(self.emit(event_name, priority=priority))\n return result\n return sync_wrapper\n return decorator\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.emit_only","title":"emit_only(event_name, subscriber_names, payload=None, priority=0)
async
","text":"Emits an event only to specified subscribers.
Parameters:
Name Type Description Default event_name
str
The name of the event to emit.
required subscriber_names
str or List[str]
The name(s) of subscribers to notify.
required payload
Any
The payload of the event.
None
priority
int
The priority of the event.
0
Source code in src/aeiva/event/event_bus.py
async def emit_only(self, event_name: str, subscriber_names: Union[str, List[str]], payload: Any = None, priority: int = 0):\n \"\"\"\n Emits an event only to specified subscribers.\n\n Args:\n event_name (str): The name of the event to emit.\n subscriber_names (str or List[str]): The name(s) of subscribers to notify.\n payload (Any, optional): The payload of the event.\n priority (int, optional): The priority of the event.\n \"\"\"\n await self.publish(Event(name=event_name, payload=payload, priority=priority), only=subscriber_names)\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.on","title":"on(event_pattern, priority=0, once=False)
","text":"Decorator for subscribing a function to events matching a pattern.
Usage @event_bus.on('event.*', priority=10) async def handler(event): ...
Parameters:
Name Type Description Default event_pattern
str
The event name or pattern to subscribe to.
required priority
int
Priority of the callback.
0
once
bool
If True, unsubscribe after one call.
False
Returns:
Name Type Description Callable
The decorator function.
Source code in src/aeiva/event/event_bus.py
def on(self, event_pattern: str, priority: int = 0, once: bool = False):\n \"\"\"\n Decorator for subscribing a function to events matching a pattern.\n\n Usage:\n @event_bus.on('event.*', priority=10)\n async def handler(event):\n ...\n\n Args:\n event_pattern (str): The event name or pattern to subscribe to.\n priority (int, optional): Priority of the callback.\n once (bool, optional): If True, unsubscribe after one call.\n\n Returns:\n Callable: The decorator function.\n \"\"\"\n def decorator(callback: Callable[[Event], Any]):\n self.subscribe(event_pattern, callback, priority=priority, once=once)\n return callback\n return decorator\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.publish","title":"publish(event, only=None)
async
","text":"Publishes an event to the event bus.
Parameters:
Name Type Description Default event
Event
The event to publish.
required only
str or List[str]
Names of specific subscribers to notify.
None
Source code in src/aeiva/event/event_bus.py
async def publish(self, event: Event, only: Union[str, List[str]] = None):\n \"\"\"\n Publishes an event to the event bus.\n\n Args:\n event (Event): The event to publish.\n only (str or List[str], optional): Names of specific subscribers to notify.\n \"\"\"\n self._event_counter += 1\n # Use a tuple of (priority, counter) to ensure proper ordering\n await self._event_queue.put((event.priority * -1, self._event_counter, event, only))\n logger.info(f\"Published event '{event.name}' with priority {event.priority}.\")\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.start","title":"start()
","text":"Starts the event bus processing loop.
Source code in src/aeiva/event/event_bus.py
def start(self):\n \"\"\"\n Starts the event bus processing loop.\n \"\"\"\n if self._processing_task is None:\n self.loop = asyncio.get_running_loop()\n self._processing_task = asyncio.create_task(self._process_events())\n logger.info(\"Event bus started.\")\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.stop","title":"stop()
","text":"Stops the event bus processing loop.
Source code in src/aeiva/event/event_bus.py
def stop(self):\n \"\"\"\n Stops the event bus processing loop.\n \"\"\"\n if self._processing_task:\n self._processing_task.cancel()\n logger.info(\"Event bus stopped.\")\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.subscribe","title":"subscribe(event_pattern, callback, *, priority=0, once=False)
","text":"Subscribes a callback function to events matching a pattern.
Parameters:
Name Type Description Default event_pattern
str
The event name or pattern to subscribe to.
required callback
Callable[[Event], Any]
The callback function.
required priority
int
Priority of the callback.
0
once
bool
If True, unsubscribe after one call.
False
Source code in src/aeiva/event/event_bus.py
def subscribe(\n self,\n event_pattern: str,\n callback: Callable[[Event], Any],\n *,\n priority: int = 0,\n once: bool = False\n):\n \"\"\"\n Subscribes a callback function to events matching a pattern.\n\n Args:\n event_pattern (str): The event name or pattern to subscribe to.\n callback (Callable[[Event], Any]): The callback function.\n priority (int, optional): Priority of the callback.\n once (bool, optional): If True, unsubscribe after one call.\n \"\"\"\n subscriber = {\n 'pattern': re.compile(event_pattern.replace('*', '.*')),\n 'callback': callback,\n 'priority': priority,\n 'once': once\n }\n self._subscribers.append(subscriber)\n logger.info(f\"Subscribed '{callback.__name__}' to pattern '{event_pattern}' with priority {priority}.\")\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.unsubscribe","title":"unsubscribe(callback)
","text":"Unsubscribes a callback function from all events.
Parameters:
Name Type Description Default callback
Callable[[Event], Any]
The callback function to remove.
required Source code in src/aeiva/event/event_bus.py
def unsubscribe(self, callback: Callable[[Event], Any]):\n \"\"\"\n Unsubscribes a callback function from all events.\n\n Args:\n callback (Callable[[Event], Any]): The callback function to remove.\n \"\"\"\n self._subscribers = [\n sub for sub in self._subscribers\n if sub['callback'] != callback\n ]\n logger.info(f\"Unsubscribed '{callback.__name__}' from all events.\")\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventBus.wait_until_all_events_processed","title":"wait_until_all_events_processed()
async
","text":"Waits until all events in the queue have been processed.
Source code in src/aeiva/event/event_bus.py
async def wait_until_all_events_processed(self):\n \"\"\"\n Waits until all events in the queue have been processed.\n \"\"\"\n await self._event_queue.join()\n
"},{"location":"reference/#src.aeiva.event.event_bus.EventCancelled","title":"EventCancelled
","text":" Bases: Exception
Exception to indicate that an event has been cancelled.
Source code in src/aeiva/event/event_bus.py
class EventCancelled(Exception):\n \"\"\"Exception to indicate that an event has been cancelled.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.hypergraph","title":"hypergraph
","text":""},{"location":"reference/#src.aeiva.hypergraph.exceptions","title":"exceptions
","text":""},{"location":"reference/#src.aeiva.hypergraph.exceptions.HypergraphError","title":"HypergraphError
","text":" Bases: Exception
Custom exception class for Hypergraph-related errors.
Source code in src/aeiva/hypergraph/exceptions.py
class HypergraphError(Exception):\n \"\"\"\n Custom exception class for Hypergraph-related errors.\n \"\"\"\n def __init__(self, message: str = \"An error occurred in the Hypergraph module.\"):\n super().__init__(message)\n
"},{"location":"reference/#src.aeiva.hypergraph.hyperedge","title":"hyperedge
","text":""},{"location":"reference/#src.aeiva.hypergraph.hyperedge.HyperEdge","title":"HyperEdge
","text":"Represents a hyperedge in the hypergraph, encapsulating its properties and connected nodes.
Source code in src/aeiva/hypergraph/hyperedge.py
class HyperEdge:\n \"\"\"\n Represents a hyperedge in the hypergraph, encapsulating its properties and connected nodes.\n \"\"\"\n\n def __init__(\n self,\n id: Any,\n nodes: Optional[Iterable[Any]] = None,\n properties: Optional[Dict[str, Any]] = None\n ) -> None:\n \"\"\"\n Initializes a HyperEdge.\n\n Parameters:\n id: Unique identifier for the hyperedge.\n nodes: (Optional) Iterable of node identifiers connected by the hyperedge.\n properties: (Optional) Dictionary of properties.\n \"\"\"\n self.id: Any = id\n self.nodes: Set[Any] = set(nodes) if nodes else set()\n self.properties: Dict[str, Any] = properties.copy() if properties else {}\n\n def add_node(self, node_id: Any) -> None:\n \"\"\"\n Adds a node to the hyperedge.\n\n Parameters:\n node_id: Identifier of the node to add.\n \"\"\"\n self.nodes.add(node_id)\n\n def remove_node(self, node_id: Any) -> None:\n \"\"\"\n Removes a node from the hyperedge.\n\n Parameters:\n node_id: Identifier of the node to remove.\n \"\"\"\n if node_id in self.nodes:\n self.nodes.remove(node_id)\n else:\n raise HypergraphError(f\"Node '{node_id}' not found in HyperEdge '{self.id}'.\")\n\n def add_property(self, key: str, value: Any) -> None:\n \"\"\"\n Adds or updates a property of the hyperedge.\n\n Parameters:\n key: Property name.\n value: Property value.\n \"\"\"\n self.properties[key] = value\n\n def get_property(self, key: str) -> Any:\n \"\"\"\n Retrieves a property of the hyperedge.\n\n Parameters:\n key: Property name.\n\n Returns:\n The value of the property.\n\n Raises:\n HypergraphError: If the property does not exist.\n \"\"\"\n if key in self.properties:\n return self.properties[key]\n else:\n raise HypergraphError(f\"Property '{key}' does not exist for HyperEdge '{self.id}'.\")\n\n def remove_property(self, key: str) -> None:\n \"\"\"\n Removes a property from the hyperedge.\n\n Parameters:\n key: Property name.\n\n Raises:\n HypergraphError: If the property does not exist.\n \"\"\"\n if key in self.properties:\n del self.properties[key]\n else:\n raise HypergraphError(f\"Property '{key}' does not exist for HyperEdge '{self.id}'.\")\n\n def to_dict(self):\n return {\n \"id\": self.id,\n \"nodes\": self.nodes,\n \"properties\": self.properties\n }\n
"},{"location":"reference/#src.aeiva.hypergraph.hyperedge.HyperEdge.__init__","title":"__init__(id, nodes=None, properties=None)
","text":"Initializes a HyperEdge.
Parameters:
Name Type Description Default id
Any
Unique identifier for the hyperedge.
required nodes
Optional[Iterable[Any]]
(Optional) Iterable of node identifiers connected by the hyperedge.
None
properties
Optional[Dict[str, Any]]
(Optional) Dictionary of properties.
None
Source code in src/aeiva/hypergraph/hyperedge.py
def __init__(\n self,\n id: Any,\n nodes: Optional[Iterable[Any]] = None,\n properties: Optional[Dict[str, Any]] = None\n) -> None:\n \"\"\"\n Initializes a HyperEdge.\n\n Parameters:\n id: Unique identifier for the hyperedge.\n nodes: (Optional) Iterable of node identifiers connected by the hyperedge.\n properties: (Optional) Dictionary of properties.\n \"\"\"\n self.id: Any = id\n self.nodes: Set[Any] = set(nodes) if nodes else set()\n self.properties: Dict[str, Any] = properties.copy() if properties else {}\n
"},{"location":"reference/#src.aeiva.hypergraph.hyperedge.HyperEdge.add_node","title":"add_node(node_id)
","text":"Adds a node to the hyperedge.
Parameters:
Name Type Description Default node_id
Any
Identifier of the node to add.
required Source code in src/aeiva/hypergraph/hyperedge.py
def add_node(self, node_id: Any) -> None:\n \"\"\"\n Adds a node to the hyperedge.\n\n Parameters:\n node_id: Identifier of the node to add.\n \"\"\"\n self.nodes.add(node_id)\n
"},{"location":"reference/#src.aeiva.hypergraph.hyperedge.HyperEdge.add_property","title":"add_property(key, value)
","text":"Adds or updates a property of the hyperedge.
Parameters:
Name Type Description Default key
str
Property name.
required value
Any
Property value.
required Source code in src/aeiva/hypergraph/hyperedge.py
def add_property(self, key: str, value: Any) -> None:\n \"\"\"\n Adds or updates a property of the hyperedge.\n\n Parameters:\n key: Property name.\n value: Property value.\n \"\"\"\n self.properties[key] = value\n
"},{"location":"reference/#src.aeiva.hypergraph.hyperedge.HyperEdge.get_property","title":"get_property(key)
","text":"Retrieves a property of the hyperedge.
Parameters:
Name Type Description Default key
str
Property name.
required Returns:
Type Description Any
The value of the property.
Raises:
Type Description HypergraphError
If the property does not exist.
Source code in src/aeiva/hypergraph/hyperedge.py
def get_property(self, key: str) -> Any:\n \"\"\"\n Retrieves a property of the hyperedge.\n\n Parameters:\n key: Property name.\n\n Returns:\n The value of the property.\n\n Raises:\n HypergraphError: If the property does not exist.\n \"\"\"\n if key in self.properties:\n return self.properties[key]\n else:\n raise HypergraphError(f\"Property '{key}' does not exist for HyperEdge '{self.id}'.\")\n
"},{"location":"reference/#src.aeiva.hypergraph.hyperedge.HyperEdge.remove_node","title":"remove_node(node_id)
","text":"Removes a node from the hyperedge.
Parameters:
Name Type Description Default node_id
Any
Identifier of the node to remove.
required Source code in src/aeiva/hypergraph/hyperedge.py
def remove_node(self, node_id: Any) -> None:\n \"\"\"\n Removes a node from the hyperedge.\n\n Parameters:\n node_id: Identifier of the node to remove.\n \"\"\"\n if node_id in self.nodes:\n self.nodes.remove(node_id)\n else:\n raise HypergraphError(f\"Node '{node_id}' not found in HyperEdge '{self.id}'.\")\n
"},{"location":"reference/#src.aeiva.hypergraph.hyperedge.HyperEdge.remove_property","title":"remove_property(key)
","text":"Removes a property from the hyperedge.
Parameters:
Name Type Description Default key
str
Property name.
required Raises:
Type Description HypergraphError
If the property does not exist.
Source code in src/aeiva/hypergraph/hyperedge.py
def remove_property(self, key: str) -> None:\n \"\"\"\n Removes a property from the hyperedge.\n\n Parameters:\n key: Property name.\n\n Raises:\n HypergraphError: If the property does not exist.\n \"\"\"\n if key in self.properties:\n del self.properties[key]\n else:\n raise HypergraphError(f\"Property '{key}' does not exist for HyperEdge '{self.id}'.\")\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph","title":"hypergraph
","text":""},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph","title":"Hypergraph
","text":"A simplified Hypergraph class using dictionaries and NetworkX for management.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph--parameters","title":"Parameters","text":"hyperedges : Dict[Any, Dict[str, Any]] A dictionary where keys are hyperedge identifiers and values are dictionaries containing: - 'nodes': Iterable of node identifiers connected by the hyperedge. - 'properties': (Optional) Dictionary of properties for the hyperedge.
Optional[Dict[Any, Dict[str, Any]]] = None A dictionary where keys are node identifiers and values are dictionaries of node properties.
Optional[Dict[Any, Dict[str, Any]]] = None A dictionary where keys are hyperedge identifiers and values are dictionaries of hyperedge properties.
Optional[str] = None Name assigned to the hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
class Hypergraph:\n \"\"\"\n A simplified Hypergraph class using dictionaries and NetworkX for management.\n\n Parameters\n ----------\n hyperedges : Dict[Any, Dict[str, Any]]\n A dictionary where keys are hyperedge identifiers and values are dictionaries containing:\n - 'nodes': Iterable of node identifiers connected by the hyperedge.\n - 'properties': (Optional) Dictionary of properties for the hyperedge.\n\n node_properties : Optional[Dict[Any, Dict[str, Any]]] = None\n A dictionary where keys are node identifiers and values are dictionaries of node properties.\n\n hyperedge_properties : Optional[Dict[Any, Dict[str, Any]]] = None\n A dictionary where keys are hyperedge identifiers and values are dictionaries of hyperedge properties.\n\n name : Optional[str] = None\n Name assigned to the hypergraph.\n \"\"\"\n\n def __init__(\n self,\n hyperedges: Dict[Any, Dict[str, Any]],\n node_properties: Optional[Dict[Any, Dict[str, Any]]] = None,\n hyperedge_properties: Optional[Dict[Any, Dict[str, Any]]] = None,\n name: Optional[str] = None\n ):\n self.name = name\n self.graph = nx.Graph()\n self.bipartite_nodes: Set[Any] = set()\n\n # Initialize node and hyperedge properties using deep copies to ensure full duplication\n self.node_properties: Dict[Any, Dict[str, Any]] = copy.deepcopy(node_properties) if node_properties else {}\n self.hyperedge_properties: Dict[Any, Dict[str, Any]] = copy.deepcopy(hyperedge_properties) if hyperedge_properties else {}\n\n # Add hyperedges and their connections to nodes\n self.hyperedges: Dict[Any, HyperEdge] = {}\n for he_id, he_data in hyperedges.items():\n nodes = he_data.get('nodes', [])\n properties = he_data.get('properties', {})\n hyperedge = HyperEdge(id=he_id, nodes=nodes, properties=properties)\n self.hyperedges[he_id] = hyperedge\n\n # Add hyperedge to bipartite graph with properties\n self.graph.add_node(he_id, bipartite='hyperedge', **self.hyperedge_properties.get(he_id, {}))\n self.bipartite_nodes.add(he_id)\n\n # Add edges between hyperedge and nodes with node properties\n for node in hyperedge.nodes:\n if node not in self.graph:\n self.graph.add_node(node, bipartite='node', **self.node_properties.get(node, {}))\n self.graph.add_edge(he_id, node)\n\n def dual(self, name: Optional[str] = None) -> \"Hypergraph\":\n \"\"\"\n Constructs the dual of the current hypergraph by reversing the roles of nodes and hyperedges.\n\n Parameters\n ----------\n name : Optional[str], default=None\n Name for the dual hypergraph. If None, defaults to the original hypergraph's name with '_dual' appended.\n\n Returns\n -------\n Hypergraph\n A new Hypergraph instance representing the dual of the current hypergraph.\n \"\"\"\n # Initialize dual hyperedges, which will correspond to original nodes\n dual_hyperedges = {}\n\n # Invert the node-hyperedge structure\n for he_id, hyperedge in self.hyperedges.items():\n for node in hyperedge.nodes:\n # Each original node becomes a hyperedge in the dual\n if node not in dual_hyperedges:\n dual_hyperedges[node] = {'nodes': [], 'properties': self.node_properties.get(node, {})}\n # The new hyperedge (original node) connects to the original hyperedge id as a \"node\"\n dual_hyperedges[node]['nodes'].append(he_id)\n\n # Define node properties in the dual as the original hyperedge properties\n dual_node_properties = {he_id: self.hyperedge_properties.get(he_id, {}) for he_id in self.hyperedges}\n\n # Create and return the dual Hypergraph\n return Hypergraph(\n hyperedges=dual_hyperedges,\n node_properties=dual_node_properties,\n hyperedge_properties=self.node_properties, # Properties of original nodes now apply to dual hyperedges\n name=name or (self.name + \"_dual\" if self.name else \"dual\")\n )\n\n def nodes(self) -> List[Any]:\n \"\"\"\n Returns a list of all unique node identifiers in the hypergraph.\n\n Returns\n -------\n List[Any]\n List of node IDs.\n \"\"\"\n return list(self.node_properties.keys())\n\n def node_memberships(self) -> Dict[Any, List[Any]]:\n \"\"\"\n Returns a dictionary where each key is a node ID and the value is a list of hyperedge IDs that include the node.\n\n Returns\n -------\n Dict[Any, List[Any]]\n Dictionary mapping node IDs to the hyperedge IDs they belong to.\n \"\"\"\n memberships = {}\n for he_id, hyperedge in self.hyperedges.items():\n for node in hyperedge.nodes:\n memberships.setdefault(node, []).append(he_id)\n return memberships\n\n def edges(self) -> List[Any]:\n \"\"\"\n Returns a list of all hyperedge identifiers in the hypergraph.\n\n Returns\n -------\n List[Any]\n List of hyperedge IDs.\n \"\"\"\n return list(self.hyperedges.keys())\n\n def edge_elements(self) -> Dict[Any, List[Any]]:\n \"\"\"\n Returns a dictionary where each key is a hyperedge ID and the value is a list of node IDs within that hyperedge.\n\n Returns\n -------\n Dict[Any, List[Any]]\n Dictionary mapping hyperedge IDs to lists of node IDs they contain.\n \"\"\"\n return {he_id: hyperedge.nodes for he_id, hyperedge in self.hyperedges.items()}\n\n def __str__(self) -> str:\n \"\"\"\n String representation of the hypergraph.\n\n Returns\n -------\n str\n A string describing the hypergraph with its name, number of nodes, and hyperedges.\n \"\"\"\n return f\"Hypergraph '{self.name}' with {len(self)} nodes and {len(self.hyperedges)} hyperedges.\"\n\n def __repr__(self) -> str:\n \"\"\"\n Official string representation of the hypergraph.\n\n Returns\n -------\n str\n A detailed string describing the hypergraph with its name, number of nodes, and hyperedges.\n \"\"\"\n return (\n f\"Hypergraph(name={self.name!r}, \"\n f\"nodes={len(self)}, hyperedges={len(self.hyperedges)})\"\n )\n\n def __len__(self) -> int:\n \"\"\"\n Returns the number of nodes in the hypergraph.\n\n Returns\n -------\n int\n Number of nodes.\n \"\"\"\n return len(self.node_properties)\n\n def __iter__(self) -> Iterator[Any]:\n \"\"\"\n Allows iteration over the nodes of the hypergraph.\n\n Yields\n ------\n Any\n Node identifiers.\n \"\"\"\n return iter(self.node_properties)\n\n def __contains__(self, item: Any) -> bool:\n \"\"\"\n Checks if a node is in the hypergraph.\n\n Parameters\n ----------\n item : Any\n The node identifier to check.\n\n Returns\n -------\n bool\n True if the node exists in the hypergraph, False otherwise.\n \"\"\"\n return item in self.node_properties\n\n def __getitem__(self, node: Any) -> Iterable[Any]:\n \"\"\"\n Retrieves the neighbors of a node in the hypergraph.\n\n Neighbors are nodes that share at least one hyperedge with the given node.\n\n Parameters\n ----------\n node : Any\n The node identifier.\n\n Returns\n -------\n Iterable[Any]\n An iterator over neighboring node identifiers.\n\n Raises\n ------\n HypergraphError\n If the node does not exist in the hypergraph.\n \"\"\"\n if node not in self.node_properties:\n raise HypergraphError(f\"Node '{node}' does not exist in the hypergraph.\")\n\n # Get all hyperedges that include the node\n hyperedges = set(self.graph.neighbors(node))\n\n # Get all nodes connected by these hyperedges\n neighbors = set()\n for he_id in hyperedges:\n neighbors.update(self.hyperedges[he_id].nodes)\n\n neighbors.discard(node) # Remove the node itself\n return neighbors\n\n def __eq__(self, other: Any) -> bool:\n \"\"\"\n Checks if two hypergraphs are equal based on their hyperedges and nodes.\n\n Parameters\n ----------\n other : Any\n The other object to compare.\n\n Returns\n -------\n bool\n True if both hypergraphs have identical nodes and hyperedges with the same properties, False otherwise.\n \"\"\"\n if not isinstance(other, Hypergraph):\n return False\n\n # Compare nodes and their properties\n if self.node_properties != other.node_properties:\n return False\n\n # Compare hyperedges and their properties\n if self.hyperedges.keys() != other.hyperedges.keys():\n return False\n\n for he_id in self.hyperedges:\n if self.hyperedges[he_id].nodes != other.hyperedges[he_id].nodes:\n return False\n if self.hyperedge_properties.get(he_id, {}) != other.hyperedge_properties.get(he_id, {}):\n return False\n\n return True\n\n def copy(self, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Creates a deep copy of the hypergraph instance.\n\n Parameters\n ----------\n name : Optional[str], default=None\n The name for the copied Hypergraph. If not provided, retains the original name.\n\n Returns\n -------\n Hypergraph\n A new Hypergraph instance that is a deep copy of the original.\n \"\"\"\n\n # Deep copy hyperedges\n hyperedges_dict = {}\n for he_id, he in self.hyperedges.items():\n hyperedges_dict[he_id] = {\n 'nodes': list(he.nodes),\n 'properties': copy.deepcopy(he.properties)\n }\n\n # Deep copy node_properties and hyperedge_properties\n node_properties_copy = copy.deepcopy(self.node_properties)\n hyperedge_properties_copy = copy.deepcopy(self.hyperedge_properties)\n\n # Create a new Hypergraph instance with the copied data\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=node_properties_copy,\n hyperedge_properties=hyperedge_properties_copy,\n name=name if name is not None else self.name\n )\n\n def deepcopy(self, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Creates a deep copy of the hypergraph.\n\n Parameters\n ----------\n name : Optional[str], default=None\n The name assigned to the cloned hypergraph. If None, defaults to the original hypergraph's name suffixed with '_clone'.\n\n Returns\n -------\n Hypergraph\n A deep copy of the hypergraph.\n \"\"\"\n\n # Deep copy hyperedges\n hyperedges_copy = {\n he_id: {\n 'nodes': hyperedge.nodes.copy(),\n 'properties': copy.deepcopy(hyperedge.properties)\n }\n for he_id, hyperedge in self.hyperedges.items()\n }\n\n # Deep copy node properties\n node_properties_copy = copy.deepcopy(self.node_properties)\n\n # Deep copy hyperedge properties\n hyperedge_properties_copy = copy.deepcopy(self.hyperedge_properties)\n\n # Set name\n cloned_name = f\"{self.name}_deepcopy\" if name is None else name\n\n # Initialize the cloned hypergraph\n cloned_H = Hypergraph(\n hyperedges=hyperedges_copy,\n node_properties=node_properties_copy,\n hyperedge_properties=hyperedge_properties_copy,\n name=cloned_name\n )\n\n return cloned_H\n\n # Adding and Removing Hyperedges and Nodes\n\n def add_hyperedge(\n self,\n he_id: Any,\n nodes: Iterable[Any],\n properties: Optional[Dict[str, Any]] = None\n ) -> None:\n \"\"\"\n Adds a hyperedge to the hypergraph.\n\n Parameters\n ----------\n he_id : Any\n Unique identifier for the hyperedge.\n nodes : Iterable[Any]\n Nodes connected by the hyperedge.\n properties : Optional[Dict[str, Any]] = None\n Properties of the hyperedge.\n\n Raises\n ------\n HypergraphError\n If the hyperedge ID already exists.\n \"\"\"\n if he_id in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' already exists.\")\n\n hyperedge = HyperEdge(id=he_id, nodes=nodes, properties=properties)\n self.hyperedges[he_id] = hyperedge\n self.hyperedge_properties[he_id] = copy.deepcopy(properties) if properties else {}\n\n # Add hyperedge to bipartite graph\n self.graph.add_node(he_id, bipartite='hyperedge', **self.hyperedge_properties[he_id])\n self.bipartite_nodes.add(he_id)\n\n # Add edges between hyperedge and nodes\n for node in hyperedge.nodes:\n if node not in self.graph:\n self.graph.add_node(node, bipartite='node', **self.node_properties.get(node, {}))\n self.graph.add_edge(he_id, node)\n\n def remove_hyperedge(self, he_id: Any) -> None:\n \"\"\"\n Removes a hyperedge from the hypergraph.\n\n Parameters\n ----------\n he_id : Any\n Identifier of the hyperedge to remove.\n\n Raises\n ------\n HypergraphError\n If the hyperedge does not exist.\n \"\"\"\n if he_id not in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' does not exist.\")\n\n # Remove hyperedge from the graph, which also removes all incidences\n self.graph.remove_node(he_id)\n self.bipartite_nodes.discard(he_id)\n\n # Remove from internal structures\n del self.hyperedges[he_id]\n self.hyperedge_properties.pop(he_id, None)\n\n def add_hyperedges_from(\n self,\n hyperedges: Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]],\n inplace: bool = True\n ) -> 'Hypergraph':\n \"\"\"\n Adds multiple hyperedges with attributes to the hypergraph.\n\n Parameters\n ----------\n hyperedges : Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]]\n An iterable of hyperedge identifiers or tuples of (he_id, attributes).\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added hyperedges.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any hyperedge ID already exists.\n ValueError\n If any tuple does not contain exactly two elements or if attributes are not dictionaries.\n \"\"\"\n new_hyperedges = []\n for item in hyperedges:\n if isinstance(item, tuple):\n if len(item) != 2 or not isinstance(item[1], dict):\n raise ValueError(f\"Each tuple must be of the form (he_id, attributes). Invalid tuple: {item}\")\n he_id, attrs = item\n else:\n he_id, attrs = item, {}\n\n if he_id in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' already exists.\")\n\n hyperedge = HyperEdge(id=he_id, nodes=[], properties=attrs.copy())\n new_hyperedges.append(hyperedge)\n\n if inplace:\n for hyperedge in new_hyperedges:\n self.hyperedges[hyperedge.id] = hyperedge\n self.hyperedge_properties[hyperedge.id] = copy.deepcopy(hyperedge.properties)\n self.graph.add_node(hyperedge.id, bipartite='hyperedge', **self.hyperedge_properties[hyperedge.id])\n self.bipartite_nodes.add(hyperedge.id)\n return self\n else:\n # Create a new Hypergraph instance with added hyperedges\n new_hyperedges_dict = copy.deepcopy(self.hyperedges)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for hyperedge in new_hyperedges:\n new_hyperedges_dict[hyperedge.id] = hyperedge\n new_hyperedge_properties[hyperedge.id] = copy.deepcopy(hyperedge.properties)\n new_graph.add_node(hyperedge.id, bipartite='hyperedge', **new_hyperedge_properties[hyperedge.id])\n new_bipartite_nodes.add(hyperedge.id)\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges_dict.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n\n def add_node(\n self,\n node_id: Any,\n properties: Optional[Dict[str, Any]] = None,\n inplace: bool = True\n ) -> 'Hypergraph':\n \"\"\"\n Adds a node to the hypergraph.\n\n Parameters\n ----------\n node_id : Any\n Identifier for the node.\n properties : Optional[Dict[str, Any]] = None\n Properties of the node.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added node.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If the node ID already exists.\n \"\"\"\n if node_id in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' already exists in the hypergraph.\")\n\n if inplace:\n self.node_properties[node_id] = copy.deepcopy(properties) if properties else {}\n self.graph.add_node(node_id, bipartite='node', **self.node_properties[node_id])\n return self\n else:\n # Create a new Hypergraph instance with the added node\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n new_node_properties[node_id] = copy.deepcopy(properties) if properties else {}\n new_graph.add_node(node_id, bipartite='node', **new_node_properties[node_id])\n\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n\n def remove_node(self, node_id: Any, inplace: bool = True) -> 'Hypergraph':\n \"\"\"\n Removes a node from the hypergraph.\n\n Parameters\n ----------\n node_id : Any\n Identifier of the node to remove.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the node removed.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If the node does not exist.\n \"\"\"\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n\n if inplace:\n # Remove node from node_properties\n del self.node_properties[node_id]\n # Remove node from all hyperedges\n for hyperedge in self.hyperedges.values():\n if node_id in hyperedge.nodes:\n hyperedge.remove_node(node_id)\n # Remove node from graph, which also removes all incidences\n self.graph.remove_node(node_id)\n return self\n else:\n # Create a new Hypergraph instance with the node removed\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n # Remove node from node_properties\n del new_node_properties[node_id]\n # Remove node from all hyperedges\n for hyperedge in new_hyperedges.values():\n if node_id in hyperedge.nodes:\n hyperedge.remove_node(node_id)\n # Remove node from graph, which also removes all incidences\n new_graph.remove_node(node_id)\n\n # Remove nodes not connected to any hyperedges\n retained_nodes = set()\n for hyperedge in new_hyperedges.values():\n retained_nodes.update(hyperedge.nodes)\n\n new_node_properties = {node: props for node, props in new_node_properties.items() if node in retained_nodes}\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n\n def add_nodes_from(\n self,\n nodes: Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]],\n inplace: bool = True\n ) -> 'Hypergraph':\n \"\"\"\n Adds multiple nodes with attributes to the hypergraph.\n\n Parameters\n ----------\n nodes : Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]]\n An iterable of node identifiers or tuples of (node_id, attributes).\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added nodes.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any node ID already exists.\n ValueError\n If any tuple does not contain exactly two elements or if attributes are not dictionaries.\n \"\"\"\n new_nodes = {}\n for item in nodes:\n if isinstance(item, tuple):\n if len(item) != 2 or not isinstance(item[1], dict):\n raise ValueError(f\"Each tuple must be of the form (node_id, attributes). Invalid tuple: {item}\")\n node_id, attrs = item\n else:\n node_id, attrs = item, {}\n\n if node_id in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' already exists in the hypergraph.\")\n\n new_nodes[node_id] = copy.deepcopy(attrs)\n\n if inplace:\n for node_id, attrs in new_nodes.items():\n self.node_properties[node_id] = attrs\n self.graph.add_node(node_id, bipartite='node', **self.node_properties[node_id])\n return self\n else:\n # Create a new Hypergraph instance with the added nodes\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for node_id, attrs in new_nodes.items():\n new_node_properties[node_id] = attrs\n new_graph.add_node(node_id, bipartite='node', **new_node_properties[node_id])\n\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n\n def remove_hyperedges(self, he_ids: Union[Any, Iterable[Any]], inplace: bool = True) -> 'Hypergraph':\n \"\"\"\n Removes the specified hyperedges from the hypergraph.\n\n Parameters\n ----------\n he_ids : Any | Iterable[Any]\n Hyperedge identifier(s) to remove.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the hyperedges removed.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any hyperedge ID does not exist.\n \"\"\"\n if isinstance(he_ids, (str, int)):\n he_ids = [he_ids]\n else:\n he_ids = list(he_ids)\n\n non_existing = set(he_ids) - set(self.hyperedges.keys())\n if non_existing:\n raise HypergraphError(f\"Hyperedges {non_existing} do not exist in the hypergraph.\")\n\n if inplace:\n for he_id in he_ids:\n self.remove_hyperedge(he_id)\n return self\n else:\n # Create a new Hypergraph instance with hyperedges removed\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for he_id in he_ids:\n del new_hyperedges[he_id]\n new_hyperedge_properties.pop(he_id, None)\n new_graph.remove_node(he_id)\n new_bipartite_nodes.discard(he_id)\n\n # Remove nodes not connected to any hyperedges\n retained_nodes = set()\n for hyperedge in new_hyperedges.values():\n retained_nodes.update(hyperedge.nodes)\n\n new_node_properties = {node: props for node, props in new_node_properties.items() if node in retained_nodes}\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n\n def remove_nodes_from(\n self,\n nodes: Union[Any, Iterable[Any]],\n inplace: bool = True\n ) -> 'Hypergraph':\n \"\"\"\n Removes the specified nodes from the hypergraph.\n\n Parameters\n ----------\n nodes : Any | Iterable[Any]\n Node identifier(s) to remove.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the nodes removed.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any node ID does not exist.\n \"\"\"\n if isinstance(nodes, (str, int)):\n nodes = [nodes]\n else:\n nodes = list(nodes)\n\n non_existing = set(nodes) - set(self.node_properties.keys())\n if non_existing:\n raise HypergraphError(f\"Nodes {non_existing} do not exist in the hypergraph.\")\n\n if inplace:\n for node_id in nodes:\n self.remove_node(node_id)\n return self\n else:\n # Create a new Hypergraph instance with nodes removed\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for node_id in nodes:\n del new_node_properties[node_id]\n # Remove node from all hyperedges\n for hyperedge in new_hyperedges.values():\n if node_id in hyperedge.nodes:\n hyperedge.remove_node(node_id)\n # Remove node from graph, which also removes all incidences\n new_graph.remove_node(node_id)\n\n # Remove nodes not connected to any hyperedges\n retained_nodes = set()\n for hyperedge in new_hyperedges.values():\n retained_nodes.update(hyperedge.nodes)\n\n new_node_properties = {node: props for node, props in new_node_properties.items() if node in retained_nodes}\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n\n def add_incidence(\n self,\n he_id: Any,\n node_id: Any,\n attributes: Optional[Dict[str, Any]] = None,\n inplace: bool = True\n ) -> 'Hypergraph':\n \"\"\"\n Adds a single incidence with attributes to the hypergraph.\n\n Parameters\n ----------\n he_id : Any\n Identifier of the hyperedge.\n node_id : Any\n Identifier of the node.\n attributes : Optional[Dict[str, Any]] = None\n Properties to add to the incidence as key-value pairs.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added incidence.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If the hyperedge or node does not exist, or if the incidence already exists.\n \"\"\"\n if he_id not in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' does not exist in the hypergraph.\")\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n if node_id in self.hyperedges[he_id].nodes:\n raise HypergraphError(f\"Incidence between hyperedge '{he_id}' and node '{node_id}' already exists.\")\n\n if inplace:\n # Add node to HyperEdge's nodes\n self.hyperedges[he_id].add_node(node_id)\n # Update hyperedge_properties if attributes provided\n if attributes:\n self.hyperedge_properties[he_id].update(attributes)\n # Add edge in graph with attributes\n self.graph.add_edge(he_id, node_id, **(attributes if attributes else {}))\n return self\n else:\n # Create a new Hypergraph instance with the incidence added\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n # Add node to HyperEdge's nodes\n new_hyperedges[he_id].add_node(node_id)\n # Update hyperedge_properties if attributes provided\n if attributes:\n new_hyperedge_properties[he_id].update(attributes)\n # Add edge in graph with attributes\n new_graph.add_edge(he_id, node_id, **(attributes if attributes else {}))\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n\n def remove_incidence(\n self,\n he_id: Any,\n node_id: Any,\n inplace: bool = True\n ) -> 'Hypergraph':\n \"\"\"\n Removes a single incidence from the hypergraph.\n\n Parameters\n ----------\n he_id : Any\n Identifier of the hyperedge.\n node_id : Any\n Identifier of the node.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the incidence removed.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If the hyperedge or node does not exist, or if the incidence does not exist.\n \"\"\"\n if he_id not in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' does not exist in the hypergraph.\")\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n if node_id not in self.hyperedges[he_id].nodes:\n raise HypergraphError(f\"Incidence between hyperedge '{he_id}' and node '{node_id}' does not exist.\")\n\n if inplace:\n # Remove node from HyperEdge's nodes\n self.hyperedges[he_id].remove_node(node_id)\n # Remove edge from graph\n self.graph.remove_edge(he_id, node_id)\n return self\n else:\n # Create a new Hypergraph instance with the incidence removed\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n # Remove node from HyperEdge's nodes\n new_hyperedges[he_id].remove_node(node_id)\n # Remove edge from graph\n new_graph.remove_edge(he_id, node_id)\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n\n # Managing Properties and Incidences\n\n def adjacency_matrix(self, s: int = 1, index: bool = False) -> Tuple[Optional[csr_matrix], Dict[int, Any]]:\n \"\"\"\n Generates the adjacency matrix for nodes based on s-node connectivity.\n \"\"\"\n from scipy.sparse import lil_matrix\n\n node_ids = list(self.node_properties.keys())\n node_index = {node_id: idx for idx, node_id in enumerate(node_ids)}\n size = len(node_ids)\n if size == 0:\n return None, {}\n\n A = lil_matrix((size, size), dtype=int)\n for he in self.hyperedges.values():\n nodes = list(he.nodes)\n for i in range(len(nodes)):\n for j in range(i + 1, len(nodes)):\n A[node_index[nodes[i]], node_index[nodes[j]]] += 1\n\n # Apply the threshold s and convert to binary\n A = (A >= s).astype(int)\n A = A.tocsr()\n\n if index:\n return A, node_index\n return A, {}\n\n def hyperedge_adjacency_matrix(self, s: int = 1, index: bool = False) -> Tuple[Optional[csr_matrix], Dict[int, Any]]:\n \"\"\"\n Generates the adjacency matrix for hyperedges based on s-hyperedge connectivity.\n\n Parameters\n ----------\n s : int, optional, default=1\n The number of shared nodes required for hyperedges to be considered adjacent.\n index : bool, optional, default=False\n If True, returns a mapping from matrix indices to hyperedge IDs.\n\n Returns\n -------\n Tuple[Optional[csr_matrix], Dict[int, Any]]\n - The adjacency matrix in CSR format.\n - A dictionary mapping matrix indices to hyperedge IDs.\n \"\"\"\n from scipy.sparse import lil_matrix\n\n hyperedge_ids = list(self.hyperedges.keys())\n he_index = {he_id: idx for idx, he_id in enumerate(hyperedge_ids)}\n size = len(hyperedge_ids)\n if size == 0:\n return None, {}\n\n A = lil_matrix((size, size), dtype=int)\n for i, he1 in enumerate(hyperedge_ids):\n nodes1 = self.hyperedges[he1].nodes\n for j in range(i + 1, size):\n he2 = hyperedge_ids[j]\n nodes2 = self.hyperedges[he2].nodes\n shared_nodes = nodes1 & nodes2\n if len(shared_nodes) >= s:\n A[i, j] = 1\n A[j, i] = 1\n\n A = A.tocsr()\n\n if index:\n return A, he_index\n return A, {}\n\n def get_hyperedges_of_node(self, node_id: Any) -> Set[Any]:\n \"\"\"\n Retrieves all hyperedges that a given node is part of.\n\n Parameters\n ----------\n node_id : Any\n The node identifier.\n\n Returns\n -------\n Set[Any]\n A set of hyperedge IDs that the node belongs to.\n\n Raises\n ------\n HypergraphError\n If the node does not exist in the hypergraph.\n \"\"\"\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n return {he.id for he in self.hyperedges.values() if node_id in he.nodes}\n\n def collapse_duplicate_hyperedges(\n self,\n name: Optional[str] = None,\n use_uids: Optional[List[Any]] = None,\n use_counts: bool = False,\n return_counts: bool = True,\n return_equivalence_classes: bool = False,\n aggregate_properties_by: Optional[Dict[str, str]] = None,\n ) -> Union['Hypergraph', Tuple['Hypergraph', Dict[Any, Set[Any]]]]:\n \"\"\"\n Collapses duplicate hyperedges (hyperedges with identical node memberships) into single hyperedges.\n\n Parameters\n ----------\n name : Optional[str], default=None\n The name assigned to the collapsed hypergraph. If None, defaults to the original name suffixed with '_collapsed_hyperedges'.\n\n use_uids : Optional[List[Any]] = None\n Specifies the hyperedge identifiers to use as representatives for each equivalence class.\n If two identifiers occur in the same equivalence class, the first one found in `use_uids` is used.\n If None, the first encountered hyperedge in each class is used as the representative.\n\n use_counts : bool, optional, default=False\n If True, renames the equivalence class representatives by appending the size of the class (e.g., 'HE1:3').\n\n return_counts : bool, optional, default=True\n If True, adds the size of each equivalence class to the properties of the representative hyperedge under the key 'equivalence_class_size'.\n\n return_equivalence_classes : bool, optional, default=False\n If True, returns a tuple containing the new collapsed hypergraph and a dictionary mapping representatives to their equivalence classes.\n\n aggregate_properties_by : Optional[Dict[str, str]] = None\n A dictionary specifying aggregation methods for hyperedge properties. Keys are property names, and values are aggregation functions (e.g., {'weight': 'sum'}).\n Properties not specified will use the 'first' aggregation.\n\n Returns\n -------\n Hypergraph or Tuple[Hypergraph, Dict[Any, Set[Any]]]\n - If `return_equivalence_classes=False`, returns the new collapsed hypergraph.\n - If `return_equivalence_classes=True`, returns a tuple containing the collapsed hypergraph and a dictionary of equivalence classes.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is empty or improperly structured.\n \"\"\"\n if not self.hyperedges:\n raise HypergraphError(\"Cannot collapse hyperedges in an empty hypergraph.\")\n\n # Identify equivalence classes based on identical node memberships\n membership_to_hyperedges: Dict[frozenset, Set[Any]] = {}\n for he_id, hyperedge in self.hyperedges.items():\n key = frozenset(hyperedge.nodes)\n membership_to_hyperedges.setdefault(key, set()).add(he_id)\n\n # Filter out classes with only one hyperedge (no duplicates)\n equivalence_classes = [hes for hes in membership_to_hyperedges.values() if len(hes) > 1]\n if not equivalence_classes:\n # No duplicates to collapse; return the original hypergraph\n return self if not return_equivalence_classes else (self, {})\n\n # Prepare aggregation methods\n aggregate_properties_by = aggregate_properties_by if aggregate_properties_by is not None else {\"weight\": \"sum\"}\n\n # Initialize mapping from old hyperedges to new hyperedges\n hyperedge_mapping: Dict[Any, Any] = {}\n equivalence_class_dict: Dict[Any, Set[Any]] = {}\n\n for eq_class in equivalence_classes:\n # Determine representative\n if use_uids:\n # Select the first UID from use_uids that is in the equivalence class\n representative = next((uid for uid in use_uids if uid in eq_class), None)\n if not representative:\n # Fallback to the first hyperedge in the equivalence class\n representative = next(iter(eq_class))\n else:\n # Use the first hyperedge in the equivalence class as representative\n representative = next(iter(eq_class))\n\n # Optionally rename with counts\n if use_counts:\n new_representative = f\"{representative}:{len(eq_class)}\"\n else:\n new_representative = representative\n\n # Map all hyperedges in the class to the representative\n for he in eq_class:\n hyperedge_mapping[he] = new_representative\n\n # Store the equivalence class\n equivalence_class_dict[new_representative] = eq_class\n\n # Replace hyperedge IDs in incidences based on mapping\n new_hyperedges = {}\n for he_id, hyperedge in self.hyperedges.items():\n new_he_id = hyperedge_mapping.get(he_id, he_id)\n if new_he_id not in new_hyperedges:\n new_hyperedges[new_he_id] = HyperEdge(id=new_he_id, nodes=hyperedge.nodes.copy(), properties=copy.deepcopy(hyperedge.properties))\n else:\n new_hyperedges[new_he_id].nodes.update(hyperedge.nodes)\n\n # Aggregate hyperedge properties\n for he_id, hyperedge in new_hyperedges.items():\n if he_id in equivalence_class_dict:\n aggregated_props = {}\n for prop, agg_func in aggregate_properties_by.items():\n values = [self.hyperedge_properties[old_he].get(prop, 0) for old_he in equivalence_class_dict[he_id]]\n if agg_func == 'sum':\n aggregated_props[prop] = sum(values)\n elif agg_func == 'mean':\n aggregated_props[prop] = sum(values) / len(values) if values else 0\n elif agg_func == 'max':\n aggregated_props[prop] = max(values) if values else None\n elif agg_func == 'min':\n aggregated_props[prop] = min(values) if values else None\n else:\n aggregated_props[prop] = values[0] if values else None # Default to first\n new_hyperedges[he_id].properties.update(aggregated_props)\n\n # Handle equivalence class size\n if use_counts:\n for he_id in equivalence_class_dict:\n new_hyperedges[he_id].properties['equivalence_class_size'] = len(equivalence_class_dict[he_id])\n elif return_counts:\n for he_id in new_hyperedges:\n if he_id in equivalence_class_dict:\n new_hyperedges[he_id].properties['equivalence_class_size'] = len(equivalence_class_dict[he_id])\n else:\n new_hyperedges[he_id].properties['equivalence_class_size'] = 1\n\n # Initialize the collapsed hypergraph\n collapsed_hypergraph = Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=copy.deepcopy(self.node_properties),\n hyperedge_properties={\n he_id: copy.deepcopy(he.properties) for he_id, he in new_hyperedges.items()\n },\n name=name if name else f\"{self.name}_collapsed_hyperedges\"\n )\n\n if return_equivalence_classes:\n return collapsed_hypergraph, equivalence_class_dict\n else:\n return collapsed_hypergraph\n\n def restrict_to_specific_hyperedges(\n self,\n hyperedges_to_retain: Iterable[Any],\n name: Optional[str] = None\n ) -> 'Hypergraph':\n \"\"\"\n Creates a new hypergraph by retaining only the specified hyperedges and removing all others.\n\n Parameters\n ----------\n hyperedges_to_retain : Iterable[Any]\n An iterable of hyperedge identifiers to retain in the new hypergraph.\n\n name : Optional[str], default=None\n The name assigned to the restricted hypergraph. If None, defaults to the original name suffixed with '_restricted_hyperedges'.\n\n Returns\n -------\n Hypergraph\n A new hypergraph containing only the specified hyperedges and their associated nodes.\n\n Raises\n ------\n HypergraphError\n If none of the specified hyperedges exist in the hypergraph.\n \"\"\"\n hyperedges_to_retain = set(hyperedges_to_retain)\n existing_hyperedges = set(self.hyperedges.keys())\n invalid_hyperedges = hyperedges_to_retain - existing_hyperedges\n if invalid_hyperedges:\n raise HypergraphError(f\"The following hyperedges do not exist and cannot be retained: {invalid_hyperedges}\")\n\n # Determine hyperedges to remove\n hyperedges_to_remove = existing_hyperedges - hyperedges_to_retain\n if not hyperedges_to_remove:\n # No hyperedges to remove; return the original hypergraph\n return self\n\n # Remove hyperedges using the existing remove_hyperedges method\n restricted_hypergraph = self.remove_hyperedges(hyperedges_to_remove, inplace=False)\n restricted_hypergraph.name = name if name else f\"{self.name}_restricted_hyperedges\"\n\n return restricted_hypergraph\n\n def restrict_to_specific_nodes(\n self,\n nodes_to_retain: Iterable[Any],\n name: Optional[str] = None\n ) -> 'Hypergraph':\n \"\"\"\n Creates a new hypergraph by retaining only the specified nodes and removing all others.\n\n Parameters\n ----------\n nodes_to_retain : Iterable[Any]\n An iterable of node identifiers to retain in the new hypergraph.\n\n name : Optional[str], default=None\n The name assigned to the restricted hypergraph. If None, defaults to the original name suffixed with '_restricted_nodes'.\n\n Returns\n -------\n Hypergraph\n A new hypergraph containing only the specified nodes and their associated hyperedges.\n\n Raises\n ------\n HypergraphError\n If none of the specified nodes exist in the hypergraph.\n \"\"\"\n nodes_to_retain = set(nodes_to_retain)\n existing_nodes = set(self.node_properties.keys())\n invalid_nodes = nodes_to_retain - existing_nodes\n if invalid_nodes:\n raise HypergraphError(f\"The following nodes do not exist and cannot be retained: {invalid_nodes}\")\n\n # Determine nodes to remove\n nodes_to_remove = existing_nodes - nodes_to_retain\n if not nodes_to_remove:\n # No nodes to remove; return the original hypergraph\n return self\n\n # Remove nodes using the existing remove_nodes_from method\n restricted_hypergraph = self.remove_nodes_from(nodes_to_remove, inplace=False)\n restricted_hypergraph.name = name if name else f\"{self.name}_restricted_nodes\"\n\n return restricted_hypergraph\n\n def add_incidences_from(\n self,\n incidences: Iterable[Union[Tuple[Any, Any], Tuple[Any, Any, Dict[str, Any]]]],\n inplace: bool = True\n ) -> 'Hypergraph':\n \"\"\"\n Adds a collection of incidences to the hypergraph.\n\n Parameters\n ----------\n incidences : Iterable[Union[Tuple[Any, Any], Tuple[Any, Any, Dict[str, Any]]]]\n Incidence tuples as:\n - (he_id, node_id)\n - (he_id, node_id, attributes)\n\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added incidences.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any hyperedge or node does not exist, or if any incidence already exists.\n ValueError\n If the structure of any incidence tuple is invalid.\n \"\"\"\n new_incidences = []\n for pr in incidences:\n if not isinstance(pr, tuple):\n raise ValueError(f\"Each incidence must be a tuple, got {type(pr)}\")\n if len(pr) == 2:\n he_id, node_id = pr\n attrs = {}\n elif len(pr) == 3:\n he_id, node_id, attrs = pr\n if not isinstance(attrs, dict):\n raise ValueError(f\"Attributes must be a dictionary, got {type(attrs)}\")\n else:\n raise ValueError(f\"Incidence tuples must be of length 2 or 3, got {len(pr)}\")\n\n if he_id not in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' does not exist in the hypergraph.\")\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n if node_id in self.hyperedges[he_id].nodes:\n raise HypergraphError(f\"Incidence between hyperedge '{he_id}' and node '{node_id}' already exists.\")\n\n new_incidences.append((he_id, node_id, attrs.copy()))\n\n if inplace:\n for he_id, node_id, attrs in new_incidences:\n # Add node to HyperEdge's nodes\n self.hyperedges[he_id].add_node(node_id)\n # Update hyperedge_properties if attributes provided\n if attrs:\n self.hyperedge_properties[he_id].update(attrs)\n # Add edge in graph with attributes\n self.graph.add_edge(he_id, node_id, **(attrs if attrs else {}))\n return self\n else:\n # Create a new Hypergraph instance with the incidences added\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for he_id, node_id, attrs in new_incidences:\n # Add node to HyperEdge's nodes\n new_hyperedges[he_id].add_node(node_id)\n # Update hyperedge_properties if attributes provided\n if attrs:\n new_hyperedge_properties[he_id].update(attrs)\n # Add edge in graph with attributes\n new_graph.add_edge(he_id, node_id, **(attrs if attrs else {}))\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n\n def remove_incidences(\n self,\n incidences: Iterable[Tuple[Any, Any]],\n inplace: bool = True\n ) -> 'Hypergraph':\n \"\"\"\n Removes the specified incidences from the hypergraph.\n\n Parameters\n ----------\n incidences : Iterable[Tuple[Any, Any]]\n Incidence identifiers as tuples of (he_id, node_id).\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the incidences removed.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any incidence does not exist.\n \"\"\"\n incidence_ids = list(incidences)\n\n # Check existence of incidences\n for he_id, node_id in incidence_ids:\n if he_id not in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' does not exist in the hypergraph.\")\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n if node_id not in self.hyperedges[he_id].nodes:\n raise HypergraphError(f\"Incidence between hyperedge '{he_id}' and node '{node_id}' does not exist.\")\n\n if inplace:\n for he_id, node_id in incidence_ids:\n # Remove node from HyperEdge's nodes\n self.hyperedges[he_id].remove_node(node_id)\n # Remove edge from graph\n self.graph.remove_edge(he_id, node_id)\n return self\n else:\n # Create a new Hypergraph instance with the incidences removed\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for he_id, node_id in incidence_ids:\n # Remove node from HyperEdge's nodes\n new_hyperedges[he_id].remove_node(node_id)\n # Remove edge from graph\n new_graph.remove_edge(he_id, node_id)\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n\n def collapse_duplicate_nodes(\n self,\n name: Optional[str] = None,\n use_uids: Optional[List[Any]] = None,\n use_counts: bool = False,\n return_counts: bool = True,\n return_equivalence_classes: bool = False,\n aggregate_properties_by: Optional[Dict[str, str]] = None,\n ) -> Union['Hypergraph', Tuple['Hypergraph', Dict[Any, Set[Any]]]]:\n \"\"\"\n Collapses duplicate nodes (nodes with identical hyperedge memberships) into single nodes.\n\n Parameters\n ----------\n name : Optional[str], default=None\n The name assigned to the collapsed hypergraph. If None, defaults to the original name suffixed with '_collapsed_nodes'.\n\n use_uids : Optional[List[Any]] = None\n Specifies the node identifiers to use as representatives for each equivalence class.\n If two identifiers occur in the same equivalence class, the first one found in `use_uids` is used.\n If None, the first encountered node in each class is used as the representative.\n\n use_counts : bool, optional, default=False\n If True, renames the equivalence class representatives by appending the size of the class (e.g., 'N1:3').\n\n return_counts : bool, optional, default=True\n If True, adds the size of each equivalence class to the properties of the representative node under the key 'equivalence_class_size'.\n\n return_equivalence_classes : bool, optional, default=False\n If True, returns a tuple containing the new collapsed hypergraph and a dictionary mapping representatives to their equivalence classes.\n\n aggregate_properties_by : Optional[Dict[str, str]] = None\n A dictionary specifying aggregation methods for node properties. Keys are property names, and values are aggregation functions (e.g., {'weight': 'sum'}).\n Properties not specified will use the 'first' aggregation.\n\n Returns\n -------\n Hypergraph or Tuple[Hypergraph, Dict[Any, Set[Any]]]\n - If `return_equivalence_classes=False`, returns the new collapsed hypergraph.\n - If `return_equivalence_classes=True`, returns a tuple containing the collapsed hypergraph and a dictionary of equivalence classes.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is empty or improperly structured.\n \"\"\"\n if not self.node_properties:\n raise HypergraphError(\"Cannot collapse nodes in an empty hypergraph.\")\n\n # Identify equivalence classes based on identical hyperedge memberships\n membership_to_nodes: Dict[frozenset, Set[Any]] = {}\n for node_id, node_props in self.node_properties.items():\n key = frozenset(self.get_hyperedges_of_node(node_id))\n membership_to_nodes.setdefault(key, set()).add(node_id)\n\n # Filter out classes with only one node (no duplicates)\n equivalence_classes = [nodes for nodes in membership_to_nodes.values() if len(nodes) > 1]\n if not equivalence_classes:\n # No duplicates to collapse; return the original hypergraph\n return self if not return_equivalence_classes else (self, {})\n\n # Prepare aggregation methods\n aggregate_properties_by = aggregate_properties_by if aggregate_properties_by is not None else {\"weight\": \"sum\"}\n\n # Initialize mapping from old nodes to new nodes\n node_mapping: Dict[Any, Any] = {}\n equivalence_class_dict: Dict[Any, Set[Any]] = {}\n\n for eq_class in equivalence_classes:\n # Determine representative\n if use_uids:\n # Select the first UID from use_uids that is in the equivalence class\n representative = next((uid for uid in use_uids if uid in eq_class), None)\n if not representative:\n # Fallback to the first node in the equivalence class\n representative = next(iter(eq_class))\n else:\n # Use the first node in the equivalence class as representative\n representative = next(iter(eq_class))\n\n # Optionally rename with counts\n if use_counts:\n new_representative = f\"{representative}:{len(eq_class)}\"\n else:\n new_representative = representative\n\n # Map all nodes in the class to the representative\n for node in eq_class:\n node_mapping[node] = new_representative\n\n # Store the equivalence class\n equivalence_class_dict[new_representative] = eq_class\n\n # Replace node IDs in hyperedges based on mapping\n new_hyperedges = {}\n for he_id, hyperedge in self.hyperedges.items():\n new_nodes = set()\n for node_id in hyperedge.nodes:\n new_node_id = node_mapping.get(node_id, node_id)\n new_nodes.add(new_node_id)\n new_hyperedges[he_id] = HyperEdge(id=he_id, nodes=new_nodes, properties=copy.deepcopy(hyperedge.properties))\n\n # Aggregate node properties\n new_node_properties = {}\n for node_id, node_props in self.node_properties.items():\n new_node_id = node_mapping.get(node_id, node_id)\n if new_node_id not in new_node_properties:\n new_node_properties[new_node_id] = copy.deepcopy(node_props)\n else:\n for prop, agg_func in aggregate_properties_by.items():\n if prop in node_props:\n if agg_func == 'sum':\n new_node_properties[new_node_id][prop] = new_node_properties[new_node_id].get(prop, 0) + node_props[prop]\n elif agg_func == 'mean':\n # To calculate mean, store sum and count\n if 'sum_' + prop not in new_node_properties[new_node_id]:\n new_node_properties[new_node_id]['sum_' + prop] = node_props[prop]\n new_node_properties[new_node_id]['count_' + prop] = 1\n else:\n new_node_properties[new_node_id]['sum_' + prop] += node_props[prop]\n new_node_properties[new_node_id]['count_' + prop] += 1\n # Calculate mean at the end\n elif agg_func == 'max':\n current_max = new_node_properties[new_node_id].get(prop, float('-inf'))\n new_node_properties[new_node_id][prop] = max(current_max, node_props[prop])\n elif agg_func == 'min':\n current_min = new_node_properties[new_node_id].get(prop, float('inf'))\n new_node_properties[new_node_id][prop] = min(current_min, node_props[prop])\n else:\n new_node_properties[new_node_id][prop] = node_props[prop] # Default to last\n # Finalize mean calculations\n for node_id, props in new_node_properties.items():\n for prop in list(props.keys()):\n if prop.startswith('sum_'):\n base_prop = prop[4:]\n sum_val = props[prop]\n count_val = props.get('count_' + base_prop, 1)\n new_node_properties[node_id][base_prop] = sum_val / count_val if count_val > 0 else 0\n del new_node_properties[node_id][prop]\n del new_node_properties[node_id]['count_' + base_prop]\n\n # Handle equivalence class size\n if use_counts:\n for node_id in equivalence_class_dict:\n new_node_properties[node_id]['equivalence_class_size'] = len(equivalence_class_dict[node_id])\n elif return_counts:\n for node_id in new_node_properties:\n if node_id in equivalence_class_dict:\n new_node_properties[node_id]['equivalence_class_size'] = len(equivalence_class_dict[node_id])\n else:\n new_node_properties[node_id]['equivalence_class_size'] = 1\n\n # Initialize the collapsed hypergraph\n collapsed_hypergraph = Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties={\n he_id: copy.deepcopy(he.properties) for he_id, he in new_hyperedges.items()\n },\n name=name if name else f\"{self.name}_collapsed_nodes\"\n )\n\n if return_equivalence_classes:\n return collapsed_hypergraph, equivalence_class_dict\n else:\n return collapsed_hypergraph\n\n # Analyzing and Querying the Hypergraph\n\n def get_toplexes(self, return_hypergraph: bool = False) -> Union[List[Any], 'Hypergraph']:\n \"\"\"\n Computes a maximal collection of toplexes for the hypergraph.\n A :term:`toplex` is a hyperedge that is not contained in any other hyperedge.\n\n Parameters\n ----------\n return_hypergraph : bool, optional, default=False\n If True, returns a new Hypergraph consisting only of the toplexes.\n\n Returns\n -------\n List[Any] or Hypergraph\n - A list of toplex hyperedge IDs.\n - If `return_hypergraph=True`, returns a Hypergraph containing only the toplexes.\n \"\"\"\n toplexes = []\n hyperedges = list(self.hyperedges.values())\n\n for he in hyperedges:\n if not any(he.nodes < other_he.nodes for other_he in hyperedges if he.id != other_he.id):\n toplexes.append(he.id)\n\n if return_hypergraph:\n return self.restrict_to_specific_hyperedges(toplexes, name=\"Toplexes\")\n return toplexes\n\n def is_node_connected(self, s: int = 1) -> bool:\n \"\"\"\n Determines if the hypergraph is s-node-connected.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n\n Returns\n -------\n bool\n True if the hypergraph is s-node-connected, False otherwise.\n \"\"\"\n return self._is_connected(s=s, hyperedges=False)\n\n def is_hyperedge_connected(self, s: int = 1) -> bool:\n \"\"\"\n Determines if the hypergraph is s-hyperedge-connected.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n\n Returns\n -------\n bool\n True if the hypergraph is s-hyperedge-connected, False otherwise.\n \"\"\"\n return self._is_connected(s=s, hyperedges=True)\n\n def _is_connected(self, s: int = 1, hyperedges: bool = False) -> bool:\n \"\"\"\n Internal method to determine connectivity based on nodes or hyperedges.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n hyperedges : bool, optional, default=False\n If True, checks for s-hyperedge-connectedness. Otherwise, checks for s-node-connectedness.\n\n Returns\n -------\n bool\n Connectivity status.\n \"\"\"\n if hyperedges:\n # Create hyperedge connectivity graph: hyperedges are nodes, connect if they share >= s nodes\n hyperedge_graph = nx.Graph()\n hyperedge_ids = list(self.hyperedges.keys())\n hyperedge_graph.add_nodes_from(hyperedge_ids)\n\n for i, he1 in enumerate(hyperedge_ids):\n nodes1 = self.hyperedges[he1].nodes\n for he2 in hyperedge_ids[i+1:]:\n nodes2 = self.hyperedges[he2].nodes\n shared_nodes = nodes1 & nodes2\n if len(shared_nodes) >= s:\n hyperedge_graph.add_edge(he1, he2)\n\n try:\n return nx.is_connected(hyperedge_graph)\n except nx.NetworkXPointlessConcept:\n return False\n else:\n # Create node connectivity graph: nodes are nodes, connect if they share >= s hyperedges\n node_graph = nx.Graph()\n node_ids = list(self.node_properties.keys())\n node_graph.add_nodes_from(node_ids)\n\n for i, node1 in enumerate(node_ids):\n hyperedges1 = {he.id for he in self.hyperedges.values() if node1 in he.nodes}\n for node2 in node_ids[i+1:]:\n hyperedges2 = {he.id for he in self.hyperedges.values() if node2 in he.nodes}\n shared_hyperedges = hyperedges1 & hyperedges2\n if len(shared_hyperedges) >= s:\n node_graph.add_edge(node1, node2)\n\n try:\n return nx.is_connected(node_graph)\n except nx.NetworkXPointlessConcept:\n return False\n\n def get_node_connected_components(\n self, s: int = 1, return_singletons: bool = False\n ) -> Iterator[Set[Any]]:\n \"\"\"\n Yields the s-node-connected components of the hypergraph.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n\n Yields\n ------\n Set[Any]\n Sets of node IDs representing each connected component.\n \"\"\"\n return self.s_connected_components(s=s, hyperedges=False, return_singletons=return_singletons)\n\n def get_hyperedge_connected_components(\n self, s: int = 1, return_singletons: bool = False\n ) -> Iterator[Set[Any]]:\n \"\"\"\n Yields the s-hyperedge-connected components of the hypergraph.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n\n Yields\n ------\n Set[Any]\n Sets of hyperedge IDs representing each connected component.\n \"\"\"\n return self.s_connected_components(s=s, hyperedges=True, return_singletons=return_singletons)\n\n def get_node_connected_subgraphs(\n self, s: int = 1, return_singletons: bool = False, name: Optional[str] = None\n ) -> Iterator['Hypergraph']:\n \"\"\"\n Yields subgraphs corresponding to each s-node-connected component.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n name : Optional[str], default=None\n Base name for the subgraphs. Each subgraph will have a unique name appended.\n\n Yields\n ------\n Hypergraph\n Subgraphs representing each connected component.\n \"\"\"\n return self.s_component_subgraphs(\n s=s,\n hyperedges=False,\n return_singletons=return_singletons,\n name=name\n )\n\n def get_hyperedge_connected_subgraphs(\n self, s: int = 1, return_singletons: bool = False, name: Optional[str] = None\n ) -> Iterator['Hypergraph']:\n \"\"\"\n Yields subgraphs corresponding to each s-hyperedge-connected component.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n name : Optional[str], default=None\n Base name for the subgraphs. Each subgraph will have a unique name appended.\n\n Yields\n ------\n Hypergraph\n Subgraphs representing each connected component.\n \"\"\"\n return self.s_component_subgraphs(\n s=s,\n hyperedges=True,\n return_singletons=return_singletons,\n name=name\n )\n\n def get_singleton_hyperedges(self) -> List[Any]:\n \"\"\"\n Returns a list of singleton hyperedges.\n A singleton hyperedge is a hyperedge of size 1 where its sole node has degree 1.\n\n Returns\n -------\n List[Any]\n A list of singleton hyperedge IDs.\n \"\"\"\n singletons = []\n for he in self.hyperedges.values():\n if len(he.nodes) == 1:\n node = next(iter(he.nodes))\n node_degree = sum(1 for hyperedge in self.hyperedges.values() if node in hyperedge.nodes)\n if node_degree == 1:\n singletons.append(he.id)\n return singletons\n\n def remove_singleton_hyperedges(self, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Constructs a clone of the hypergraph with singleton hyperedges removed.\n \"\"\"\n singletons = self.get_singleton_hyperedges()\n if not singletons:\n return self.copy(name=name)\n\n new_hypergraph = self.remove_hyperedges(singletons, inplace=False)\n new_hypergraph.name = name if name else f\"{self.name}_no_singleton_hyperedges\"\n return new_hypergraph\n\n def s_connected_components(\n self, \n s: int = 1, \n hyperedges: bool = True, \n return_singletons: bool = False\n ) -> Iterator[Set[Any]]:\n \"\"\"\n Yields the s-hyperedge-connected or s-node-connected components of the hypergraph.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n hyperedges : bool, optional, default=True\n If True, yields s-hyperedge-connected components. Otherwise, yields s-node-connected components.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n\n Yields\n ------\n Set[Any]\n Sets of hyperedge IDs or node IDs representing each connected component.\n \"\"\"\n if hyperedges:\n # s-hyperedge-connected: hyperedges are connected if they share at least s nodes\n hyperedge_graph = nx.Graph()\n hyperedge_ids = list(self.hyperedges.keys())\n hyperedge_graph.add_nodes_from(hyperedge_ids)\n\n for i, he1 in enumerate(hyperedge_ids):\n nodes1 = self.hyperedges[he1].nodes\n for he2 in hyperedge_ids[i + 1:]:\n nodes2 = self.hyperedges[he2].nodes\n shared_nodes = nodes1 & nodes2\n if len(shared_nodes) >= s:\n hyperedge_graph.add_edge(he1, he2)\n\n components = nx.connected_components(hyperedge_graph)\n for component in components:\n if not return_singletons and len(component) == 1:\n continue\n yield component\n else:\n # s-node-connected: nodes are connected if they share at least s hyperedges\n node_graph = nx.Graph()\n node_ids = list(self.node_properties.keys())\n node_graph.add_nodes_from(node_ids)\n\n for i, node1 in enumerate(node_ids):\n hyperedges1 = {he.id for he in self.hyperedges.values() if node1 in he.nodes}\n for node2 in node_ids[i + 1:]:\n hyperedges2 = {he.id for he in self.hyperedges.values() if node2 in he.nodes}\n shared_hyperedges = hyperedges1 & hyperedges2\n if len(shared_hyperedges) >= s:\n node_graph.add_edge(node1, node2)\n\n components = nx.connected_components(node_graph)\n for component in components:\n if not return_singletons and len(component) == 1:\n continue\n yield component\n\n def s_component_subgraphs(\n self,\n s: int = 1,\n hyperedges: bool = True,\n return_singletons: bool = False,\n name: Optional[str] = None\n ) -> Iterator['Hypergraph']:\n \"\"\"\n Yields subgraphs corresponding to each s-hyperedge-connected or s-node-connected component.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n hyperedges : bool, optional, default=True\n If True, yields subgraphs of s-hyperedge-connected components. Otherwise, yields subgraphs of s-node-connected components.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n name : Optional[str], default=None\n Base name for the subgraphs. Each subgraph will have a unique name appended.\n\n Yields\n ------\n Hypergraph\n Subgraphs representing each connected component.\n \"\"\"\n for idx, component in enumerate(\n self.s_connected_components(s=s, hyperedges=hyperedges, return_singletons=return_singletons)\n ):\n if hyperedges:\n yield self.restrict_to_specific_hyperedges(\n hyperedges_to_retain=component, \n name=f\"{name or self.name}_component_{idx}\"\n )\n else:\n yield self.restrict_to_specific_nodes(\n nodes_to_retain=component, \n name=f\"{name or self.name}_component_{idx}\"\n )\n\n def compute_node_diameters(self, s: int = 1) -> Tuple[int, List[int], List[Set[Any]]]:\n \"\"\"\n Returns the node diameters of the connected components in the hypergraph.\n\n Parameters\n ----------\n s : int, optional, default=1\n The number of shared hyperedges required for nodes to be considered adjacent.\n\n Returns\n -------\n Tuple[int, List[int], List[Set[Any]]]\n - Maximum diameter among all connected components.\n - List of diameters for each s-node connected component.\n - List of sets, each containing node IDs in an s-node connected component.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is not s-connected or has no nodes.\n \"\"\"\n A, node_id_map = self.adjacency_matrix(s=s, index=True)\n if A is None or A.shape[0] == 0:\n raise HypergraphError(\"The hypergraph has no nodes to compute diameters.\")\n\n graph = nx.from_scipy_sparse_array(A)\n\n if not nx.is_connected(graph):\n raise HypergraphError(f\"Hypergraph is not s-node-connected. s={s}\")\n\n diams = []\n comps = []\n for component in nx.connected_components(graph):\n subgraph = graph.subgraph(component)\n if len(subgraph) == 1:\n diamc = 0 # Diameter of a single node is 0\n else:\n try:\n diamc = nx.diameter(subgraph)\n except nx.NetworkXError:\n diamc = float('inf') # Infinite diameter if the subgraph is not connected\n diams.append(diamc)\n component_nodes = {node_id_map[node] for node in component}\n comps.append(component_nodes)\n\n if not diams:\n raise HypergraphError(\"No connected components found to compute diameters.\")\n\n max_diam = max(diams)\n return max_diam, diams, comps\n\n def compute_hyperedge_diameters(self, s: int = 1) -> Tuple[int, List[int], List[Set[Any]]]:\n \"\"\"\n Returns the hyperedge diameters of the s-hyperedge-connected component subgraphs in the hypergraph.\n\n Parameters\n ----------\n s : int, optional, default=1\n The number of shared nodes required for hyperedges to be considered adjacent.\n\n Returns\n -------\n Tuple[int, List[int], List[Set[Any]]]\n - Maximum diameter among all s-hyperedge-connected components.\n - List of diameters for each s-hyperedge connected component.\n - List of sets, each containing hyperedge IDs in an s-hyperedge connected component.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is not s-hyperedge-connected or has no hyperedges.\n \"\"\"\n A, he_id_map = self.hyperedge_adjacency_matrix(s=s, index=True)\n if A is None or A.shape[0] == 0:\n raise HypergraphError(\"The hypergraph has no hyperedges to compute diameters.\")\n\n graph = nx.from_scipy_sparse_array(A)\n\n if not nx.is_connected(graph):\n raise HypergraphError(f\"Hypergraph is not s-hyperedge-connected. s={s}\")\n\n diams = []\n comps = []\n for component in nx.connected_components(graph):\n subgraph = graph.subgraph(component)\n if len(subgraph) == 1:\n diamc = 0 # Diameter of a single hyperedge is 0\n else:\n try:\n diamc = nx.diameter(subgraph)\n except nx.NetworkXError:\n diamc = float('inf') # Infinite diameter if the subgraph is not connected\n diams.append(diamc)\n component_hyperedges = {he_id_map[he] for he in component}\n comps.append(component_hyperedges)\n\n if not diams:\n raise HypergraphError(\"No connected components found to compute hyperedge diameters.\")\n\n max_diam = max(diams)\n return max_diam, diams, comps\n\n def compute_node_diameter(self, s: int = 1) -> int:\n \"\"\"\n Returns the diameter of the hypergraph based on s-node connectivity.\n\n Parameters\n ----------\n s : int, optional, default=1\n The number of shared hyperedges required for nodes to be considered adjacent.\n\n Returns\n -------\n int\n The diameter of the hypergraph.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is not s-node-connected or has no nodes.\n \"\"\"\n A, _ = self.adjacency_matrix(s=s, index=True)\n if A is None or A.shape[0] == 0:\n raise HypergraphError(\"The hypergraph has no nodes to compute diameter.\")\n\n graph = nx.from_scipy_sparse_array(A)\n if not nx.is_connected(graph):\n raise HypergraphError(f\"Hypergraph is not s-node-connected. s={s}\")\n\n try:\n return nx.diameter(graph)\n except nx.NetworkXError as e:\n raise HypergraphError(f\"Could not compute diameter: {e}\")\n\n def compute_hyperedge_diameter(self, s: int = 1) -> int:\n \"\"\"\n Returns the diameter of the hypergraph based on s-hyperedge connectivity.\n\n Parameters\n ----------\n s : int, optional, default=1\n The number of shared nodes required for hyperedges to be considered adjacent.\n\n Returns\n -------\n int\n The diameter of the hypergraph based on hyperedge connectivity.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is not s-hyperedge-connected or has no hyperedges.\n \"\"\"\n A, _ = self.hyperedge_adjacency_matrix(s=s, index=True)\n if A is None or A.shape[0] == 0:\n raise HypergraphError(\"The hypergraph has no hyperedges to compute diameter.\")\n\n graph = nx.from_scipy_sparse_array(A)\n if not nx.is_connected(graph):\n raise HypergraphError(f\"Hypergraph is not s-hyperedge-connected. s={s}\")\n\n try:\n return nx.diameter(graph)\n except nx.NetworkXError as e:\n raise HypergraphError(f\"Could not compute hyperedge diameter: {e}\")\n\n def get_node_distance(self, source: Any, target: Any, s: int = 1) -> Union[int, float]:\n \"\"\"\n Returns the shortest s-walk distance between two nodes in the hypergraph.\n\n Parameters\n ----------\n source : Any\n A node identifier in the hypergraph.\n target : Any\n A node identifier in the hypergraph.\n s : int, optional, default=1\n The number of shared hyperedges required for nodes to be considered adjacent.\n\n Returns\n -------\n Union[int, float]\n The shortest s-walk distance between the source and target nodes.\n Returns `float('inf')` if no path exists.\n\n Raises\n ------\n HypergraphError\n If either the source or target node does not exist in the hypergraph.\n \"\"\"\n if source not in self.node_properties:\n raise HypergraphError(f\"Source node '{source}' does not exist in the hypergraph.\")\n if target not in self.node_properties:\n raise HypergraphError(f\"Target node '{target}' does not exist in the hypergraph.\")\n\n A, node_id_map = self.adjacency_matrix(s=s, index=True)\n if A is None:\n raise HypergraphError(\"Adjacency matrix could not be generated.\")\n\n graph = nx.from_scipy_sparse_array(A)\n\n try:\n distance = nx.shortest_path_length(graph, source=source, target=target)\n return distance\n except (nx.NetworkXNoPath, nx.NodeNotFound):\n warnings.warn(f\"No s-walk path between '{source}' and '{target}'. Returning infinity.\")\n return float('inf')\n\n def get_hyperedge_distance(self, source: Any, target: Any, s: int = 1) -> Union[int, float]:\n \"\"\"\n Returns the shortest s-walk distance between two hyperedges in the hypergraph.\n\n Parameters\n ----------\n source : Any\n A hyperedge identifier in the hypergraph.\n target : Any\n A hyperedge identifier in the hypergraph.\n s : int, optional, default=1\n The number of shared nodes required for hyperedges to be considered adjacent.\n\n Returns\n -------\n Union[int, float]\n The shortest s-walk distance between the source and target hyperedges.\n Returns `float('inf')` if no path exists.\n\n Raises\n ------\n HypergraphError\n If either the source or target hyperedge does not exist in the hypergraph.\n \"\"\"\n if source not in self.hyperedges:\n raise HypergraphError(f\"Source hyperedge '{source}' does not exist in the hypergraph.\")\n if target not in self.hyperedges:\n raise HypergraphError(f\"Target hyperedge '{target}' does not exist in the hypergraph.\")\n\n A, he_id_map = self.hyperedge_adjacency_matrix(s=s, index=True)\n if A is None:\n raise HypergraphError(\"Hyperedge adjacency matrix could not be generated.\")\n\n graph = nx.from_scipy_sparse_array(A)\n\n try:\n distance = nx.shortest_path_length(graph, source=source, target=target)\n return distance\n except (nx.NetworkXNoPath, nx.NodeNotFound):\n warnings.warn(f\"No s-walk path between hyperedges '{source}' and '{target}'. Returning infinity.\")\n return float('inf')\n\n # Advanced Operations and Transformations\n\n def union(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Returns the union of the current hypergraph with another hypergraph.\n The union combines all nodes and hyperedges from both hypergraphs.\n\n Parameters\n ----------\n other : Hypergraph\n The hypergraph to union with.\n inplace : bool, optional, default=False\n If True, modifies the current hypergraph. Otherwise, returns a new Hypergraph instance.\n name : Optional[str], default=None\n The name for the resulting hypergraph. If None, defaults to 'Union_of_{self.name}_{other.name}'.\n\n Returns\n -------\n Hypergraph\n The resulting union hypergraph.\n\n Raises\n ------\n TypeError\n If `other` is not an instance of Hypergraph.\n \"\"\"\n if not isinstance(other, Hypergraph):\n raise TypeError(\"The `other` parameter must be an instance of Hypergraph.\")\n\n if inplace:\n # Add nodes from other\n for node_id, props in other.node_properties.items():\n if node_id not in self.node_properties:\n self.add_node(node_id, properties=props, inplace=True)\n else:\n # Optionally, merge properties\n self.node_properties[node_id].update(props)\n self.graph.nodes[node_id].update(props)\n\n # Add hyperedges from other\n for he_id, hyperedge in other.hyperedges.items():\n if he_id not in self.hyperedges:\n self.add_hyperedge(he_id, hyperedge.nodes, properties=hyperedge.properties)\n else:\n # Optionally, merge properties and nodes\n self.hyperedges[he_id].nodes.update(hyperedge.nodes)\n self.hyperedge_properties[he_id].update(hyperedge.properties)\n for node in hyperedge.nodes:\n if node not in self.graph:\n self.add_node(node)\n self.graph.add_edge(he_id, node)\n if name:\n self.name = name\n return self\n else:\n # Create a new Hypergraph instance\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n new_name = name if name else f\"Union_of_{self.name}_{other.name}\"\n\n # Add nodes from other\n for node_id, props in other.node_properties.items():\n if node_id not in new_node_properties:\n new_node_properties[node_id] = copy.deepcopy(props)\n new_graph.add_node(node_id, bipartite='node', **props)\n\n # Add hyperedges from other\n for he_id, hyperedge in other.hyperedges.items():\n if he_id not in new_hyperedges:\n new_hyperedges[he_id] = copy.deepcopy(hyperedge)\n new_hyperedge_properties[he_id] = copy.deepcopy(other.hyperedge_properties[he_id])\n new_graph.add_node(he_id, bipartite='hyperedge', **new_hyperedge_properties[he_id])\n new_bipartite_nodes.add(he_id)\n for node in hyperedge.nodes:\n new_graph.add_edge(he_id, node)\n else:\n # Merge nodes and properties\n new_hyperedges[he_id].nodes.update(hyperedge.nodes)\n new_hyperedge_properties[he_id].update(other.hyperedge_properties[he_id])\n for node in hyperedge.nodes:\n new_graph.add_edge(he_id, node)\n\n # Construct the new Hypergraph\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=new_name\n )\n\n def intersection(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Returns the intersection of the current hypergraph with another hypergraph.\n The intersection includes only nodes and hyperedges present in both hypergraphs.\n\n Parameters\n ----------\n other : Hypergraph\n The hypergraph to intersect with.\n inplace : bool, optional, default=False\n If True, modifies the current hypergraph to keep only the intersecting elements.\n Otherwise, returns a new Hypergraph instance.\n name : Optional[str], default=None\n The name for the resulting hypergraph. If None, defaults to 'Intersection_of_{self.name}_{other.name}'.\n\n Returns\n -------\n Hypergraph\n The resulting intersection hypergraph.\n\n Raises\n ------\n TypeError\n If `other` is not an instance of Hypergraph.\n \"\"\"\n if not isinstance(other, Hypergraph):\n raise TypeError(\"The `other` parameter must be an instance of Hypergraph.\")\n\n intersect_nodes = set(self.node_properties.keys()) & set(other.node_properties.keys())\n intersect_hyperedges = set(self.hyperedges.keys()) & set(other.hyperedges.keys())\n\n if inplace:\n # Remove non-intersecting nodes and hyperedges\n nodes_to_remove = set(self.node_properties.keys()) - intersect_nodes\n hyperedges_to_remove = set(self.hyperedges.keys()) - intersect_hyperedges\n self.remove_nodes_from(nodes_to_remove, inplace=True)\n self.remove_hyperedges(hyperedges_to_remove, inplace=True)\n return self\n else:\n # Create a new Hypergraph instance\n new_hyperedges = {}\n new_node_properties = {node_id: copy.deepcopy(self.node_properties[node_id]) for node_id in intersect_nodes}\n new_hyperedge_properties = {}\n new_graph = nx.Graph()\n new_bipartite_nodes = set()\n\n for he_id in intersect_hyperedges:\n he_self = self.hyperedges[he_id]\n he_other = other.hyperedges[he_id]\n # Intersection hyperedges have the same nodes and merged properties\n new_nodes = set(he_self.nodes) & set(he_other.nodes)\n if not new_nodes:\n continue # Skip hyperedges with no common nodes\n new_hyperedges[he_id] = HyperEdge(id=he_id, nodes=new_nodes, properties={})\n # Merge properties (could define specific rules)\n new_hyperedge_properties[he_id] = {**self.hyperedge_properties.get(he_id, {}), \n **other.hyperedge_properties.get(he_id, {})}\n new_graph.add_node(he_id, bipartite='hyperedge', **new_hyperedge_properties[he_id])\n new_bipartite_nodes.add(he_id)\n for node in new_nodes:\n new_graph.add_edge(he_id, node)\n\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=name if name else f\"Intersection_of_{self.name}_{other.name}\"\n )\n\n def difference(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Returns the difference of the current hypergraph with another hypergraph.\n The difference includes nodes and hyperedges present in the current hypergraph but not in the other.\n\n Parameters\n ----------\n other : Hypergraph\n The hypergraph to subtract.\n inplace : bool, optional, default=False\n If True, modifies the current hypergraph by removing elements found in `other`.\n Otherwise, returns a new Hypergraph instance.\n name : Optional[str], default=None\n The name for the resulting hypergraph. If None, defaults to 'Difference_of_{self.name}_{other.name}'.\n\n Returns\n -------\n Hypergraph\n The resulting difference hypergraph.\n\n Raises\n ------\n TypeError\n If `other` is not an instance of Hypergraph.\n \"\"\"\n if not isinstance(other, Hypergraph):\n raise TypeError(\"The `other` parameter must be an instance of Hypergraph.\")\n\n if inplace:\n # Remove hyperedges present in other\n hyperedges_to_remove = set(self.hyperedges.keys()) & set(other.hyperedges.keys())\n self.remove_hyperedges(hyperedges_to_remove, inplace=True)\n # Remove nodes present in other\n nodes_to_remove = set(self.node_properties.keys()) & set(other.node_properties.keys())\n self.remove_nodes_from(nodes_to_remove, inplace=True)\n return self\n else:\n # Create a new Hypergraph instance\n new_hyperedges = {he_id: copy.deepcopy(he) for he_id, he in self.hyperedges.items() if he_id not in other.hyperedges}\n new_hyperedge_properties = {he_id: copy.deepcopy(props) for he_id, props in self.hyperedge_properties.items() if he_id not in other.hyperedges}\n new_node_properties = {node_id: copy.deepcopy(props) for node_id, props in self.node_properties.items() if node_id not in other.node_properties}\n\n # Reconstruct graph\n new_graph = nx.Graph()\n new_bipartite_nodes = set()\n for he_id, hyperedge in new_hyperedges.items():\n new_graph.add_node(he_id, bipartite='hyperedge', **new_hyperedge_properties[he_id])\n new_bipartite_nodes.add(he_id)\n for node in hyperedge.nodes:\n if node in new_node_properties:\n new_graph.add_edge(he_id, node)\n\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=name if name else f\"Difference_of_{self.name}_{other.name}\"\n )\n\n def symmetric_difference(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Returns the symmetric difference of the current hypergraph with another hypergraph.\n The symmetric difference includes elements present in either hypergraph but not in both.\n\n Parameters\n ----------\n other : Hypergraph\n The hypergraph to symmetric difference with.\n inplace : bool, optional, default=False\n If True, modifies the current hypergraph to keep only the symmetric difference elements.\n Otherwise, returns a new Hypergraph instance.\n name : Optional[str], default=None\n The name for the resulting hypergraph. If None, defaults to 'SymmetricDifference_of_{self.name}_{other.name}'.\n\n Returns\n -------\n Hypergraph\n The resulting symmetric difference hypergraph.\n\n Raises\n ------\n TypeError\n If `other` is not an instance of Hypergraph.\n \"\"\"\n if not isinstance(other, Hypergraph):\n raise TypeError(\"The `other` parameter must be an instance of Hypergraph.\")\n\n if inplace:\n # Hyperedges symmetric difference\n hyperedges_to_add = set(other.hyperedges.keys()) - set(self.hyperedges.keys())\n hyperedges_to_remove = set(self.hyperedges.keys()) & set(other.hyperedges.keys())\n self.remove_hyperedges(hyperedges_to_remove, inplace=True)\n for he_id in hyperedges_to_add:\n hyperedge = other.hyperedges[he_id]\n self.add_hyperedge(he_id, hyperedge.nodes, properties=hyperedge.properties)\n\n # Nodes symmetric difference\n nodes_to_add = set(other.node_properties.keys()) - set(self.node_properties.keys())\n nodes_to_remove = set(self.node_properties.keys()) & set(other.node_properties.keys())\n self.remove_nodes_from(nodes_to_remove, inplace=True)\n for node_id in nodes_to_add:\n props = other.node_properties[node_id]\n self.add_node(node_id, properties=props, inplace=True)\n\n if name:\n self.name = name\n return self\n else:\n # Create a new Hypergraph instance\n union_hg = self.union(other)\n intersection_hg = self.intersection(other)\n return union_hg.difference(intersection_hg, name=name if name else f\"SymmetricDifference_of_{self.name}_{other.name}\")\n\n def transpose(self, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Transposes the hypergraph by swapping the roles of nodes and hyperedges.\n The resulting hypergraph has hyperedges corresponding to the original nodes and vice versa.\n\n Parameters\n ----------\n name : Optional[str], default=None\n The name assigned to the transposed hypergraph. If None, defaults to the original name suffixed with '_transposed'.\n\n Returns\n -------\n Hypergraph\n The transposed hypergraph.\n \"\"\"\n transposed_hyperedges = {node_id: HyperEdge(id=node_id, nodes=set(), properties=copy.deepcopy(props))\n for node_id, props in self.node_properties.items()}\n transposed_node_properties = {he_id: copy.deepcopy(props) for he_id, props in self.hyperedge_properties.items()}\n\n for he_id, hyperedge in self.hyperedges.items():\n for node in hyperedge.nodes:\n if node in transposed_hyperedges:\n transposed_hyperedges[node].nodes.add(he_id)\n\n # Construct the transposed hypergraph\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in transposed_hyperedges.items()\n },\n node_properties=transposed_node_properties,\n hyperedge_properties={he_id: he.properties.copy() for he_id, he in transposed_hyperedges.items()},\n name=name if name else f\"{self.name}_transposed\"\n )\n\n def to_bipartite_graph(self, keep_data=False, directed=False) -> nx.Graph:\n \"\"\"\n Creates a bipartite NetworkX graph from the hypergraph.\n The nodes and hyperedges of the hypergraph become nodes in the bipartite graph.\n For every hyperedge in the hypergraph and each node it connects to, there\n is an edge in the bipartite graph.\n\n Parameters\n ----------\n keep_data : bool, optional, default = False\n If True, includes the node and hyperedge properties in the NetworkX graph.\n directed : bool, optional, default = False\n If True, the edges in the graph are directed with hyperedges as sources and nodes as targets.\n\n Returns\n -------\n networkx.Graph or networkx.DiGraph\n The bipartite graph representation of the hypergraph.\n \"\"\"\n # Choose graph type based on directed flag\n B = nx.DiGraph() if directed else nx.Graph()\n\n if not keep_data:\n # Add nodes with bipartite attributes, where 0 indicates hyperedges and 1 indicates regular nodes\n B.add_nodes_from(self.hyperedges.keys(), bipartite=0) # hyperedges\n B.add_nodes_from(self.node_properties.keys(), bipartite=1) # nodes\n\n # Add edges between hyperedges and nodes based on hyperedges data\n for he_id, hyperedge in self.hyperedges.items():\n for node in hyperedge.nodes:\n B.add_edge(he_id, node)\n else:\n # Add nodes with properties if keep_data is True\n for node_id, properties in self.node_properties.items():\n B.add_node(node_id, bipartite=1, **properties)\n\n for he_id, hyperedge in self.hyperedges.items():\n B.add_node(he_id, bipartite=0, **self.hyperedge_properties.get(he_id, {}))\n for node in hyperedge.nodes:\n # Add edges with optional properties if keep_data is True\n B.add_edge(he_id, node)\n\n return B\n\n @classmethod\n def from_bipartite_graph(cls, bipartite_graph: nx.Graph, hyperedge_prefix: str = \"HE\", node_prefix: str = \"N\", name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Constructs a Hypergraph instance from a bipartite graph.\n\n Parameters\n ----------\n bipartite_graph : nx.Graph\n A bipartite graph where one set of nodes represents hyperedges and the other represents regular nodes.\n hyperedge_prefix : str, optional, default=\"HE\"\n The prefix to identify hyperedge nodes in the bipartite graph.\n node_prefix : str, optional, default=\"N\"\n The prefix to identify regular nodes in the bipartite graph.\n name : Optional[str], default=None\n The name assigned to the new Hypergraph. If None, defaults to 'FromBipartiteGraph'.\n\n Returns\n -------\n Hypergraph\n The constructed Hypergraph instance.\n\n Raises\n ------\n ValueError\n If the bipartite graph does not contain two distinct sets of nodes identifiable by the provided prefixes.\n \"\"\"\n hyperedges = {}\n node_properties = {}\n hyperedge_properties = {}\n name = name if name else \"FromBipartiteGraph\"\n\n for node in bipartite_graph.nodes(data=True):\n node_id, attrs = node\n if node_id.startswith(hyperedge_prefix):\n # It's a hyperedge\n hyperedges[node_id] = HyperEdge(id=node_id, nodes=set(), properties=attrs)\n hyperedge_properties[node_id] = copy.deepcopy(attrs)\n elif node_id.startswith(node_prefix):\n # It's a regular node\n node_properties[node_id] = copy.deepcopy(attrs)\n else:\n raise ValueError(f\"Node '{node_id}' does not start with either hyperedge_prefix '{hyperedge_prefix}' or node_prefix '{node_prefix}'.\")\n\n # Assign nodes to hyperedges based on edges in bipartite graph\n for he_id in hyperedges:\n connected_nodes = set(bipartite_graph.neighbors(he_id))\n hyperedges[he_id].nodes = connected_nodes\n\n # Construct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in hyperedges.items()\n }\n\n return cls(\n hyperedges=hyperedges_dict,\n node_properties=node_properties,\n hyperedge_properties=hyperedge_properties,\n name=name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__contains__","title":"__contains__(item)
","text":"Checks if a node is in the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__contains__--parameters","title":"Parameters","text":"item : Any The node identifier to check.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__contains__--returns","title":"Returns","text":"bool True if the node exists in the hypergraph, False otherwise.
Source code in src/aeiva/hypergraph/hypergraph.py
def __contains__(self, item: Any) -> bool:\n \"\"\"\n Checks if a node is in the hypergraph.\n\n Parameters\n ----------\n item : Any\n The node identifier to check.\n\n Returns\n -------\n bool\n True if the node exists in the hypergraph, False otherwise.\n \"\"\"\n return item in self.node_properties\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__eq__","title":"__eq__(other)
","text":"Checks if two hypergraphs are equal based on their hyperedges and nodes.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__eq__--parameters","title":"Parameters","text":"other : Any The other object to compare.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__eq__--returns","title":"Returns","text":"bool True if both hypergraphs have identical nodes and hyperedges with the same properties, False otherwise.
Source code in src/aeiva/hypergraph/hypergraph.py
def __eq__(self, other: Any) -> bool:\n \"\"\"\n Checks if two hypergraphs are equal based on their hyperedges and nodes.\n\n Parameters\n ----------\n other : Any\n The other object to compare.\n\n Returns\n -------\n bool\n True if both hypergraphs have identical nodes and hyperedges with the same properties, False otherwise.\n \"\"\"\n if not isinstance(other, Hypergraph):\n return False\n\n # Compare nodes and their properties\n if self.node_properties != other.node_properties:\n return False\n\n # Compare hyperedges and their properties\n if self.hyperedges.keys() != other.hyperedges.keys():\n return False\n\n for he_id in self.hyperedges:\n if self.hyperedges[he_id].nodes != other.hyperedges[he_id].nodes:\n return False\n if self.hyperedge_properties.get(he_id, {}) != other.hyperedge_properties.get(he_id, {}):\n return False\n\n return True\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__getitem__","title":"__getitem__(node)
","text":"Retrieves the neighbors of a node in the hypergraph.
Neighbors are nodes that share at least one hyperedge with the given node.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__getitem__--parameters","title":"Parameters","text":"node : Any The node identifier.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__getitem__--returns","title":"Returns","text":"Iterable[Any] An iterator over neighboring node identifiers.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__getitem__--raises","title":"Raises","text":"HypergraphError If the node does not exist in the hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def __getitem__(self, node: Any) -> Iterable[Any]:\n \"\"\"\n Retrieves the neighbors of a node in the hypergraph.\n\n Neighbors are nodes that share at least one hyperedge with the given node.\n\n Parameters\n ----------\n node : Any\n The node identifier.\n\n Returns\n -------\n Iterable[Any]\n An iterator over neighboring node identifiers.\n\n Raises\n ------\n HypergraphError\n If the node does not exist in the hypergraph.\n \"\"\"\n if node not in self.node_properties:\n raise HypergraphError(f\"Node '{node}' does not exist in the hypergraph.\")\n\n # Get all hyperedges that include the node\n hyperedges = set(self.graph.neighbors(node))\n\n # Get all nodes connected by these hyperedges\n neighbors = set()\n for he_id in hyperedges:\n neighbors.update(self.hyperedges[he_id].nodes)\n\n neighbors.discard(node) # Remove the node itself\n return neighbors\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__iter__","title":"__iter__()
","text":"Allows iteration over the nodes of the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__iter__--yields","title":"Yields","text":"Any Node identifiers.
Source code in src/aeiva/hypergraph/hypergraph.py
def __iter__(self) -> Iterator[Any]:\n \"\"\"\n Allows iteration over the nodes of the hypergraph.\n\n Yields\n ------\n Any\n Node identifiers.\n \"\"\"\n return iter(self.node_properties)\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__len__","title":"__len__()
","text":"Returns the number of nodes in the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__len__--returns","title":"Returns","text":"int Number of nodes.
Source code in src/aeiva/hypergraph/hypergraph.py
def __len__(self) -> int:\n \"\"\"\n Returns the number of nodes in the hypergraph.\n\n Returns\n -------\n int\n Number of nodes.\n \"\"\"\n return len(self.node_properties)\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__repr__","title":"__repr__()
","text":"Official string representation of the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__repr__--returns","title":"Returns","text":"str A detailed string describing the hypergraph with its name, number of nodes, and hyperedges.
Source code in src/aeiva/hypergraph/hypergraph.py
def __repr__(self) -> str:\n \"\"\"\n Official string representation of the hypergraph.\n\n Returns\n -------\n str\n A detailed string describing the hypergraph with its name, number of nodes, and hyperedges.\n \"\"\"\n return (\n f\"Hypergraph(name={self.name!r}, \"\n f\"nodes={len(self)}, hyperedges={len(self.hyperedges)})\"\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__str__","title":"__str__()
","text":"String representation of the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.__str__--returns","title":"Returns","text":"str A string describing the hypergraph with its name, number of nodes, and hyperedges.
Source code in src/aeiva/hypergraph/hypergraph.py
def __str__(self) -> str:\n \"\"\"\n String representation of the hypergraph.\n\n Returns\n -------\n str\n A string describing the hypergraph with its name, number of nodes, and hyperedges.\n \"\"\"\n return f\"Hypergraph '{self.name}' with {len(self)} nodes and {len(self.hyperedges)} hyperedges.\"\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_hyperedge","title":"add_hyperedge(he_id, nodes, properties=None)
","text":"Adds a hyperedge to the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_hyperedge--parameters","title":"Parameters","text":"he_id : Any Unique identifier for the hyperedge. nodes : Iterable[Any] Nodes connected by the hyperedge. properties : Optional[Dict[str, Any]] = None Properties of the hyperedge.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_hyperedge--raises","title":"Raises","text":"HypergraphError If the hyperedge ID already exists.
Source code in src/aeiva/hypergraph/hypergraph.py
def add_hyperedge(\n self,\n he_id: Any,\n nodes: Iterable[Any],\n properties: Optional[Dict[str, Any]] = None\n) -> None:\n \"\"\"\n Adds a hyperedge to the hypergraph.\n\n Parameters\n ----------\n he_id : Any\n Unique identifier for the hyperedge.\n nodes : Iterable[Any]\n Nodes connected by the hyperedge.\n properties : Optional[Dict[str, Any]] = None\n Properties of the hyperedge.\n\n Raises\n ------\n HypergraphError\n If the hyperedge ID already exists.\n \"\"\"\n if he_id in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' already exists.\")\n\n hyperedge = HyperEdge(id=he_id, nodes=nodes, properties=properties)\n self.hyperedges[he_id] = hyperedge\n self.hyperedge_properties[he_id] = copy.deepcopy(properties) if properties else {}\n\n # Add hyperedge to bipartite graph\n self.graph.add_node(he_id, bipartite='hyperedge', **self.hyperedge_properties[he_id])\n self.bipartite_nodes.add(he_id)\n\n # Add edges between hyperedge and nodes\n for node in hyperedge.nodes:\n if node not in self.graph:\n self.graph.add_node(node, bipartite='node', **self.node_properties.get(node, {}))\n self.graph.add_edge(he_id, node)\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_hyperedges_from","title":"add_hyperedges_from(hyperedges, inplace=True)
","text":"Adds multiple hyperedges with attributes to the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_hyperedges_from--parameters","title":"Parameters","text":"hyperedges : Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]] An iterable of hyperedge identifiers or tuples of (he_id, attributes). inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added hyperedges.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_hyperedges_from--returns","title":"Returns","text":"Hypergraph The updated or new Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_hyperedges_from--raises","title":"Raises","text":"HypergraphError If any hyperedge ID already exists. ValueError If any tuple does not contain exactly two elements or if attributes are not dictionaries.
Source code in src/aeiva/hypergraph/hypergraph.py
def add_hyperedges_from(\n self,\n hyperedges: Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]],\n inplace: bool = True\n) -> 'Hypergraph':\n \"\"\"\n Adds multiple hyperedges with attributes to the hypergraph.\n\n Parameters\n ----------\n hyperedges : Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]]\n An iterable of hyperedge identifiers or tuples of (he_id, attributes).\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added hyperedges.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any hyperedge ID already exists.\n ValueError\n If any tuple does not contain exactly two elements or if attributes are not dictionaries.\n \"\"\"\n new_hyperedges = []\n for item in hyperedges:\n if isinstance(item, tuple):\n if len(item) != 2 or not isinstance(item[1], dict):\n raise ValueError(f\"Each tuple must be of the form (he_id, attributes). Invalid tuple: {item}\")\n he_id, attrs = item\n else:\n he_id, attrs = item, {}\n\n if he_id in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' already exists.\")\n\n hyperedge = HyperEdge(id=he_id, nodes=[], properties=attrs.copy())\n new_hyperedges.append(hyperedge)\n\n if inplace:\n for hyperedge in new_hyperedges:\n self.hyperedges[hyperedge.id] = hyperedge\n self.hyperedge_properties[hyperedge.id] = copy.deepcopy(hyperedge.properties)\n self.graph.add_node(hyperedge.id, bipartite='hyperedge', **self.hyperedge_properties[hyperedge.id])\n self.bipartite_nodes.add(hyperedge.id)\n return self\n else:\n # Create a new Hypergraph instance with added hyperedges\n new_hyperedges_dict = copy.deepcopy(self.hyperedges)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for hyperedge in new_hyperedges:\n new_hyperedges_dict[hyperedge.id] = hyperedge\n new_hyperedge_properties[hyperedge.id] = copy.deepcopy(hyperedge.properties)\n new_graph.add_node(hyperedge.id, bipartite='hyperedge', **new_hyperedge_properties[hyperedge.id])\n new_bipartite_nodes.add(hyperedge.id)\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges_dict.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_incidence","title":"add_incidence(he_id, node_id, attributes=None, inplace=True)
","text":"Adds a single incidence with attributes to the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_incidence--parameters","title":"Parameters","text":"he_id : Any Identifier of the hyperedge. node_id : Any Identifier of the node. attributes : Optional[Dict[str, Any]] = None Properties to add to the incidence as key-value pairs. inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added incidence.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_incidence--returns","title":"Returns","text":"Hypergraph The updated or new Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_incidence--raises","title":"Raises","text":"HypergraphError If the hyperedge or node does not exist, or if the incidence already exists.
Source code in src/aeiva/hypergraph/hypergraph.py
def add_incidence(\n self,\n he_id: Any,\n node_id: Any,\n attributes: Optional[Dict[str, Any]] = None,\n inplace: bool = True\n) -> 'Hypergraph':\n \"\"\"\n Adds a single incidence with attributes to the hypergraph.\n\n Parameters\n ----------\n he_id : Any\n Identifier of the hyperedge.\n node_id : Any\n Identifier of the node.\n attributes : Optional[Dict[str, Any]] = None\n Properties to add to the incidence as key-value pairs.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added incidence.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If the hyperedge or node does not exist, or if the incidence already exists.\n \"\"\"\n if he_id not in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' does not exist in the hypergraph.\")\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n if node_id in self.hyperedges[he_id].nodes:\n raise HypergraphError(f\"Incidence between hyperedge '{he_id}' and node '{node_id}' already exists.\")\n\n if inplace:\n # Add node to HyperEdge's nodes\n self.hyperedges[he_id].add_node(node_id)\n # Update hyperedge_properties if attributes provided\n if attributes:\n self.hyperedge_properties[he_id].update(attributes)\n # Add edge in graph with attributes\n self.graph.add_edge(he_id, node_id, **(attributes if attributes else {}))\n return self\n else:\n # Create a new Hypergraph instance with the incidence added\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n # Add node to HyperEdge's nodes\n new_hyperedges[he_id].add_node(node_id)\n # Update hyperedge_properties if attributes provided\n if attributes:\n new_hyperedge_properties[he_id].update(attributes)\n # Add edge in graph with attributes\n new_graph.add_edge(he_id, node_id, **(attributes if attributes else {}))\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_incidences_from","title":"add_incidences_from(incidences, inplace=True)
","text":"Adds a collection of incidences to the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_incidences_from--parameters","title":"Parameters","text":"incidences : Iterable[Union[Tuple[Any, Any], Tuple[Any, Any, Dict[str, Any]]]] Incidence tuples as: - (he_id, node_id) - (he_id, node_id, attributes)
bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added incidences.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_incidences_from--returns","title":"Returns","text":"Hypergraph The updated or new Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_incidences_from--raises","title":"Raises","text":"HypergraphError If any hyperedge or node does not exist, or if any incidence already exists. ValueError If the structure of any incidence tuple is invalid.
Source code in src/aeiva/hypergraph/hypergraph.py
def add_incidences_from(\n self,\n incidences: Iterable[Union[Tuple[Any, Any], Tuple[Any, Any, Dict[str, Any]]]],\n inplace: bool = True\n) -> 'Hypergraph':\n \"\"\"\n Adds a collection of incidences to the hypergraph.\n\n Parameters\n ----------\n incidences : Iterable[Union[Tuple[Any, Any], Tuple[Any, Any, Dict[str, Any]]]]\n Incidence tuples as:\n - (he_id, node_id)\n - (he_id, node_id, attributes)\n\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added incidences.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any hyperedge or node does not exist, or if any incidence already exists.\n ValueError\n If the structure of any incidence tuple is invalid.\n \"\"\"\n new_incidences = []\n for pr in incidences:\n if not isinstance(pr, tuple):\n raise ValueError(f\"Each incidence must be a tuple, got {type(pr)}\")\n if len(pr) == 2:\n he_id, node_id = pr\n attrs = {}\n elif len(pr) == 3:\n he_id, node_id, attrs = pr\n if not isinstance(attrs, dict):\n raise ValueError(f\"Attributes must be a dictionary, got {type(attrs)}\")\n else:\n raise ValueError(f\"Incidence tuples must be of length 2 or 3, got {len(pr)}\")\n\n if he_id not in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' does not exist in the hypergraph.\")\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n if node_id in self.hyperedges[he_id].nodes:\n raise HypergraphError(f\"Incidence between hyperedge '{he_id}' and node '{node_id}' already exists.\")\n\n new_incidences.append((he_id, node_id, attrs.copy()))\n\n if inplace:\n for he_id, node_id, attrs in new_incidences:\n # Add node to HyperEdge's nodes\n self.hyperedges[he_id].add_node(node_id)\n # Update hyperedge_properties if attributes provided\n if attrs:\n self.hyperedge_properties[he_id].update(attrs)\n # Add edge in graph with attributes\n self.graph.add_edge(he_id, node_id, **(attrs if attrs else {}))\n return self\n else:\n # Create a new Hypergraph instance with the incidences added\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for he_id, node_id, attrs in new_incidences:\n # Add node to HyperEdge's nodes\n new_hyperedges[he_id].add_node(node_id)\n # Update hyperedge_properties if attributes provided\n if attrs:\n new_hyperedge_properties[he_id].update(attrs)\n # Add edge in graph with attributes\n new_graph.add_edge(he_id, node_id, **(attrs if attrs else {}))\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_node","title":"add_node(node_id, properties=None, inplace=True)
","text":"Adds a node to the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_node--parameters","title":"Parameters","text":"node_id : Any Identifier for the node. properties : Optional[Dict[str, Any]] = None Properties of the node. inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added node.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_node--returns","title":"Returns","text":"Hypergraph The updated or new Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_node--raises","title":"Raises","text":"HypergraphError If the node ID already exists.
Source code in src/aeiva/hypergraph/hypergraph.py
def add_node(\n self,\n node_id: Any,\n properties: Optional[Dict[str, Any]] = None,\n inplace: bool = True\n) -> 'Hypergraph':\n \"\"\"\n Adds a node to the hypergraph.\n\n Parameters\n ----------\n node_id : Any\n Identifier for the node.\n properties : Optional[Dict[str, Any]] = None\n Properties of the node.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added node.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If the node ID already exists.\n \"\"\"\n if node_id in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' already exists in the hypergraph.\")\n\n if inplace:\n self.node_properties[node_id] = copy.deepcopy(properties) if properties else {}\n self.graph.add_node(node_id, bipartite='node', **self.node_properties[node_id])\n return self\n else:\n # Create a new Hypergraph instance with the added node\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n new_node_properties[node_id] = copy.deepcopy(properties) if properties else {}\n new_graph.add_node(node_id, bipartite='node', **new_node_properties[node_id])\n\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_nodes_from","title":"add_nodes_from(nodes, inplace=True)
","text":"Adds multiple nodes with attributes to the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_nodes_from--parameters","title":"Parameters","text":"nodes : Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]] An iterable of node identifiers or tuples of (node_id, attributes). inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added nodes.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_nodes_from--returns","title":"Returns","text":"Hypergraph The updated or new Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.add_nodes_from--raises","title":"Raises","text":"HypergraphError If any node ID already exists. ValueError If any tuple does not contain exactly two elements or if attributes are not dictionaries.
Source code in src/aeiva/hypergraph/hypergraph.py
def add_nodes_from(\n self,\n nodes: Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]],\n inplace: bool = True\n) -> 'Hypergraph':\n \"\"\"\n Adds multiple nodes with attributes to the hypergraph.\n\n Parameters\n ----------\n nodes : Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]]\n An iterable of node identifiers or tuples of (node_id, attributes).\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the added nodes.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any node ID already exists.\n ValueError\n If any tuple does not contain exactly two elements or if attributes are not dictionaries.\n \"\"\"\n new_nodes = {}\n for item in nodes:\n if isinstance(item, tuple):\n if len(item) != 2 or not isinstance(item[1], dict):\n raise ValueError(f\"Each tuple must be of the form (node_id, attributes). Invalid tuple: {item}\")\n node_id, attrs = item\n else:\n node_id, attrs = item, {}\n\n if node_id in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' already exists in the hypergraph.\")\n\n new_nodes[node_id] = copy.deepcopy(attrs)\n\n if inplace:\n for node_id, attrs in new_nodes.items():\n self.node_properties[node_id] = attrs\n self.graph.add_node(node_id, bipartite='node', **self.node_properties[node_id])\n return self\n else:\n # Create a new Hypergraph instance with the added nodes\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for node_id, attrs in new_nodes.items():\n new_node_properties[node_id] = attrs\n new_graph.add_node(node_id, bipartite='node', **new_node_properties[node_id])\n\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.adjacency_matrix","title":"adjacency_matrix(s=1, index=False)
","text":"Generates the adjacency matrix for nodes based on s-node connectivity.
Source code in src/aeiva/hypergraph/hypergraph.py
def adjacency_matrix(self, s: int = 1, index: bool = False) -> Tuple[Optional[csr_matrix], Dict[int, Any]]:\n \"\"\"\n Generates the adjacency matrix for nodes based on s-node connectivity.\n \"\"\"\n from scipy.sparse import lil_matrix\n\n node_ids = list(self.node_properties.keys())\n node_index = {node_id: idx for idx, node_id in enumerate(node_ids)}\n size = len(node_ids)\n if size == 0:\n return None, {}\n\n A = lil_matrix((size, size), dtype=int)\n for he in self.hyperedges.values():\n nodes = list(he.nodes)\n for i in range(len(nodes)):\n for j in range(i + 1, len(nodes)):\n A[node_index[nodes[i]], node_index[nodes[j]]] += 1\n\n # Apply the threshold s and convert to binary\n A = (A >= s).astype(int)\n A = A.tocsr()\n\n if index:\n return A, node_index\n return A, {}\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.collapse_duplicate_hyperedges","title":"collapse_duplicate_hyperedges(name=None, use_uids=None, use_counts=False, return_counts=True, return_equivalence_classes=False, aggregate_properties_by=None)
","text":"Collapses duplicate hyperedges (hyperedges with identical node memberships) into single hyperedges.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.collapse_duplicate_hyperedges--parameters","title":"Parameters","text":"name : Optional[str], default=None The name assigned to the collapsed hypergraph. If None, defaults to the original name suffixed with '_collapsed_hyperedges'.
Optional[List[Any]] = None Specifies the hyperedge identifiers to use as representatives for each equivalence class. If two identifiers occur in the same equivalence class, the first one found in use_uids
is used. If None, the first encountered hyperedge in each class is used as the representative.
bool, optional, default=False If True, renames the equivalence class representatives by appending the size of the class (e.g., 'HE1:3').
bool, optional, default=True If True, adds the size of each equivalence class to the properties of the representative hyperedge under the key 'equivalence_class_size'.
bool, optional, default=False If True, returns a tuple containing the new collapsed hypergraph and a dictionary mapping representatives to their equivalence classes.
Optional[Dict[str, str]] = None A dictionary specifying aggregation methods for hyperedge properties. Keys are property names, and values are aggregation functions (e.g., {'weight': 'sum'}). Properties not specified will use the 'first' aggregation.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.collapse_duplicate_hyperedges--returns","title":"Returns","text":"Hypergraph or Tuple[Hypergraph, Dict[Any, Set[Any]]] - If return_equivalence_classes=False
, returns the new collapsed hypergraph. - If return_equivalence_classes=True
, returns a tuple containing the collapsed hypergraph and a dictionary of equivalence classes.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.collapse_duplicate_hyperedges--raises","title":"Raises","text":"HypergraphError If the hypergraph is empty or improperly structured.
Source code in src/aeiva/hypergraph/hypergraph.py
def collapse_duplicate_hyperedges(\n self,\n name: Optional[str] = None,\n use_uids: Optional[List[Any]] = None,\n use_counts: bool = False,\n return_counts: bool = True,\n return_equivalence_classes: bool = False,\n aggregate_properties_by: Optional[Dict[str, str]] = None,\n) -> Union['Hypergraph', Tuple['Hypergraph', Dict[Any, Set[Any]]]]:\n \"\"\"\n Collapses duplicate hyperedges (hyperedges with identical node memberships) into single hyperedges.\n\n Parameters\n ----------\n name : Optional[str], default=None\n The name assigned to the collapsed hypergraph. If None, defaults to the original name suffixed with '_collapsed_hyperedges'.\n\n use_uids : Optional[List[Any]] = None\n Specifies the hyperedge identifiers to use as representatives for each equivalence class.\n If two identifiers occur in the same equivalence class, the first one found in `use_uids` is used.\n If None, the first encountered hyperedge in each class is used as the representative.\n\n use_counts : bool, optional, default=False\n If True, renames the equivalence class representatives by appending the size of the class (e.g., 'HE1:3').\n\n return_counts : bool, optional, default=True\n If True, adds the size of each equivalence class to the properties of the representative hyperedge under the key 'equivalence_class_size'.\n\n return_equivalence_classes : bool, optional, default=False\n If True, returns a tuple containing the new collapsed hypergraph and a dictionary mapping representatives to their equivalence classes.\n\n aggregate_properties_by : Optional[Dict[str, str]] = None\n A dictionary specifying aggregation methods for hyperedge properties. Keys are property names, and values are aggregation functions (e.g., {'weight': 'sum'}).\n Properties not specified will use the 'first' aggregation.\n\n Returns\n -------\n Hypergraph or Tuple[Hypergraph, Dict[Any, Set[Any]]]\n - If `return_equivalence_classes=False`, returns the new collapsed hypergraph.\n - If `return_equivalence_classes=True`, returns a tuple containing the collapsed hypergraph and a dictionary of equivalence classes.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is empty or improperly structured.\n \"\"\"\n if not self.hyperedges:\n raise HypergraphError(\"Cannot collapse hyperedges in an empty hypergraph.\")\n\n # Identify equivalence classes based on identical node memberships\n membership_to_hyperedges: Dict[frozenset, Set[Any]] = {}\n for he_id, hyperedge in self.hyperedges.items():\n key = frozenset(hyperedge.nodes)\n membership_to_hyperedges.setdefault(key, set()).add(he_id)\n\n # Filter out classes with only one hyperedge (no duplicates)\n equivalence_classes = [hes for hes in membership_to_hyperedges.values() if len(hes) > 1]\n if not equivalence_classes:\n # No duplicates to collapse; return the original hypergraph\n return self if not return_equivalence_classes else (self, {})\n\n # Prepare aggregation methods\n aggregate_properties_by = aggregate_properties_by if aggregate_properties_by is not None else {\"weight\": \"sum\"}\n\n # Initialize mapping from old hyperedges to new hyperedges\n hyperedge_mapping: Dict[Any, Any] = {}\n equivalence_class_dict: Dict[Any, Set[Any]] = {}\n\n for eq_class in equivalence_classes:\n # Determine representative\n if use_uids:\n # Select the first UID from use_uids that is in the equivalence class\n representative = next((uid for uid in use_uids if uid in eq_class), None)\n if not representative:\n # Fallback to the first hyperedge in the equivalence class\n representative = next(iter(eq_class))\n else:\n # Use the first hyperedge in the equivalence class as representative\n representative = next(iter(eq_class))\n\n # Optionally rename with counts\n if use_counts:\n new_representative = f\"{representative}:{len(eq_class)}\"\n else:\n new_representative = representative\n\n # Map all hyperedges in the class to the representative\n for he in eq_class:\n hyperedge_mapping[he] = new_representative\n\n # Store the equivalence class\n equivalence_class_dict[new_representative] = eq_class\n\n # Replace hyperedge IDs in incidences based on mapping\n new_hyperedges = {}\n for he_id, hyperedge in self.hyperedges.items():\n new_he_id = hyperedge_mapping.get(he_id, he_id)\n if new_he_id not in new_hyperedges:\n new_hyperedges[new_he_id] = HyperEdge(id=new_he_id, nodes=hyperedge.nodes.copy(), properties=copy.deepcopy(hyperedge.properties))\n else:\n new_hyperedges[new_he_id].nodes.update(hyperedge.nodes)\n\n # Aggregate hyperedge properties\n for he_id, hyperedge in new_hyperedges.items():\n if he_id in equivalence_class_dict:\n aggregated_props = {}\n for prop, agg_func in aggregate_properties_by.items():\n values = [self.hyperedge_properties[old_he].get(prop, 0) for old_he in equivalence_class_dict[he_id]]\n if agg_func == 'sum':\n aggregated_props[prop] = sum(values)\n elif agg_func == 'mean':\n aggregated_props[prop] = sum(values) / len(values) if values else 0\n elif agg_func == 'max':\n aggregated_props[prop] = max(values) if values else None\n elif agg_func == 'min':\n aggregated_props[prop] = min(values) if values else None\n else:\n aggregated_props[prop] = values[0] if values else None # Default to first\n new_hyperedges[he_id].properties.update(aggregated_props)\n\n # Handle equivalence class size\n if use_counts:\n for he_id in equivalence_class_dict:\n new_hyperedges[he_id].properties['equivalence_class_size'] = len(equivalence_class_dict[he_id])\n elif return_counts:\n for he_id in new_hyperedges:\n if he_id in equivalence_class_dict:\n new_hyperedges[he_id].properties['equivalence_class_size'] = len(equivalence_class_dict[he_id])\n else:\n new_hyperedges[he_id].properties['equivalence_class_size'] = 1\n\n # Initialize the collapsed hypergraph\n collapsed_hypergraph = Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=copy.deepcopy(self.node_properties),\n hyperedge_properties={\n he_id: copy.deepcopy(he.properties) for he_id, he in new_hyperedges.items()\n },\n name=name if name else f\"{self.name}_collapsed_hyperedges\"\n )\n\n if return_equivalence_classes:\n return collapsed_hypergraph, equivalence_class_dict\n else:\n return collapsed_hypergraph\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.collapse_duplicate_nodes","title":"collapse_duplicate_nodes(name=None, use_uids=None, use_counts=False, return_counts=True, return_equivalence_classes=False, aggregate_properties_by=None)
","text":"Collapses duplicate nodes (nodes with identical hyperedge memberships) into single nodes.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.collapse_duplicate_nodes--parameters","title":"Parameters","text":"name : Optional[str], default=None The name assigned to the collapsed hypergraph. If None, defaults to the original name suffixed with '_collapsed_nodes'.
Optional[List[Any]] = None Specifies the node identifiers to use as representatives for each equivalence class. If two identifiers occur in the same equivalence class, the first one found in use_uids
is used. If None, the first encountered node in each class is used as the representative.
bool, optional, default=False If True, renames the equivalence class representatives by appending the size of the class (e.g., 'N1:3').
bool, optional, default=True If True, adds the size of each equivalence class to the properties of the representative node under the key 'equivalence_class_size'.
bool, optional, default=False If True, returns a tuple containing the new collapsed hypergraph and a dictionary mapping representatives to their equivalence classes.
Optional[Dict[str, str]] = None A dictionary specifying aggregation methods for node properties. Keys are property names, and values are aggregation functions (e.g., {'weight': 'sum'}). Properties not specified will use the 'first' aggregation.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.collapse_duplicate_nodes--returns","title":"Returns","text":"Hypergraph or Tuple[Hypergraph, Dict[Any, Set[Any]]] - If return_equivalence_classes=False
, returns the new collapsed hypergraph. - If return_equivalence_classes=True
, returns a tuple containing the collapsed hypergraph and a dictionary of equivalence classes.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.collapse_duplicate_nodes--raises","title":"Raises","text":"HypergraphError If the hypergraph is empty or improperly structured.
Source code in src/aeiva/hypergraph/hypergraph.py
def collapse_duplicate_nodes(\n self,\n name: Optional[str] = None,\n use_uids: Optional[List[Any]] = None,\n use_counts: bool = False,\n return_counts: bool = True,\n return_equivalence_classes: bool = False,\n aggregate_properties_by: Optional[Dict[str, str]] = None,\n) -> Union['Hypergraph', Tuple['Hypergraph', Dict[Any, Set[Any]]]]:\n \"\"\"\n Collapses duplicate nodes (nodes with identical hyperedge memberships) into single nodes.\n\n Parameters\n ----------\n name : Optional[str], default=None\n The name assigned to the collapsed hypergraph. If None, defaults to the original name suffixed with '_collapsed_nodes'.\n\n use_uids : Optional[List[Any]] = None\n Specifies the node identifiers to use as representatives for each equivalence class.\n If two identifiers occur in the same equivalence class, the first one found in `use_uids` is used.\n If None, the first encountered node in each class is used as the representative.\n\n use_counts : bool, optional, default=False\n If True, renames the equivalence class representatives by appending the size of the class (e.g., 'N1:3').\n\n return_counts : bool, optional, default=True\n If True, adds the size of each equivalence class to the properties of the representative node under the key 'equivalence_class_size'.\n\n return_equivalence_classes : bool, optional, default=False\n If True, returns a tuple containing the new collapsed hypergraph and a dictionary mapping representatives to their equivalence classes.\n\n aggregate_properties_by : Optional[Dict[str, str]] = None\n A dictionary specifying aggregation methods for node properties. Keys are property names, and values are aggregation functions (e.g., {'weight': 'sum'}).\n Properties not specified will use the 'first' aggregation.\n\n Returns\n -------\n Hypergraph or Tuple[Hypergraph, Dict[Any, Set[Any]]]\n - If `return_equivalence_classes=False`, returns the new collapsed hypergraph.\n - If `return_equivalence_classes=True`, returns a tuple containing the collapsed hypergraph and a dictionary of equivalence classes.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is empty or improperly structured.\n \"\"\"\n if not self.node_properties:\n raise HypergraphError(\"Cannot collapse nodes in an empty hypergraph.\")\n\n # Identify equivalence classes based on identical hyperedge memberships\n membership_to_nodes: Dict[frozenset, Set[Any]] = {}\n for node_id, node_props in self.node_properties.items():\n key = frozenset(self.get_hyperedges_of_node(node_id))\n membership_to_nodes.setdefault(key, set()).add(node_id)\n\n # Filter out classes with only one node (no duplicates)\n equivalence_classes = [nodes for nodes in membership_to_nodes.values() if len(nodes) > 1]\n if not equivalence_classes:\n # No duplicates to collapse; return the original hypergraph\n return self if not return_equivalence_classes else (self, {})\n\n # Prepare aggregation methods\n aggregate_properties_by = aggregate_properties_by if aggregate_properties_by is not None else {\"weight\": \"sum\"}\n\n # Initialize mapping from old nodes to new nodes\n node_mapping: Dict[Any, Any] = {}\n equivalence_class_dict: Dict[Any, Set[Any]] = {}\n\n for eq_class in equivalence_classes:\n # Determine representative\n if use_uids:\n # Select the first UID from use_uids that is in the equivalence class\n representative = next((uid for uid in use_uids if uid in eq_class), None)\n if not representative:\n # Fallback to the first node in the equivalence class\n representative = next(iter(eq_class))\n else:\n # Use the first node in the equivalence class as representative\n representative = next(iter(eq_class))\n\n # Optionally rename with counts\n if use_counts:\n new_representative = f\"{representative}:{len(eq_class)}\"\n else:\n new_representative = representative\n\n # Map all nodes in the class to the representative\n for node in eq_class:\n node_mapping[node] = new_representative\n\n # Store the equivalence class\n equivalence_class_dict[new_representative] = eq_class\n\n # Replace node IDs in hyperedges based on mapping\n new_hyperedges = {}\n for he_id, hyperedge in self.hyperedges.items():\n new_nodes = set()\n for node_id in hyperedge.nodes:\n new_node_id = node_mapping.get(node_id, node_id)\n new_nodes.add(new_node_id)\n new_hyperedges[he_id] = HyperEdge(id=he_id, nodes=new_nodes, properties=copy.deepcopy(hyperedge.properties))\n\n # Aggregate node properties\n new_node_properties = {}\n for node_id, node_props in self.node_properties.items():\n new_node_id = node_mapping.get(node_id, node_id)\n if new_node_id not in new_node_properties:\n new_node_properties[new_node_id] = copy.deepcopy(node_props)\n else:\n for prop, agg_func in aggregate_properties_by.items():\n if prop in node_props:\n if agg_func == 'sum':\n new_node_properties[new_node_id][prop] = new_node_properties[new_node_id].get(prop, 0) + node_props[prop]\n elif agg_func == 'mean':\n # To calculate mean, store sum and count\n if 'sum_' + prop not in new_node_properties[new_node_id]:\n new_node_properties[new_node_id]['sum_' + prop] = node_props[prop]\n new_node_properties[new_node_id]['count_' + prop] = 1\n else:\n new_node_properties[new_node_id]['sum_' + prop] += node_props[prop]\n new_node_properties[new_node_id]['count_' + prop] += 1\n # Calculate mean at the end\n elif agg_func == 'max':\n current_max = new_node_properties[new_node_id].get(prop, float('-inf'))\n new_node_properties[new_node_id][prop] = max(current_max, node_props[prop])\n elif agg_func == 'min':\n current_min = new_node_properties[new_node_id].get(prop, float('inf'))\n new_node_properties[new_node_id][prop] = min(current_min, node_props[prop])\n else:\n new_node_properties[new_node_id][prop] = node_props[prop] # Default to last\n # Finalize mean calculations\n for node_id, props in new_node_properties.items():\n for prop in list(props.keys()):\n if prop.startswith('sum_'):\n base_prop = prop[4:]\n sum_val = props[prop]\n count_val = props.get('count_' + base_prop, 1)\n new_node_properties[node_id][base_prop] = sum_val / count_val if count_val > 0 else 0\n del new_node_properties[node_id][prop]\n del new_node_properties[node_id]['count_' + base_prop]\n\n # Handle equivalence class size\n if use_counts:\n for node_id in equivalence_class_dict:\n new_node_properties[node_id]['equivalence_class_size'] = len(equivalence_class_dict[node_id])\n elif return_counts:\n for node_id in new_node_properties:\n if node_id in equivalence_class_dict:\n new_node_properties[node_id]['equivalence_class_size'] = len(equivalence_class_dict[node_id])\n else:\n new_node_properties[node_id]['equivalence_class_size'] = 1\n\n # Initialize the collapsed hypergraph\n collapsed_hypergraph = Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties={\n he_id: copy.deepcopy(he.properties) for he_id, he in new_hyperedges.items()\n },\n name=name if name else f\"{self.name}_collapsed_nodes\"\n )\n\n if return_equivalence_classes:\n return collapsed_hypergraph, equivalence_class_dict\n else:\n return collapsed_hypergraph\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_hyperedge_diameter","title":"compute_hyperedge_diameter(s=1)
","text":"Returns the diameter of the hypergraph based on s-hyperedge connectivity.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_hyperedge_diameter--parameters","title":"Parameters","text":"s : int, optional, default=1 The number of shared nodes required for hyperedges to be considered adjacent.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_hyperedge_diameter--returns","title":"Returns","text":"int The diameter of the hypergraph based on hyperedge connectivity.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_hyperedge_diameter--raises","title":"Raises","text":"HypergraphError If the hypergraph is not s-hyperedge-connected or has no hyperedges.
Source code in src/aeiva/hypergraph/hypergraph.py
def compute_hyperedge_diameter(self, s: int = 1) -> int:\n \"\"\"\n Returns the diameter of the hypergraph based on s-hyperedge connectivity.\n\n Parameters\n ----------\n s : int, optional, default=1\n The number of shared nodes required for hyperedges to be considered adjacent.\n\n Returns\n -------\n int\n The diameter of the hypergraph based on hyperedge connectivity.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is not s-hyperedge-connected or has no hyperedges.\n \"\"\"\n A, _ = self.hyperedge_adjacency_matrix(s=s, index=True)\n if A is None or A.shape[0] == 0:\n raise HypergraphError(\"The hypergraph has no hyperedges to compute diameter.\")\n\n graph = nx.from_scipy_sparse_array(A)\n if not nx.is_connected(graph):\n raise HypergraphError(f\"Hypergraph is not s-hyperedge-connected. s={s}\")\n\n try:\n return nx.diameter(graph)\n except nx.NetworkXError as e:\n raise HypergraphError(f\"Could not compute hyperedge diameter: {e}\")\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_hyperedge_diameters","title":"compute_hyperedge_diameters(s=1)
","text":"Returns the hyperedge diameters of the s-hyperedge-connected component subgraphs in the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_hyperedge_diameters--parameters","title":"Parameters","text":"s : int, optional, default=1 The number of shared nodes required for hyperedges to be considered adjacent.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_hyperedge_diameters--returns","title":"Returns","text":"Tuple[int, List[int], List[Set[Any]]] - Maximum diameter among all s-hyperedge-connected components. - List of diameters for each s-hyperedge connected component. - List of sets, each containing hyperedge IDs in an s-hyperedge connected component.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_hyperedge_diameters--raises","title":"Raises","text":"HypergraphError If the hypergraph is not s-hyperedge-connected or has no hyperedges.
Source code in src/aeiva/hypergraph/hypergraph.py
def compute_hyperedge_diameters(self, s: int = 1) -> Tuple[int, List[int], List[Set[Any]]]:\n \"\"\"\n Returns the hyperedge diameters of the s-hyperedge-connected component subgraphs in the hypergraph.\n\n Parameters\n ----------\n s : int, optional, default=1\n The number of shared nodes required for hyperedges to be considered adjacent.\n\n Returns\n -------\n Tuple[int, List[int], List[Set[Any]]]\n - Maximum diameter among all s-hyperedge-connected components.\n - List of diameters for each s-hyperedge connected component.\n - List of sets, each containing hyperedge IDs in an s-hyperedge connected component.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is not s-hyperedge-connected or has no hyperedges.\n \"\"\"\n A, he_id_map = self.hyperedge_adjacency_matrix(s=s, index=True)\n if A is None or A.shape[0] == 0:\n raise HypergraphError(\"The hypergraph has no hyperedges to compute diameters.\")\n\n graph = nx.from_scipy_sparse_array(A)\n\n if not nx.is_connected(graph):\n raise HypergraphError(f\"Hypergraph is not s-hyperedge-connected. s={s}\")\n\n diams = []\n comps = []\n for component in nx.connected_components(graph):\n subgraph = graph.subgraph(component)\n if len(subgraph) == 1:\n diamc = 0 # Diameter of a single hyperedge is 0\n else:\n try:\n diamc = nx.diameter(subgraph)\n except nx.NetworkXError:\n diamc = float('inf') # Infinite diameter if the subgraph is not connected\n diams.append(diamc)\n component_hyperedges = {he_id_map[he] for he in component}\n comps.append(component_hyperedges)\n\n if not diams:\n raise HypergraphError(\"No connected components found to compute hyperedge diameters.\")\n\n max_diam = max(diams)\n return max_diam, diams, comps\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_node_diameter","title":"compute_node_diameter(s=1)
","text":"Returns the diameter of the hypergraph based on s-node connectivity.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_node_diameter--parameters","title":"Parameters","text":"s : int, optional, default=1 The number of shared hyperedges required for nodes to be considered adjacent.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_node_diameter--returns","title":"Returns","text":"int The diameter of the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_node_diameter--raises","title":"Raises","text":"HypergraphError If the hypergraph is not s-node-connected or has no nodes.
Source code in src/aeiva/hypergraph/hypergraph.py
def compute_node_diameter(self, s: int = 1) -> int:\n \"\"\"\n Returns the diameter of the hypergraph based on s-node connectivity.\n\n Parameters\n ----------\n s : int, optional, default=1\n The number of shared hyperedges required for nodes to be considered adjacent.\n\n Returns\n -------\n int\n The diameter of the hypergraph.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is not s-node-connected or has no nodes.\n \"\"\"\n A, _ = self.adjacency_matrix(s=s, index=True)\n if A is None or A.shape[0] == 0:\n raise HypergraphError(\"The hypergraph has no nodes to compute diameter.\")\n\n graph = nx.from_scipy_sparse_array(A)\n if not nx.is_connected(graph):\n raise HypergraphError(f\"Hypergraph is not s-node-connected. s={s}\")\n\n try:\n return nx.diameter(graph)\n except nx.NetworkXError as e:\n raise HypergraphError(f\"Could not compute diameter: {e}\")\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_node_diameters","title":"compute_node_diameters(s=1)
","text":"Returns the node diameters of the connected components in the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_node_diameters--parameters","title":"Parameters","text":"s : int, optional, default=1 The number of shared hyperedges required for nodes to be considered adjacent.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_node_diameters--returns","title":"Returns","text":"Tuple[int, List[int], List[Set[Any]]] - Maximum diameter among all connected components. - List of diameters for each s-node connected component. - List of sets, each containing node IDs in an s-node connected component.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.compute_node_diameters--raises","title":"Raises","text":"HypergraphError If the hypergraph is not s-connected or has no nodes.
Source code in src/aeiva/hypergraph/hypergraph.py
def compute_node_diameters(self, s: int = 1) -> Tuple[int, List[int], List[Set[Any]]]:\n \"\"\"\n Returns the node diameters of the connected components in the hypergraph.\n\n Parameters\n ----------\n s : int, optional, default=1\n The number of shared hyperedges required for nodes to be considered adjacent.\n\n Returns\n -------\n Tuple[int, List[int], List[Set[Any]]]\n - Maximum diameter among all connected components.\n - List of diameters for each s-node connected component.\n - List of sets, each containing node IDs in an s-node connected component.\n\n Raises\n ------\n HypergraphError\n If the hypergraph is not s-connected or has no nodes.\n \"\"\"\n A, node_id_map = self.adjacency_matrix(s=s, index=True)\n if A is None or A.shape[0] == 0:\n raise HypergraphError(\"The hypergraph has no nodes to compute diameters.\")\n\n graph = nx.from_scipy_sparse_array(A)\n\n if not nx.is_connected(graph):\n raise HypergraphError(f\"Hypergraph is not s-node-connected. s={s}\")\n\n diams = []\n comps = []\n for component in nx.connected_components(graph):\n subgraph = graph.subgraph(component)\n if len(subgraph) == 1:\n diamc = 0 # Diameter of a single node is 0\n else:\n try:\n diamc = nx.diameter(subgraph)\n except nx.NetworkXError:\n diamc = float('inf') # Infinite diameter if the subgraph is not connected\n diams.append(diamc)\n component_nodes = {node_id_map[node] for node in component}\n comps.append(component_nodes)\n\n if not diams:\n raise HypergraphError(\"No connected components found to compute diameters.\")\n\n max_diam = max(diams)\n return max_diam, diams, comps\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.copy","title":"copy(name=None)
","text":"Creates a deep copy of the hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.copy--parameters","title":"Parameters","text":"name : Optional[str], default=None The name for the copied Hypergraph. If not provided, retains the original name.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.copy--returns","title":"Returns","text":"Hypergraph A new Hypergraph instance that is a deep copy of the original.
Source code in src/aeiva/hypergraph/hypergraph.py
def copy(self, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Creates a deep copy of the hypergraph instance.\n\n Parameters\n ----------\n name : Optional[str], default=None\n The name for the copied Hypergraph. If not provided, retains the original name.\n\n Returns\n -------\n Hypergraph\n A new Hypergraph instance that is a deep copy of the original.\n \"\"\"\n\n # Deep copy hyperedges\n hyperedges_dict = {}\n for he_id, he in self.hyperedges.items():\n hyperedges_dict[he_id] = {\n 'nodes': list(he.nodes),\n 'properties': copy.deepcopy(he.properties)\n }\n\n # Deep copy node_properties and hyperedge_properties\n node_properties_copy = copy.deepcopy(self.node_properties)\n hyperedge_properties_copy = copy.deepcopy(self.hyperedge_properties)\n\n # Create a new Hypergraph instance with the copied data\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=node_properties_copy,\n hyperedge_properties=hyperedge_properties_copy,\n name=name if name is not None else self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.deepcopy","title":"deepcopy(name=None)
","text":"Creates a deep copy of the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.deepcopy--parameters","title":"Parameters","text":"name : Optional[str], default=None The name assigned to the cloned hypergraph. If None, defaults to the original hypergraph's name suffixed with '_clone'.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.deepcopy--returns","title":"Returns","text":"Hypergraph A deep copy of the hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def deepcopy(self, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Creates a deep copy of the hypergraph.\n\n Parameters\n ----------\n name : Optional[str], default=None\n The name assigned to the cloned hypergraph. If None, defaults to the original hypergraph's name suffixed with '_clone'.\n\n Returns\n -------\n Hypergraph\n A deep copy of the hypergraph.\n \"\"\"\n\n # Deep copy hyperedges\n hyperedges_copy = {\n he_id: {\n 'nodes': hyperedge.nodes.copy(),\n 'properties': copy.deepcopy(hyperedge.properties)\n }\n for he_id, hyperedge in self.hyperedges.items()\n }\n\n # Deep copy node properties\n node_properties_copy = copy.deepcopy(self.node_properties)\n\n # Deep copy hyperedge properties\n hyperedge_properties_copy = copy.deepcopy(self.hyperedge_properties)\n\n # Set name\n cloned_name = f\"{self.name}_deepcopy\" if name is None else name\n\n # Initialize the cloned hypergraph\n cloned_H = Hypergraph(\n hyperedges=hyperedges_copy,\n node_properties=node_properties_copy,\n hyperedge_properties=hyperedge_properties_copy,\n name=cloned_name\n )\n\n return cloned_H\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.difference","title":"difference(other, inplace=False, name=None)
","text":"Returns the difference of the current hypergraph with another hypergraph. The difference includes nodes and hyperedges present in the current hypergraph but not in the other.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.difference--parameters","title":"Parameters","text":"other : Hypergraph The hypergraph to subtract. inplace : bool, optional, default=False If True, modifies the current hypergraph by removing elements found in other
. Otherwise, returns a new Hypergraph instance. name : Optional[str], default=None The name for the resulting hypergraph. If None, defaults to 'Difference_of_{self.name}_{other.name}'.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.difference--returns","title":"Returns","text":"Hypergraph The resulting difference hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.difference--raises","title":"Raises","text":"TypeError If other
is not an instance of Hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def difference(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Returns the difference of the current hypergraph with another hypergraph.\n The difference includes nodes and hyperedges present in the current hypergraph but not in the other.\n\n Parameters\n ----------\n other : Hypergraph\n The hypergraph to subtract.\n inplace : bool, optional, default=False\n If True, modifies the current hypergraph by removing elements found in `other`.\n Otherwise, returns a new Hypergraph instance.\n name : Optional[str], default=None\n The name for the resulting hypergraph. If None, defaults to 'Difference_of_{self.name}_{other.name}'.\n\n Returns\n -------\n Hypergraph\n The resulting difference hypergraph.\n\n Raises\n ------\n TypeError\n If `other` is not an instance of Hypergraph.\n \"\"\"\n if not isinstance(other, Hypergraph):\n raise TypeError(\"The `other` parameter must be an instance of Hypergraph.\")\n\n if inplace:\n # Remove hyperedges present in other\n hyperedges_to_remove = set(self.hyperedges.keys()) & set(other.hyperedges.keys())\n self.remove_hyperedges(hyperedges_to_remove, inplace=True)\n # Remove nodes present in other\n nodes_to_remove = set(self.node_properties.keys()) & set(other.node_properties.keys())\n self.remove_nodes_from(nodes_to_remove, inplace=True)\n return self\n else:\n # Create a new Hypergraph instance\n new_hyperedges = {he_id: copy.deepcopy(he) for he_id, he in self.hyperedges.items() if he_id not in other.hyperedges}\n new_hyperedge_properties = {he_id: copy.deepcopy(props) for he_id, props in self.hyperedge_properties.items() if he_id not in other.hyperedges}\n new_node_properties = {node_id: copy.deepcopy(props) for node_id, props in self.node_properties.items() if node_id not in other.node_properties}\n\n # Reconstruct graph\n new_graph = nx.Graph()\n new_bipartite_nodes = set()\n for he_id, hyperedge in new_hyperedges.items():\n new_graph.add_node(he_id, bipartite='hyperedge', **new_hyperedge_properties[he_id])\n new_bipartite_nodes.add(he_id)\n for node in hyperedge.nodes:\n if node in new_node_properties:\n new_graph.add_edge(he_id, node)\n\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=name if name else f\"Difference_of_{self.name}_{other.name}\"\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.dual","title":"dual(name=None)
","text":"Constructs the dual of the current hypergraph by reversing the roles of nodes and hyperedges.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.dual--parameters","title":"Parameters","text":"name : Optional[str], default=None Name for the dual hypergraph. If None, defaults to the original hypergraph's name with '_dual' appended.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.dual--returns","title":"Returns","text":"Hypergraph A new Hypergraph instance representing the dual of the current hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def dual(self, name: Optional[str] = None) -> \"Hypergraph\":\n \"\"\"\n Constructs the dual of the current hypergraph by reversing the roles of nodes and hyperedges.\n\n Parameters\n ----------\n name : Optional[str], default=None\n Name for the dual hypergraph. If None, defaults to the original hypergraph's name with '_dual' appended.\n\n Returns\n -------\n Hypergraph\n A new Hypergraph instance representing the dual of the current hypergraph.\n \"\"\"\n # Initialize dual hyperedges, which will correspond to original nodes\n dual_hyperedges = {}\n\n # Invert the node-hyperedge structure\n for he_id, hyperedge in self.hyperedges.items():\n for node in hyperedge.nodes:\n # Each original node becomes a hyperedge in the dual\n if node not in dual_hyperedges:\n dual_hyperedges[node] = {'nodes': [], 'properties': self.node_properties.get(node, {})}\n # The new hyperedge (original node) connects to the original hyperedge id as a \"node\"\n dual_hyperedges[node]['nodes'].append(he_id)\n\n # Define node properties in the dual as the original hyperedge properties\n dual_node_properties = {he_id: self.hyperedge_properties.get(he_id, {}) for he_id in self.hyperedges}\n\n # Create and return the dual Hypergraph\n return Hypergraph(\n hyperedges=dual_hyperedges,\n node_properties=dual_node_properties,\n hyperedge_properties=self.node_properties, # Properties of original nodes now apply to dual hyperedges\n name=name or (self.name + \"_dual\" if self.name else \"dual\")\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.edge_elements","title":"edge_elements()
","text":"Returns a dictionary where each key is a hyperedge ID and the value is a list of node IDs within that hyperedge.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.edge_elements--returns","title":"Returns","text":"Dict[Any, List[Any]] Dictionary mapping hyperedge IDs to lists of node IDs they contain.
Source code in src/aeiva/hypergraph/hypergraph.py
def edge_elements(self) -> Dict[Any, List[Any]]:\n \"\"\"\n Returns a dictionary where each key is a hyperedge ID and the value is a list of node IDs within that hyperedge.\n\n Returns\n -------\n Dict[Any, List[Any]]\n Dictionary mapping hyperedge IDs to lists of node IDs they contain.\n \"\"\"\n return {he_id: hyperedge.nodes for he_id, hyperedge in self.hyperedges.items()}\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.edges","title":"edges()
","text":"Returns a list of all hyperedge identifiers in the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.edges--returns","title":"Returns","text":"List[Any] List of hyperedge IDs.
Source code in src/aeiva/hypergraph/hypergraph.py
def edges(self) -> List[Any]:\n \"\"\"\n Returns a list of all hyperedge identifiers in the hypergraph.\n\n Returns\n -------\n List[Any]\n List of hyperedge IDs.\n \"\"\"\n return list(self.hyperedges.keys())\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.from_bipartite_graph","title":"from_bipartite_graph(bipartite_graph, hyperedge_prefix='HE', node_prefix='N', name=None)
classmethod
","text":"Constructs a Hypergraph instance from a bipartite graph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.from_bipartite_graph--parameters","title":"Parameters","text":"bipartite_graph : nx.Graph A bipartite graph where one set of nodes represents hyperedges and the other represents regular nodes. hyperedge_prefix : str, optional, default=\"HE\" The prefix to identify hyperedge nodes in the bipartite graph. node_prefix : str, optional, default=\"N\" The prefix to identify regular nodes in the bipartite graph. name : Optional[str], default=None The name assigned to the new Hypergraph. If None, defaults to 'FromBipartiteGraph'.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.from_bipartite_graph--returns","title":"Returns","text":"Hypergraph The constructed Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.from_bipartite_graph--raises","title":"Raises","text":"ValueError If the bipartite graph does not contain two distinct sets of nodes identifiable by the provided prefixes.
Source code in src/aeiva/hypergraph/hypergraph.py
@classmethod\ndef from_bipartite_graph(cls, bipartite_graph: nx.Graph, hyperedge_prefix: str = \"HE\", node_prefix: str = \"N\", name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Constructs a Hypergraph instance from a bipartite graph.\n\n Parameters\n ----------\n bipartite_graph : nx.Graph\n A bipartite graph where one set of nodes represents hyperedges and the other represents regular nodes.\n hyperedge_prefix : str, optional, default=\"HE\"\n The prefix to identify hyperedge nodes in the bipartite graph.\n node_prefix : str, optional, default=\"N\"\n The prefix to identify regular nodes in the bipartite graph.\n name : Optional[str], default=None\n The name assigned to the new Hypergraph. If None, defaults to 'FromBipartiteGraph'.\n\n Returns\n -------\n Hypergraph\n The constructed Hypergraph instance.\n\n Raises\n ------\n ValueError\n If the bipartite graph does not contain two distinct sets of nodes identifiable by the provided prefixes.\n \"\"\"\n hyperedges = {}\n node_properties = {}\n hyperedge_properties = {}\n name = name if name else \"FromBipartiteGraph\"\n\n for node in bipartite_graph.nodes(data=True):\n node_id, attrs = node\n if node_id.startswith(hyperedge_prefix):\n # It's a hyperedge\n hyperedges[node_id] = HyperEdge(id=node_id, nodes=set(), properties=attrs)\n hyperedge_properties[node_id] = copy.deepcopy(attrs)\n elif node_id.startswith(node_prefix):\n # It's a regular node\n node_properties[node_id] = copy.deepcopy(attrs)\n else:\n raise ValueError(f\"Node '{node_id}' does not start with either hyperedge_prefix '{hyperedge_prefix}' or node_prefix '{node_prefix}'.\")\n\n # Assign nodes to hyperedges based on edges in bipartite graph\n for he_id in hyperedges:\n connected_nodes = set(bipartite_graph.neighbors(he_id))\n hyperedges[he_id].nodes = connected_nodes\n\n # Construct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in hyperedges.items()\n }\n\n return cls(\n hyperedges=hyperedges_dict,\n node_properties=node_properties,\n hyperedge_properties=hyperedge_properties,\n name=name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedge_connected_components","title":"get_hyperedge_connected_components(s=1, return_singletons=False)
","text":"Yields the s-hyperedge-connected components of the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedge_connected_components--parameters","title":"Parameters","text":"s : int, optional, default=1 The connectivity level to check. return_singletons : bool, optional, default=False If True, includes singleton components. Otherwise, excludes them.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedge_connected_components--yields","title":"Yields","text":"Set[Any] Sets of hyperedge IDs representing each connected component.
Source code in src/aeiva/hypergraph/hypergraph.py
def get_hyperedge_connected_components(\n self, s: int = 1, return_singletons: bool = False\n) -> Iterator[Set[Any]]:\n \"\"\"\n Yields the s-hyperedge-connected components of the hypergraph.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n\n Yields\n ------\n Set[Any]\n Sets of hyperedge IDs representing each connected component.\n \"\"\"\n return self.s_connected_components(s=s, hyperedges=True, return_singletons=return_singletons)\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedge_connected_subgraphs","title":"get_hyperedge_connected_subgraphs(s=1, return_singletons=False, name=None)
","text":"Yields subgraphs corresponding to each s-hyperedge-connected component.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedge_connected_subgraphs--parameters","title":"Parameters","text":"s : int, optional, default=1 The connectivity level to check. return_singletons : bool, optional, default=False If True, includes singleton components. Otherwise, excludes them. name : Optional[str], default=None Base name for the subgraphs. Each subgraph will have a unique name appended.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedge_connected_subgraphs--yields","title":"Yields","text":"Hypergraph Subgraphs representing each connected component.
Source code in src/aeiva/hypergraph/hypergraph.py
def get_hyperedge_connected_subgraphs(\n self, s: int = 1, return_singletons: bool = False, name: Optional[str] = None\n) -> Iterator['Hypergraph']:\n \"\"\"\n Yields subgraphs corresponding to each s-hyperedge-connected component.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n name : Optional[str], default=None\n Base name for the subgraphs. Each subgraph will have a unique name appended.\n\n Yields\n ------\n Hypergraph\n Subgraphs representing each connected component.\n \"\"\"\n return self.s_component_subgraphs(\n s=s,\n hyperedges=True,\n return_singletons=return_singletons,\n name=name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedge_distance","title":"get_hyperedge_distance(source, target, s=1)
","text":"Returns the shortest s-walk distance between two hyperedges in the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedge_distance--parameters","title":"Parameters","text":"source : Any A hyperedge identifier in the hypergraph. target : Any A hyperedge identifier in the hypergraph. s : int, optional, default=1 The number of shared nodes required for hyperedges to be considered adjacent.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedge_distance--returns","title":"Returns","text":"Union[int, float] The shortest s-walk distance between the source and target hyperedges. Returns float('inf')
if no path exists.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedge_distance--raises","title":"Raises","text":"HypergraphError If either the source or target hyperedge does not exist in the hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def get_hyperedge_distance(self, source: Any, target: Any, s: int = 1) -> Union[int, float]:\n \"\"\"\n Returns the shortest s-walk distance between two hyperedges in the hypergraph.\n\n Parameters\n ----------\n source : Any\n A hyperedge identifier in the hypergraph.\n target : Any\n A hyperedge identifier in the hypergraph.\n s : int, optional, default=1\n The number of shared nodes required for hyperedges to be considered adjacent.\n\n Returns\n -------\n Union[int, float]\n The shortest s-walk distance between the source and target hyperedges.\n Returns `float('inf')` if no path exists.\n\n Raises\n ------\n HypergraphError\n If either the source or target hyperedge does not exist in the hypergraph.\n \"\"\"\n if source not in self.hyperedges:\n raise HypergraphError(f\"Source hyperedge '{source}' does not exist in the hypergraph.\")\n if target not in self.hyperedges:\n raise HypergraphError(f\"Target hyperedge '{target}' does not exist in the hypergraph.\")\n\n A, he_id_map = self.hyperedge_adjacency_matrix(s=s, index=True)\n if A is None:\n raise HypergraphError(\"Hyperedge adjacency matrix could not be generated.\")\n\n graph = nx.from_scipy_sparse_array(A)\n\n try:\n distance = nx.shortest_path_length(graph, source=source, target=target)\n return distance\n except (nx.NetworkXNoPath, nx.NodeNotFound):\n warnings.warn(f\"No s-walk path between hyperedges '{source}' and '{target}'. Returning infinity.\")\n return float('inf')\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedges_of_node","title":"get_hyperedges_of_node(node_id)
","text":"Retrieves all hyperedges that a given node is part of.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedges_of_node--parameters","title":"Parameters","text":"node_id : Any The node identifier.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedges_of_node--returns","title":"Returns","text":"Set[Any] A set of hyperedge IDs that the node belongs to.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_hyperedges_of_node--raises","title":"Raises","text":"HypergraphError If the node does not exist in the hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def get_hyperedges_of_node(self, node_id: Any) -> Set[Any]:\n \"\"\"\n Retrieves all hyperedges that a given node is part of.\n\n Parameters\n ----------\n node_id : Any\n The node identifier.\n\n Returns\n -------\n Set[Any]\n A set of hyperedge IDs that the node belongs to.\n\n Raises\n ------\n HypergraphError\n If the node does not exist in the hypergraph.\n \"\"\"\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n return {he.id for he in self.hyperedges.values() if node_id in he.nodes}\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_node_connected_components","title":"get_node_connected_components(s=1, return_singletons=False)
","text":"Yields the s-node-connected components of the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_node_connected_components--parameters","title":"Parameters","text":"s : int, optional, default=1 The connectivity level to check. return_singletons : bool, optional, default=False If True, includes singleton components. Otherwise, excludes them.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_node_connected_components--yields","title":"Yields","text":"Set[Any] Sets of node IDs representing each connected component.
Source code in src/aeiva/hypergraph/hypergraph.py
def get_node_connected_components(\n self, s: int = 1, return_singletons: bool = False\n) -> Iterator[Set[Any]]:\n \"\"\"\n Yields the s-node-connected components of the hypergraph.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n\n Yields\n ------\n Set[Any]\n Sets of node IDs representing each connected component.\n \"\"\"\n return self.s_connected_components(s=s, hyperedges=False, return_singletons=return_singletons)\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_node_connected_subgraphs","title":"get_node_connected_subgraphs(s=1, return_singletons=False, name=None)
","text":"Yields subgraphs corresponding to each s-node-connected component.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_node_connected_subgraphs--parameters","title":"Parameters","text":"s : int, optional, default=1 The connectivity level to check. return_singletons : bool, optional, default=False If True, includes singleton components. Otherwise, excludes them. name : Optional[str], default=None Base name for the subgraphs. Each subgraph will have a unique name appended.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_node_connected_subgraphs--yields","title":"Yields","text":"Hypergraph Subgraphs representing each connected component.
Source code in src/aeiva/hypergraph/hypergraph.py
def get_node_connected_subgraphs(\n self, s: int = 1, return_singletons: bool = False, name: Optional[str] = None\n) -> Iterator['Hypergraph']:\n \"\"\"\n Yields subgraphs corresponding to each s-node-connected component.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n name : Optional[str], default=None\n Base name for the subgraphs. Each subgraph will have a unique name appended.\n\n Yields\n ------\n Hypergraph\n Subgraphs representing each connected component.\n \"\"\"\n return self.s_component_subgraphs(\n s=s,\n hyperedges=False,\n return_singletons=return_singletons,\n name=name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_node_distance","title":"get_node_distance(source, target, s=1)
","text":"Returns the shortest s-walk distance between two nodes in the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_node_distance--parameters","title":"Parameters","text":"source : Any A node identifier in the hypergraph. target : Any A node identifier in the hypergraph. s : int, optional, default=1 The number of shared hyperedges required for nodes to be considered adjacent.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_node_distance--returns","title":"Returns","text":"Union[int, float] The shortest s-walk distance between the source and target nodes. Returns float('inf')
if no path exists.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_node_distance--raises","title":"Raises","text":"HypergraphError If either the source or target node does not exist in the hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def get_node_distance(self, source: Any, target: Any, s: int = 1) -> Union[int, float]:\n \"\"\"\n Returns the shortest s-walk distance between two nodes in the hypergraph.\n\n Parameters\n ----------\n source : Any\n A node identifier in the hypergraph.\n target : Any\n A node identifier in the hypergraph.\n s : int, optional, default=1\n The number of shared hyperedges required for nodes to be considered adjacent.\n\n Returns\n -------\n Union[int, float]\n The shortest s-walk distance between the source and target nodes.\n Returns `float('inf')` if no path exists.\n\n Raises\n ------\n HypergraphError\n If either the source or target node does not exist in the hypergraph.\n \"\"\"\n if source not in self.node_properties:\n raise HypergraphError(f\"Source node '{source}' does not exist in the hypergraph.\")\n if target not in self.node_properties:\n raise HypergraphError(f\"Target node '{target}' does not exist in the hypergraph.\")\n\n A, node_id_map = self.adjacency_matrix(s=s, index=True)\n if A is None:\n raise HypergraphError(\"Adjacency matrix could not be generated.\")\n\n graph = nx.from_scipy_sparse_array(A)\n\n try:\n distance = nx.shortest_path_length(graph, source=source, target=target)\n return distance\n except (nx.NetworkXNoPath, nx.NodeNotFound):\n warnings.warn(f\"No s-walk path between '{source}' and '{target}'. Returning infinity.\")\n return float('inf')\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_singleton_hyperedges","title":"get_singleton_hyperedges()
","text":"Returns a list of singleton hyperedges. A singleton hyperedge is a hyperedge of size 1 where its sole node has degree 1.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_singleton_hyperedges--returns","title":"Returns","text":"List[Any] A list of singleton hyperedge IDs.
Source code in src/aeiva/hypergraph/hypergraph.py
def get_singleton_hyperedges(self) -> List[Any]:\n \"\"\"\n Returns a list of singleton hyperedges.\n A singleton hyperedge is a hyperedge of size 1 where its sole node has degree 1.\n\n Returns\n -------\n List[Any]\n A list of singleton hyperedge IDs.\n \"\"\"\n singletons = []\n for he in self.hyperedges.values():\n if len(he.nodes) == 1:\n node = next(iter(he.nodes))\n node_degree = sum(1 for hyperedge in self.hyperedges.values() if node in hyperedge.nodes)\n if node_degree == 1:\n singletons.append(he.id)\n return singletons\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_toplexes","title":"get_toplexes(return_hypergraph=False)
","text":"Computes a maximal collection of toplexes for the hypergraph. A :term:toplex
is a hyperedge that is not contained in any other hyperedge.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_toplexes--parameters","title":"Parameters","text":"return_hypergraph : bool, optional, default=False If True, returns a new Hypergraph consisting only of the toplexes.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.get_toplexes--returns","title":"Returns","text":"List[Any] or Hypergraph - A list of toplex hyperedge IDs. - If return_hypergraph=True
, returns a Hypergraph containing only the toplexes.
Source code in src/aeiva/hypergraph/hypergraph.py
def get_toplexes(self, return_hypergraph: bool = False) -> Union[List[Any], 'Hypergraph']:\n \"\"\"\n Computes a maximal collection of toplexes for the hypergraph.\n A :term:`toplex` is a hyperedge that is not contained in any other hyperedge.\n\n Parameters\n ----------\n return_hypergraph : bool, optional, default=False\n If True, returns a new Hypergraph consisting only of the toplexes.\n\n Returns\n -------\n List[Any] or Hypergraph\n - A list of toplex hyperedge IDs.\n - If `return_hypergraph=True`, returns a Hypergraph containing only the toplexes.\n \"\"\"\n toplexes = []\n hyperedges = list(self.hyperedges.values())\n\n for he in hyperedges:\n if not any(he.nodes < other_he.nodes for other_he in hyperedges if he.id != other_he.id):\n toplexes.append(he.id)\n\n if return_hypergraph:\n return self.restrict_to_specific_hyperedges(toplexes, name=\"Toplexes\")\n return toplexes\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.hyperedge_adjacency_matrix","title":"hyperedge_adjacency_matrix(s=1, index=False)
","text":"Generates the adjacency matrix for hyperedges based on s-hyperedge connectivity.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.hyperedge_adjacency_matrix--parameters","title":"Parameters","text":"s : int, optional, default=1 The number of shared nodes required for hyperedges to be considered adjacent. index : bool, optional, default=False If True, returns a mapping from matrix indices to hyperedge IDs.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.hyperedge_adjacency_matrix--returns","title":"Returns","text":"Tuple[Optional[csr_matrix], Dict[int, Any]] - The adjacency matrix in CSR format. - A dictionary mapping matrix indices to hyperedge IDs.
Source code in src/aeiva/hypergraph/hypergraph.py
def hyperedge_adjacency_matrix(self, s: int = 1, index: bool = False) -> Tuple[Optional[csr_matrix], Dict[int, Any]]:\n \"\"\"\n Generates the adjacency matrix for hyperedges based on s-hyperedge connectivity.\n\n Parameters\n ----------\n s : int, optional, default=1\n The number of shared nodes required for hyperedges to be considered adjacent.\n index : bool, optional, default=False\n If True, returns a mapping from matrix indices to hyperedge IDs.\n\n Returns\n -------\n Tuple[Optional[csr_matrix], Dict[int, Any]]\n - The adjacency matrix in CSR format.\n - A dictionary mapping matrix indices to hyperedge IDs.\n \"\"\"\n from scipy.sparse import lil_matrix\n\n hyperedge_ids = list(self.hyperedges.keys())\n he_index = {he_id: idx for idx, he_id in enumerate(hyperedge_ids)}\n size = len(hyperedge_ids)\n if size == 0:\n return None, {}\n\n A = lil_matrix((size, size), dtype=int)\n for i, he1 in enumerate(hyperedge_ids):\n nodes1 = self.hyperedges[he1].nodes\n for j in range(i + 1, size):\n he2 = hyperedge_ids[j]\n nodes2 = self.hyperedges[he2].nodes\n shared_nodes = nodes1 & nodes2\n if len(shared_nodes) >= s:\n A[i, j] = 1\n A[j, i] = 1\n\n A = A.tocsr()\n\n if index:\n return A, he_index\n return A, {}\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.intersection","title":"intersection(other, inplace=False, name=None)
","text":"Returns the intersection of the current hypergraph with another hypergraph. The intersection includes only nodes and hyperedges present in both hypergraphs.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.intersection--parameters","title":"Parameters","text":"other : Hypergraph The hypergraph to intersect with. inplace : bool, optional, default=False If True, modifies the current hypergraph to keep only the intersecting elements. Otherwise, returns a new Hypergraph instance. name : Optional[str], default=None The name for the resulting hypergraph. If None, defaults to 'Intersection_of_{self.name}_{other.name}'.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.intersection--returns","title":"Returns","text":"Hypergraph The resulting intersection hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.intersection--raises","title":"Raises","text":"TypeError If other
is not an instance of Hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def intersection(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Returns the intersection of the current hypergraph with another hypergraph.\n The intersection includes only nodes and hyperedges present in both hypergraphs.\n\n Parameters\n ----------\n other : Hypergraph\n The hypergraph to intersect with.\n inplace : bool, optional, default=False\n If True, modifies the current hypergraph to keep only the intersecting elements.\n Otherwise, returns a new Hypergraph instance.\n name : Optional[str], default=None\n The name for the resulting hypergraph. If None, defaults to 'Intersection_of_{self.name}_{other.name}'.\n\n Returns\n -------\n Hypergraph\n The resulting intersection hypergraph.\n\n Raises\n ------\n TypeError\n If `other` is not an instance of Hypergraph.\n \"\"\"\n if not isinstance(other, Hypergraph):\n raise TypeError(\"The `other` parameter must be an instance of Hypergraph.\")\n\n intersect_nodes = set(self.node_properties.keys()) & set(other.node_properties.keys())\n intersect_hyperedges = set(self.hyperedges.keys()) & set(other.hyperedges.keys())\n\n if inplace:\n # Remove non-intersecting nodes and hyperedges\n nodes_to_remove = set(self.node_properties.keys()) - intersect_nodes\n hyperedges_to_remove = set(self.hyperedges.keys()) - intersect_hyperedges\n self.remove_nodes_from(nodes_to_remove, inplace=True)\n self.remove_hyperedges(hyperedges_to_remove, inplace=True)\n return self\n else:\n # Create a new Hypergraph instance\n new_hyperedges = {}\n new_node_properties = {node_id: copy.deepcopy(self.node_properties[node_id]) for node_id in intersect_nodes}\n new_hyperedge_properties = {}\n new_graph = nx.Graph()\n new_bipartite_nodes = set()\n\n for he_id in intersect_hyperedges:\n he_self = self.hyperedges[he_id]\n he_other = other.hyperedges[he_id]\n # Intersection hyperedges have the same nodes and merged properties\n new_nodes = set(he_self.nodes) & set(he_other.nodes)\n if not new_nodes:\n continue # Skip hyperedges with no common nodes\n new_hyperedges[he_id] = HyperEdge(id=he_id, nodes=new_nodes, properties={})\n # Merge properties (could define specific rules)\n new_hyperedge_properties[he_id] = {**self.hyperedge_properties.get(he_id, {}), \n **other.hyperedge_properties.get(he_id, {})}\n new_graph.add_node(he_id, bipartite='hyperedge', **new_hyperedge_properties[he_id])\n new_bipartite_nodes.add(he_id)\n for node in new_nodes:\n new_graph.add_edge(he_id, node)\n\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=name if name else f\"Intersection_of_{self.name}_{other.name}\"\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.is_hyperedge_connected","title":"is_hyperedge_connected(s=1)
","text":"Determines if the hypergraph is s-hyperedge-connected.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.is_hyperedge_connected--parameters","title":"Parameters","text":"s : int, optional, default=1 The connectivity level to check.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.is_hyperedge_connected--returns","title":"Returns","text":"bool True if the hypergraph is s-hyperedge-connected, False otherwise.
Source code in src/aeiva/hypergraph/hypergraph.py
def is_hyperedge_connected(self, s: int = 1) -> bool:\n \"\"\"\n Determines if the hypergraph is s-hyperedge-connected.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n\n Returns\n -------\n bool\n True if the hypergraph is s-hyperedge-connected, False otherwise.\n \"\"\"\n return self._is_connected(s=s, hyperedges=True)\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.is_node_connected","title":"is_node_connected(s=1)
","text":"Determines if the hypergraph is s-node-connected.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.is_node_connected--parameters","title":"Parameters","text":"s : int, optional, default=1 The connectivity level to check.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.is_node_connected--returns","title":"Returns","text":"bool True if the hypergraph is s-node-connected, False otherwise.
Source code in src/aeiva/hypergraph/hypergraph.py
def is_node_connected(self, s: int = 1) -> bool:\n \"\"\"\n Determines if the hypergraph is s-node-connected.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n\n Returns\n -------\n bool\n True if the hypergraph is s-node-connected, False otherwise.\n \"\"\"\n return self._is_connected(s=s, hyperedges=False)\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.node_memberships","title":"node_memberships()
","text":"Returns a dictionary where each key is a node ID and the value is a list of hyperedge IDs that include the node.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.node_memberships--returns","title":"Returns","text":"Dict[Any, List[Any]] Dictionary mapping node IDs to the hyperedge IDs they belong to.
Source code in src/aeiva/hypergraph/hypergraph.py
def node_memberships(self) -> Dict[Any, List[Any]]:\n \"\"\"\n Returns a dictionary where each key is a node ID and the value is a list of hyperedge IDs that include the node.\n\n Returns\n -------\n Dict[Any, List[Any]]\n Dictionary mapping node IDs to the hyperedge IDs they belong to.\n \"\"\"\n memberships = {}\n for he_id, hyperedge in self.hyperedges.items():\n for node in hyperedge.nodes:\n memberships.setdefault(node, []).append(he_id)\n return memberships\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.nodes","title":"nodes()
","text":"Returns a list of all unique node identifiers in the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.nodes--returns","title":"Returns","text":"List[Any] List of node IDs.
Source code in src/aeiva/hypergraph/hypergraph.py
def nodes(self) -> List[Any]:\n \"\"\"\n Returns a list of all unique node identifiers in the hypergraph.\n\n Returns\n -------\n List[Any]\n List of node IDs.\n \"\"\"\n return list(self.node_properties.keys())\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_hyperedge","title":"remove_hyperedge(he_id)
","text":"Removes a hyperedge from the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_hyperedge--parameters","title":"Parameters","text":"he_id : Any Identifier of the hyperedge to remove.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_hyperedge--raises","title":"Raises","text":"HypergraphError If the hyperedge does not exist.
Source code in src/aeiva/hypergraph/hypergraph.py
def remove_hyperedge(self, he_id: Any) -> None:\n \"\"\"\n Removes a hyperedge from the hypergraph.\n\n Parameters\n ----------\n he_id : Any\n Identifier of the hyperedge to remove.\n\n Raises\n ------\n HypergraphError\n If the hyperedge does not exist.\n \"\"\"\n if he_id not in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' does not exist.\")\n\n # Remove hyperedge from the graph, which also removes all incidences\n self.graph.remove_node(he_id)\n self.bipartite_nodes.discard(he_id)\n\n # Remove from internal structures\n del self.hyperedges[he_id]\n self.hyperedge_properties.pop(he_id, None)\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_hyperedges","title":"remove_hyperedges(he_ids, inplace=True)
","text":"Removes the specified hyperedges from the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_hyperedges--parameters","title":"Parameters","text":"he_ids : Any | Iterable[Any] Hyperedge identifier(s) to remove. inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the hyperedges removed.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_hyperedges--returns","title":"Returns","text":"Hypergraph The updated or new Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_hyperedges--raises","title":"Raises","text":"HypergraphError If any hyperedge ID does not exist.
Source code in src/aeiva/hypergraph/hypergraph.py
def remove_hyperedges(self, he_ids: Union[Any, Iterable[Any]], inplace: bool = True) -> 'Hypergraph':\n \"\"\"\n Removes the specified hyperedges from the hypergraph.\n\n Parameters\n ----------\n he_ids : Any | Iterable[Any]\n Hyperedge identifier(s) to remove.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the hyperedges removed.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any hyperedge ID does not exist.\n \"\"\"\n if isinstance(he_ids, (str, int)):\n he_ids = [he_ids]\n else:\n he_ids = list(he_ids)\n\n non_existing = set(he_ids) - set(self.hyperedges.keys())\n if non_existing:\n raise HypergraphError(f\"Hyperedges {non_existing} do not exist in the hypergraph.\")\n\n if inplace:\n for he_id in he_ids:\n self.remove_hyperedge(he_id)\n return self\n else:\n # Create a new Hypergraph instance with hyperedges removed\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for he_id in he_ids:\n del new_hyperedges[he_id]\n new_hyperedge_properties.pop(he_id, None)\n new_graph.remove_node(he_id)\n new_bipartite_nodes.discard(he_id)\n\n # Remove nodes not connected to any hyperedges\n retained_nodes = set()\n for hyperedge in new_hyperedges.values():\n retained_nodes.update(hyperedge.nodes)\n\n new_node_properties = {node: props for node, props in new_node_properties.items() if node in retained_nodes}\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_incidence","title":"remove_incidence(he_id, node_id, inplace=True)
","text":"Removes a single incidence from the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_incidence--parameters","title":"Parameters","text":"he_id : Any Identifier of the hyperedge. node_id : Any Identifier of the node. inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the incidence removed.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_incidence--returns","title":"Returns","text":"Hypergraph The updated or new Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_incidence--raises","title":"Raises","text":"HypergraphError If the hyperedge or node does not exist, or if the incidence does not exist.
Source code in src/aeiva/hypergraph/hypergraph.py
def remove_incidence(\n self,\n he_id: Any,\n node_id: Any,\n inplace: bool = True\n) -> 'Hypergraph':\n \"\"\"\n Removes a single incidence from the hypergraph.\n\n Parameters\n ----------\n he_id : Any\n Identifier of the hyperedge.\n node_id : Any\n Identifier of the node.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the incidence removed.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If the hyperedge or node does not exist, or if the incidence does not exist.\n \"\"\"\n if he_id not in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' does not exist in the hypergraph.\")\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n if node_id not in self.hyperedges[he_id].nodes:\n raise HypergraphError(f\"Incidence between hyperedge '{he_id}' and node '{node_id}' does not exist.\")\n\n if inplace:\n # Remove node from HyperEdge's nodes\n self.hyperedges[he_id].remove_node(node_id)\n # Remove edge from graph\n self.graph.remove_edge(he_id, node_id)\n return self\n else:\n # Create a new Hypergraph instance with the incidence removed\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n # Remove node from HyperEdge's nodes\n new_hyperedges[he_id].remove_node(node_id)\n # Remove edge from graph\n new_graph.remove_edge(he_id, node_id)\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_incidences","title":"remove_incidences(incidences, inplace=True)
","text":"Removes the specified incidences from the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_incidences--parameters","title":"Parameters","text":"incidences : Iterable[Tuple[Any, Any]] Incidence identifiers as tuples of (he_id, node_id). inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the incidences removed.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_incidences--returns","title":"Returns","text":"Hypergraph The updated or new Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_incidences--raises","title":"Raises","text":"HypergraphError If any incidence does not exist.
Source code in src/aeiva/hypergraph/hypergraph.py
def remove_incidences(\n self,\n incidences: Iterable[Tuple[Any, Any]],\n inplace: bool = True\n) -> 'Hypergraph':\n \"\"\"\n Removes the specified incidences from the hypergraph.\n\n Parameters\n ----------\n incidences : Iterable[Tuple[Any, Any]]\n Incidence identifiers as tuples of (he_id, node_id).\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the incidences removed.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any incidence does not exist.\n \"\"\"\n incidence_ids = list(incidences)\n\n # Check existence of incidences\n for he_id, node_id in incidence_ids:\n if he_id not in self.hyperedges:\n raise HypergraphError(f\"Hyperedge '{he_id}' does not exist in the hypergraph.\")\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n if node_id not in self.hyperedges[he_id].nodes:\n raise HypergraphError(f\"Incidence between hyperedge '{he_id}' and node '{node_id}' does not exist.\")\n\n if inplace:\n for he_id, node_id in incidence_ids:\n # Remove node from HyperEdge's nodes\n self.hyperedges[he_id].remove_node(node_id)\n # Remove edge from graph\n self.graph.remove_edge(he_id, node_id)\n return self\n else:\n # Create a new Hypergraph instance with the incidences removed\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for he_id, node_id in incidence_ids:\n # Remove node from HyperEdge's nodes\n new_hyperedges[he_id].remove_node(node_id)\n # Remove edge from graph\n new_graph.remove_edge(he_id, node_id)\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_node","title":"remove_node(node_id, inplace=True)
","text":"Removes a node from the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_node--parameters","title":"Parameters","text":"node_id : Any Identifier of the node to remove. inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the node removed.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_node--returns","title":"Returns","text":"Hypergraph The updated or new Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_node--raises","title":"Raises","text":"HypergraphError If the node does not exist.
Source code in src/aeiva/hypergraph/hypergraph.py
def remove_node(self, node_id: Any, inplace: bool = True) -> 'Hypergraph':\n \"\"\"\n Removes a node from the hypergraph.\n\n Parameters\n ----------\n node_id : Any\n Identifier of the node to remove.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the node removed.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If the node does not exist.\n \"\"\"\n if node_id not in self.node_properties:\n raise HypergraphError(f\"Node '{node_id}' does not exist in the hypergraph.\")\n\n if inplace:\n # Remove node from node_properties\n del self.node_properties[node_id]\n # Remove node from all hyperedges\n for hyperedge in self.hyperedges.values():\n if node_id in hyperedge.nodes:\n hyperedge.remove_node(node_id)\n # Remove node from graph, which also removes all incidences\n self.graph.remove_node(node_id)\n return self\n else:\n # Create a new Hypergraph instance with the node removed\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n # Remove node from node_properties\n del new_node_properties[node_id]\n # Remove node from all hyperedges\n for hyperedge in new_hyperedges.values():\n if node_id in hyperedge.nodes:\n hyperedge.remove_node(node_id)\n # Remove node from graph, which also removes all incidences\n new_graph.remove_node(node_id)\n\n # Remove nodes not connected to any hyperedges\n retained_nodes = set()\n for hyperedge in new_hyperedges.values():\n retained_nodes.update(hyperedge.nodes)\n\n new_node_properties = {node: props for node, props in new_node_properties.items() if node in retained_nodes}\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_nodes_from","title":"remove_nodes_from(nodes, inplace=True)
","text":"Removes the specified nodes from the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_nodes_from--parameters","title":"Parameters","text":"nodes : Any | Iterable[Any] Node identifier(s) to remove. inplace : bool, default=True If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the nodes removed.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_nodes_from--returns","title":"Returns","text":"Hypergraph The updated or new Hypergraph instance.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_nodes_from--raises","title":"Raises","text":"HypergraphError If any node ID does not exist.
Source code in src/aeiva/hypergraph/hypergraph.py
def remove_nodes_from(\n self,\n nodes: Union[Any, Iterable[Any]],\n inplace: bool = True\n) -> 'Hypergraph':\n \"\"\"\n Removes the specified nodes from the hypergraph.\n\n Parameters\n ----------\n nodes : Any | Iterable[Any]\n Node identifier(s) to remove.\n inplace : bool, default=True\n If True, modifies the existing Hypergraph. Otherwise, creates a new Hypergraph with the nodes removed.\n\n Returns\n -------\n Hypergraph\n The updated or new Hypergraph instance.\n\n Raises\n ------\n HypergraphError\n If any node ID does not exist.\n \"\"\"\n if isinstance(nodes, (str, int)):\n nodes = [nodes]\n else:\n nodes = list(nodes)\n\n non_existing = set(nodes) - set(self.node_properties.keys())\n if non_existing:\n raise HypergraphError(f\"Nodes {non_existing} do not exist in the hypergraph.\")\n\n if inplace:\n for node_id in nodes:\n self.remove_node(node_id)\n return self\n else:\n # Create a new Hypergraph instance with nodes removed\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n\n for node_id in nodes:\n del new_node_properties[node_id]\n # Remove node from all hyperedges\n for hyperedge in new_hyperedges.values():\n if node_id in hyperedge.nodes:\n hyperedge.remove_node(node_id)\n # Remove node from graph, which also removes all incidences\n new_graph.remove_node(node_id)\n\n # Remove nodes not connected to any hyperedges\n retained_nodes = set()\n for hyperedge in new_hyperedges.values():\n retained_nodes.update(hyperedge.nodes)\n\n new_node_properties = {node: props for node, props in new_node_properties.items() if node in retained_nodes}\n\n # Reconstruct hyperedges dict for __init__\n hyperedges_dict = {\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n }\n\n return Hypergraph(\n hyperedges=hyperedges_dict,\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=self.name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.remove_singleton_hyperedges","title":"remove_singleton_hyperedges(name=None)
","text":"Constructs a clone of the hypergraph with singleton hyperedges removed.
Source code in src/aeiva/hypergraph/hypergraph.py
def remove_singleton_hyperedges(self, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Constructs a clone of the hypergraph with singleton hyperedges removed.\n \"\"\"\n singletons = self.get_singleton_hyperedges()\n if not singletons:\n return self.copy(name=name)\n\n new_hypergraph = self.remove_hyperedges(singletons, inplace=False)\n new_hypergraph.name = name if name else f\"{self.name}_no_singleton_hyperedges\"\n return new_hypergraph\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.restrict_to_specific_hyperedges","title":"restrict_to_specific_hyperedges(hyperedges_to_retain, name=None)
","text":"Creates a new hypergraph by retaining only the specified hyperedges and removing all others.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.restrict_to_specific_hyperedges--parameters","title":"Parameters","text":"hyperedges_to_retain : Iterable[Any] An iterable of hyperedge identifiers to retain in the new hypergraph.
Optional[str], default=None The name assigned to the restricted hypergraph. If None, defaults to the original name suffixed with '_restricted_hyperedges'.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.restrict_to_specific_hyperedges--returns","title":"Returns","text":"Hypergraph A new hypergraph containing only the specified hyperedges and their associated nodes.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.restrict_to_specific_hyperedges--raises","title":"Raises","text":"HypergraphError If none of the specified hyperedges exist in the hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def restrict_to_specific_hyperedges(\n self,\n hyperedges_to_retain: Iterable[Any],\n name: Optional[str] = None\n) -> 'Hypergraph':\n \"\"\"\n Creates a new hypergraph by retaining only the specified hyperedges and removing all others.\n\n Parameters\n ----------\n hyperedges_to_retain : Iterable[Any]\n An iterable of hyperedge identifiers to retain in the new hypergraph.\n\n name : Optional[str], default=None\n The name assigned to the restricted hypergraph. If None, defaults to the original name suffixed with '_restricted_hyperedges'.\n\n Returns\n -------\n Hypergraph\n A new hypergraph containing only the specified hyperedges and their associated nodes.\n\n Raises\n ------\n HypergraphError\n If none of the specified hyperedges exist in the hypergraph.\n \"\"\"\n hyperedges_to_retain = set(hyperedges_to_retain)\n existing_hyperedges = set(self.hyperedges.keys())\n invalid_hyperedges = hyperedges_to_retain - existing_hyperedges\n if invalid_hyperedges:\n raise HypergraphError(f\"The following hyperedges do not exist and cannot be retained: {invalid_hyperedges}\")\n\n # Determine hyperedges to remove\n hyperedges_to_remove = existing_hyperedges - hyperedges_to_retain\n if not hyperedges_to_remove:\n # No hyperedges to remove; return the original hypergraph\n return self\n\n # Remove hyperedges using the existing remove_hyperedges method\n restricted_hypergraph = self.remove_hyperedges(hyperedges_to_remove, inplace=False)\n restricted_hypergraph.name = name if name else f\"{self.name}_restricted_hyperedges\"\n\n return restricted_hypergraph\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.restrict_to_specific_nodes","title":"restrict_to_specific_nodes(nodes_to_retain, name=None)
","text":"Creates a new hypergraph by retaining only the specified nodes and removing all others.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.restrict_to_specific_nodes--parameters","title":"Parameters","text":"nodes_to_retain : Iterable[Any] An iterable of node identifiers to retain in the new hypergraph.
Optional[str], default=None The name assigned to the restricted hypergraph. If None, defaults to the original name suffixed with '_restricted_nodes'.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.restrict_to_specific_nodes--returns","title":"Returns","text":"Hypergraph A new hypergraph containing only the specified nodes and their associated hyperedges.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.restrict_to_specific_nodes--raises","title":"Raises","text":"HypergraphError If none of the specified nodes exist in the hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def restrict_to_specific_nodes(\n self,\n nodes_to_retain: Iterable[Any],\n name: Optional[str] = None\n) -> 'Hypergraph':\n \"\"\"\n Creates a new hypergraph by retaining only the specified nodes and removing all others.\n\n Parameters\n ----------\n nodes_to_retain : Iterable[Any]\n An iterable of node identifiers to retain in the new hypergraph.\n\n name : Optional[str], default=None\n The name assigned to the restricted hypergraph. If None, defaults to the original name suffixed with '_restricted_nodes'.\n\n Returns\n -------\n Hypergraph\n A new hypergraph containing only the specified nodes and their associated hyperedges.\n\n Raises\n ------\n HypergraphError\n If none of the specified nodes exist in the hypergraph.\n \"\"\"\n nodes_to_retain = set(nodes_to_retain)\n existing_nodes = set(self.node_properties.keys())\n invalid_nodes = nodes_to_retain - existing_nodes\n if invalid_nodes:\n raise HypergraphError(f\"The following nodes do not exist and cannot be retained: {invalid_nodes}\")\n\n # Determine nodes to remove\n nodes_to_remove = existing_nodes - nodes_to_retain\n if not nodes_to_remove:\n # No nodes to remove; return the original hypergraph\n return self\n\n # Remove nodes using the existing remove_nodes_from method\n restricted_hypergraph = self.remove_nodes_from(nodes_to_remove, inplace=False)\n restricted_hypergraph.name = name if name else f\"{self.name}_restricted_nodes\"\n\n return restricted_hypergraph\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.s_component_subgraphs","title":"s_component_subgraphs(s=1, hyperedges=True, return_singletons=False, name=None)
","text":"Yields subgraphs corresponding to each s-hyperedge-connected or s-node-connected component.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.s_component_subgraphs--parameters","title":"Parameters","text":"s : int, optional, default=1 The connectivity level to check. hyperedges : bool, optional, default=True If True, yields subgraphs of s-hyperedge-connected components. Otherwise, yields subgraphs of s-node-connected components. return_singletons : bool, optional, default=False If True, includes singleton components. Otherwise, excludes them. name : Optional[str], default=None Base name for the subgraphs. Each subgraph will have a unique name appended.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.s_component_subgraphs--yields","title":"Yields","text":"Hypergraph Subgraphs representing each connected component.
Source code in src/aeiva/hypergraph/hypergraph.py
def s_component_subgraphs(\n self,\n s: int = 1,\n hyperedges: bool = True,\n return_singletons: bool = False,\n name: Optional[str] = None\n) -> Iterator['Hypergraph']:\n \"\"\"\n Yields subgraphs corresponding to each s-hyperedge-connected or s-node-connected component.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n hyperedges : bool, optional, default=True\n If True, yields subgraphs of s-hyperedge-connected components. Otherwise, yields subgraphs of s-node-connected components.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n name : Optional[str], default=None\n Base name for the subgraphs. Each subgraph will have a unique name appended.\n\n Yields\n ------\n Hypergraph\n Subgraphs representing each connected component.\n \"\"\"\n for idx, component in enumerate(\n self.s_connected_components(s=s, hyperedges=hyperedges, return_singletons=return_singletons)\n ):\n if hyperedges:\n yield self.restrict_to_specific_hyperedges(\n hyperedges_to_retain=component, \n name=f\"{name or self.name}_component_{idx}\"\n )\n else:\n yield self.restrict_to_specific_nodes(\n nodes_to_retain=component, \n name=f\"{name or self.name}_component_{idx}\"\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.s_connected_components","title":"s_connected_components(s=1, hyperedges=True, return_singletons=False)
","text":"Yields the s-hyperedge-connected or s-node-connected components of the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.s_connected_components--parameters","title":"Parameters","text":"s : int, optional, default=1 The connectivity level to check. hyperedges : bool, optional, default=True If True, yields s-hyperedge-connected components. Otherwise, yields s-node-connected components. return_singletons : bool, optional, default=False If True, includes singleton components. Otherwise, excludes them.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.s_connected_components--yields","title":"Yields","text":"Set[Any] Sets of hyperedge IDs or node IDs representing each connected component.
Source code in src/aeiva/hypergraph/hypergraph.py
def s_connected_components(\n self, \n s: int = 1, \n hyperedges: bool = True, \n return_singletons: bool = False\n) -> Iterator[Set[Any]]:\n \"\"\"\n Yields the s-hyperedge-connected or s-node-connected components of the hypergraph.\n\n Parameters\n ----------\n s : int, optional, default=1\n The connectivity level to check.\n hyperedges : bool, optional, default=True\n If True, yields s-hyperedge-connected components. Otherwise, yields s-node-connected components.\n return_singletons : bool, optional, default=False\n If True, includes singleton components. Otherwise, excludes them.\n\n Yields\n ------\n Set[Any]\n Sets of hyperedge IDs or node IDs representing each connected component.\n \"\"\"\n if hyperedges:\n # s-hyperedge-connected: hyperedges are connected if they share at least s nodes\n hyperedge_graph = nx.Graph()\n hyperedge_ids = list(self.hyperedges.keys())\n hyperedge_graph.add_nodes_from(hyperedge_ids)\n\n for i, he1 in enumerate(hyperedge_ids):\n nodes1 = self.hyperedges[he1].nodes\n for he2 in hyperedge_ids[i + 1:]:\n nodes2 = self.hyperedges[he2].nodes\n shared_nodes = nodes1 & nodes2\n if len(shared_nodes) >= s:\n hyperedge_graph.add_edge(he1, he2)\n\n components = nx.connected_components(hyperedge_graph)\n for component in components:\n if not return_singletons and len(component) == 1:\n continue\n yield component\n else:\n # s-node-connected: nodes are connected if they share at least s hyperedges\n node_graph = nx.Graph()\n node_ids = list(self.node_properties.keys())\n node_graph.add_nodes_from(node_ids)\n\n for i, node1 in enumerate(node_ids):\n hyperedges1 = {he.id for he in self.hyperedges.values() if node1 in he.nodes}\n for node2 in node_ids[i + 1:]:\n hyperedges2 = {he.id for he in self.hyperedges.values() if node2 in he.nodes}\n shared_hyperedges = hyperedges1 & hyperedges2\n if len(shared_hyperedges) >= s:\n node_graph.add_edge(node1, node2)\n\n components = nx.connected_components(node_graph)\n for component in components:\n if not return_singletons and len(component) == 1:\n continue\n yield component\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.symmetric_difference","title":"symmetric_difference(other, inplace=False, name=None)
","text":"Returns the symmetric difference of the current hypergraph with another hypergraph. The symmetric difference includes elements present in either hypergraph but not in both.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.symmetric_difference--parameters","title":"Parameters","text":"other : Hypergraph The hypergraph to symmetric difference with. inplace : bool, optional, default=False If True, modifies the current hypergraph to keep only the symmetric difference elements. Otherwise, returns a new Hypergraph instance. name : Optional[str], default=None The name for the resulting hypergraph. If None, defaults to 'SymmetricDifference_of_{self.name}_{other.name}'.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.symmetric_difference--returns","title":"Returns","text":"Hypergraph The resulting symmetric difference hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.symmetric_difference--raises","title":"Raises","text":"TypeError If other
is not an instance of Hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def symmetric_difference(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Returns the symmetric difference of the current hypergraph with another hypergraph.\n The symmetric difference includes elements present in either hypergraph but not in both.\n\n Parameters\n ----------\n other : Hypergraph\n The hypergraph to symmetric difference with.\n inplace : bool, optional, default=False\n If True, modifies the current hypergraph to keep only the symmetric difference elements.\n Otherwise, returns a new Hypergraph instance.\n name : Optional[str], default=None\n The name for the resulting hypergraph. If None, defaults to 'SymmetricDifference_of_{self.name}_{other.name}'.\n\n Returns\n -------\n Hypergraph\n The resulting symmetric difference hypergraph.\n\n Raises\n ------\n TypeError\n If `other` is not an instance of Hypergraph.\n \"\"\"\n if not isinstance(other, Hypergraph):\n raise TypeError(\"The `other` parameter must be an instance of Hypergraph.\")\n\n if inplace:\n # Hyperedges symmetric difference\n hyperedges_to_add = set(other.hyperedges.keys()) - set(self.hyperedges.keys())\n hyperedges_to_remove = set(self.hyperedges.keys()) & set(other.hyperedges.keys())\n self.remove_hyperedges(hyperedges_to_remove, inplace=True)\n for he_id in hyperedges_to_add:\n hyperedge = other.hyperedges[he_id]\n self.add_hyperedge(he_id, hyperedge.nodes, properties=hyperedge.properties)\n\n # Nodes symmetric difference\n nodes_to_add = set(other.node_properties.keys()) - set(self.node_properties.keys())\n nodes_to_remove = set(self.node_properties.keys()) & set(other.node_properties.keys())\n self.remove_nodes_from(nodes_to_remove, inplace=True)\n for node_id in nodes_to_add:\n props = other.node_properties[node_id]\n self.add_node(node_id, properties=props, inplace=True)\n\n if name:\n self.name = name\n return self\n else:\n # Create a new Hypergraph instance\n union_hg = self.union(other)\n intersection_hg = self.intersection(other)\n return union_hg.difference(intersection_hg, name=name if name else f\"SymmetricDifference_of_{self.name}_{other.name}\")\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.to_bipartite_graph","title":"to_bipartite_graph(keep_data=False, directed=False)
","text":"Creates a bipartite NetworkX graph from the hypergraph. The nodes and hyperedges of the hypergraph become nodes in the bipartite graph. For every hyperedge in the hypergraph and each node it connects to, there is an edge in the bipartite graph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.to_bipartite_graph--parameters","title":"Parameters","text":"keep_data : bool, optional, default = False If True, includes the node and hyperedge properties in the NetworkX graph. directed : bool, optional, default = False If True, the edges in the graph are directed with hyperedges as sources and nodes as targets.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.to_bipartite_graph--returns","title":"Returns","text":"networkx.Graph or networkx.DiGraph The bipartite graph representation of the hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def to_bipartite_graph(self, keep_data=False, directed=False) -> nx.Graph:\n \"\"\"\n Creates a bipartite NetworkX graph from the hypergraph.\n The nodes and hyperedges of the hypergraph become nodes in the bipartite graph.\n For every hyperedge in the hypergraph and each node it connects to, there\n is an edge in the bipartite graph.\n\n Parameters\n ----------\n keep_data : bool, optional, default = False\n If True, includes the node and hyperedge properties in the NetworkX graph.\n directed : bool, optional, default = False\n If True, the edges in the graph are directed with hyperedges as sources and nodes as targets.\n\n Returns\n -------\n networkx.Graph or networkx.DiGraph\n The bipartite graph representation of the hypergraph.\n \"\"\"\n # Choose graph type based on directed flag\n B = nx.DiGraph() if directed else nx.Graph()\n\n if not keep_data:\n # Add nodes with bipartite attributes, where 0 indicates hyperedges and 1 indicates regular nodes\n B.add_nodes_from(self.hyperedges.keys(), bipartite=0) # hyperedges\n B.add_nodes_from(self.node_properties.keys(), bipartite=1) # nodes\n\n # Add edges between hyperedges and nodes based on hyperedges data\n for he_id, hyperedge in self.hyperedges.items():\n for node in hyperedge.nodes:\n B.add_edge(he_id, node)\n else:\n # Add nodes with properties if keep_data is True\n for node_id, properties in self.node_properties.items():\n B.add_node(node_id, bipartite=1, **properties)\n\n for he_id, hyperedge in self.hyperedges.items():\n B.add_node(he_id, bipartite=0, **self.hyperedge_properties.get(he_id, {}))\n for node in hyperedge.nodes:\n # Add edges with optional properties if keep_data is True\n B.add_edge(he_id, node)\n\n return B\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.transpose","title":"transpose(name=None)
","text":"Transposes the hypergraph by swapping the roles of nodes and hyperedges. The resulting hypergraph has hyperedges corresponding to the original nodes and vice versa.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.transpose--parameters","title":"Parameters","text":"name : Optional[str], default=None The name assigned to the transposed hypergraph. If None, defaults to the original name suffixed with '_transposed'.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.transpose--returns","title":"Returns","text":"Hypergraph The transposed hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def transpose(self, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Transposes the hypergraph by swapping the roles of nodes and hyperedges.\n The resulting hypergraph has hyperedges corresponding to the original nodes and vice versa.\n\n Parameters\n ----------\n name : Optional[str], default=None\n The name assigned to the transposed hypergraph. If None, defaults to the original name suffixed with '_transposed'.\n\n Returns\n -------\n Hypergraph\n The transposed hypergraph.\n \"\"\"\n transposed_hyperedges = {node_id: HyperEdge(id=node_id, nodes=set(), properties=copy.deepcopy(props))\n for node_id, props in self.node_properties.items()}\n transposed_node_properties = {he_id: copy.deepcopy(props) for he_id, props in self.hyperedge_properties.items()}\n\n for he_id, hyperedge in self.hyperedges.items():\n for node in hyperedge.nodes:\n if node in transposed_hyperedges:\n transposed_hyperedges[node].nodes.add(he_id)\n\n # Construct the transposed hypergraph\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in transposed_hyperedges.items()\n },\n node_properties=transposed_node_properties,\n hyperedge_properties={he_id: he.properties.copy() for he_id, he in transposed_hyperedges.items()},\n name=name if name else f\"{self.name}_transposed\"\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.union","title":"union(other, inplace=False, name=None)
","text":"Returns the union of the current hypergraph with another hypergraph. The union combines all nodes and hyperedges from both hypergraphs.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.union--parameters","title":"Parameters","text":"other : Hypergraph The hypergraph to union with. inplace : bool, optional, default=False If True, modifies the current hypergraph. Otherwise, returns a new Hypergraph instance. name : Optional[str], default=None The name for the resulting hypergraph. If None, defaults to 'Union_of_{self.name}_{other.name}'.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.union--returns","title":"Returns","text":"Hypergraph The resulting union hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.hypergraph.Hypergraph.union--raises","title":"Raises","text":"TypeError If other
is not an instance of Hypergraph.
Source code in src/aeiva/hypergraph/hypergraph.py
def union(self, other: 'Hypergraph', inplace: bool = False, name: Optional[str] = None) -> 'Hypergraph':\n \"\"\"\n Returns the union of the current hypergraph with another hypergraph.\n The union combines all nodes and hyperedges from both hypergraphs.\n\n Parameters\n ----------\n other : Hypergraph\n The hypergraph to union with.\n inplace : bool, optional, default=False\n If True, modifies the current hypergraph. Otherwise, returns a new Hypergraph instance.\n name : Optional[str], default=None\n The name for the resulting hypergraph. If None, defaults to 'Union_of_{self.name}_{other.name}'.\n\n Returns\n -------\n Hypergraph\n The resulting union hypergraph.\n\n Raises\n ------\n TypeError\n If `other` is not an instance of Hypergraph.\n \"\"\"\n if not isinstance(other, Hypergraph):\n raise TypeError(\"The `other` parameter must be an instance of Hypergraph.\")\n\n if inplace:\n # Add nodes from other\n for node_id, props in other.node_properties.items():\n if node_id not in self.node_properties:\n self.add_node(node_id, properties=props, inplace=True)\n else:\n # Optionally, merge properties\n self.node_properties[node_id].update(props)\n self.graph.nodes[node_id].update(props)\n\n # Add hyperedges from other\n for he_id, hyperedge in other.hyperedges.items():\n if he_id not in self.hyperedges:\n self.add_hyperedge(he_id, hyperedge.nodes, properties=hyperedge.properties)\n else:\n # Optionally, merge properties and nodes\n self.hyperedges[he_id].nodes.update(hyperedge.nodes)\n self.hyperedge_properties[he_id].update(hyperedge.properties)\n for node in hyperedge.nodes:\n if node not in self.graph:\n self.add_node(node)\n self.graph.add_edge(he_id, node)\n if name:\n self.name = name\n return self\n else:\n # Create a new Hypergraph instance\n new_hyperedges = copy.deepcopy(self.hyperedges)\n new_node_properties = copy.deepcopy(self.node_properties)\n new_hyperedge_properties = copy.deepcopy(self.hyperedge_properties)\n new_graph = copy.deepcopy(self.graph)\n new_bipartite_nodes = copy.deepcopy(self.bipartite_nodes)\n new_name = name if name else f\"Union_of_{self.name}_{other.name}\"\n\n # Add nodes from other\n for node_id, props in other.node_properties.items():\n if node_id not in new_node_properties:\n new_node_properties[node_id] = copy.deepcopy(props)\n new_graph.add_node(node_id, bipartite='node', **props)\n\n # Add hyperedges from other\n for he_id, hyperedge in other.hyperedges.items():\n if he_id not in new_hyperedges:\n new_hyperedges[he_id] = copy.deepcopy(hyperedge)\n new_hyperedge_properties[he_id] = copy.deepcopy(other.hyperedge_properties[he_id])\n new_graph.add_node(he_id, bipartite='hyperedge', **new_hyperedge_properties[he_id])\n new_bipartite_nodes.add(he_id)\n for node in hyperedge.nodes:\n new_graph.add_edge(he_id, node)\n else:\n # Merge nodes and properties\n new_hyperedges[he_id].nodes.update(hyperedge.nodes)\n new_hyperedge_properties[he_id].update(other.hyperedge_properties[he_id])\n for node in hyperedge.nodes:\n new_graph.add_edge(he_id, node)\n\n # Construct the new Hypergraph\n return Hypergraph(\n hyperedges={\n he_id: {\n 'nodes': list(he.nodes),\n 'properties': he.properties.copy()\n } for he_id, he in new_hyperedges.items()\n },\n node_properties=new_node_properties,\n hyperedge_properties=new_hyperedge_properties,\n name=new_name\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization","title":"visualization
","text":""},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_edge_labels","title":"draw_hyper_edge_labels(H, pos, polys, labels={}, edge_labels_on_edge=True, ax=None, **kwargs)
","text":"Draws a label on the hyper edge boundary.
Should be passed Matplotlib PolyCollection representing the hyper-edges, see the return value of draw_hyper_edges.
The label will be draw on the least curvy part of the polygon, and will be aligned parallel to the orientation of the polygon where it is drawn.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_edge_labels--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn polys: PolyCollection collection of polygons returned by draw_hyper_edges labels: dict mapping of node id to string label ax: Axis matplotlib axis on which the plot is rendered kwargs: dict Keyword arguments are passed through to Matplotlib's annotate function.
Source code in src/aeiva/hypergraph/visualization.py
def draw_hyper_edge_labels(\n H, pos, polys, labels={}, edge_labels_on_edge=True, ax=None, **kwargs\n):\n \"\"\"\n Draws a label on the hyper edge boundary.\n\n Should be passed Matplotlib PolyCollection representing the hyper-edges, see\n the return value of draw_hyper_edges.\n\n The label will be draw on the least curvy part of the polygon, and will be\n aligned parallel to the orientation of the polygon where it is drawn.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n polys: PolyCollection\n collection of polygons returned by draw_hyper_edges\n labels: dict\n mapping of node id to string label\n ax: Axis\n matplotlib axis on which the plot is rendered\n kwargs: dict\n Keyword arguments are passed through to Matplotlib's annotate function.\n\n \"\"\"\n ax = ax or plt.gca()\n\n params = transpose_inflated_kwargs(inflate_kwargs(H.edges(), kwargs))\n\n for edge, path, params in zip(H.edges(), polys.get_paths(), params):\n s = labels.get(edge, edge)\n\n theta = 0\n xy = None\n\n if edge_labels_on_edge:\n # calculate the xy location of the annotation\n # this is the midpoint of the pair of adjacent points the most distant\n d = ((path.vertices[:-1] - path.vertices[1:]) ** 2).sum(axis=1)\n i = d.argmax()\n\n x1, x2 = path.vertices[i : i + 2]\n x, y = x2 - x1\n theta = 360 * np.arctan2(y, x) / (2 * np.pi)\n theta = (theta + 360) % 360\n\n while theta > 90:\n theta -= 180\n\n xy = (x1 + x2) / 2\n else:\n xy = pos[edge]\n\n # the string is a comma separated list of the edge uid\n ax.annotate(s, xy, rotation=theta, ha=\"center\", va=\"center\", **params)\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_edges","title":"draw_hyper_edges(H, pos, ax=None, node_radius={}, contain_hyper_edges=False, dr=None, **kwargs)
","text":"Draws a convex hull around the nodes contained within each edge in H
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_edges--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 node_radius: dict mapping of node to R^1 (radius of each node) dr: float the spacing between concentric rings ax: Axis matplotlib axis on which the plot is rendered kwargs: dict keyword arguments, e.g., linewidth, facecolors, are passed through to the PolyCollection constructor
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_edges--returns","title":"Returns","text":"PolyCollection a Matplotlib PolyCollection that can be further styled
Source code in src/aeiva/hypergraph/visualization.py
def draw_hyper_edges(\n H, pos, ax=None, node_radius={}, contain_hyper_edges=False, dr=None, **kwargs\n):\n \"\"\"\n Draws a convex hull around the nodes contained within each edge in H\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n pos: dict\n mapping of node and edge positions to R^2\n node_radius: dict\n mapping of node to R^1 (radius of each node)\n dr: float\n the spacing between concentric rings\n ax: Axis\n matplotlib axis on which the plot is rendered\n kwargs: dict\n keyword arguments, e.g., linewidth, facecolors, are passed through to the PolyCollection constructor\n\n Returns\n -------\n PolyCollection\n a Matplotlib PolyCollection that can be further styled\n \"\"\"\n points = layout_hyper_edges(\n H, pos, node_radius=node_radius, dr=dr, contain_hyper_edges=contain_hyper_edges\n )\n\n polys = PolyCollection(points, **inflate_kwargs(H.edges(), kwargs))\n\n (ax or plt.gca()).add_collection(polys)\n\n return polys\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_edges_two_column","title":"draw_hyper_edges_two_column(H, pos, ax=None, **kwargs)
","text":"Renders hyper edges for the two column layout.
Each node-hyper edge membership is rendered as a line connecting the node in the left column to the edge in the right column.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_edges_two_column--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 ax: Axis matplotlib axis on which the plot is rendered kwargs: dict keyword arguments passed to matplotlib.LineCollection
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_edges_two_column--returns","title":"Returns","text":"LineCollection the hyper edges
Source code in src/aeiva/hypergraph/visualization.py
def draw_hyper_edges_two_column(H, pos, ax=None, **kwargs):\n \"\"\"\n Renders hyper edges for the two column layout.\n\n Each node-hyper edge membership is rendered as a line connecting the node\n in the left column to the edge in the right column.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n pos: dict\n mapping of node and edge positions to R^2\n ax: Axis\n matplotlib axis on which the plot is rendered\n kwargs: dict\n keyword arguments passed to matplotlib.LineCollection\n\n Returns\n -------\n LineCollection\n the hyper edges\n \"\"\"\n ax = ax or plt.gca()\n\n pairs = [(v, e) for e in H.edges() for v in H.edge_elements()[e]]\n\n kwargs = {\n k: v if type(v) != dict else [v.get(e) for _, e in pairs]\n for k, v in kwargs.items()\n }\n\n lines = LineCollection([(pos[u], pos[v]) for u, v in pairs], **kwargs)\n\n ax.add_collection(lines)\n\n return lines\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_labels","title":"draw_hyper_labels(H, pos, node_radius={}, ax=None, labels={}, **kwargs)
","text":"Draws text labels for the hypergraph nodes.
The label is drawn to the right of the node. The node radius is needed (see draw_hyper_nodes) so the text can be offset appropriately as the node size changes.
The text label can be customized by passing in a dictionary, labels, mapping a node to its custom label. By default, the label is the string representation of the node.
Keyword arguments are passed through to Matplotlib's annotate function.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_labels--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 node_radius: dict mapping of node to R^1 (radius of each node) ax: Axis matplotlib axis on which the plot is rendered labels: dict mapping of node to text label kwargs: dict keyword arguments passed to matplotlib.annotate
Source code in src/aeiva/hypergraph/visualization.py
def draw_hyper_labels(H, pos, node_radius={}, ax=None, labels={}, **kwargs):\n \"\"\"\n Draws text labels for the hypergraph nodes.\n\n The label is drawn to the right of the node. The node radius is needed (see\n draw_hyper_nodes) so the text can be offset appropriately as the node size\n changes.\n\n The text label can be customized by passing in a dictionary, labels, mapping\n a node to its custom label. By default, the label is the string\n representation of the node.\n\n Keyword arguments are passed through to Matplotlib's annotate function.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n pos: dict\n mapping of node and edge positions to R^2\n node_radius: dict\n mapping of node to R^1 (radius of each node)\n ax: Axis\n matplotlib axis on which the plot is rendered\n labels: dict\n mapping of node to text label\n kwargs: dict\n keyword arguments passed to matplotlib.annotate\n\n \"\"\"\n ax = ax or plt.gca()\n params = transpose_inflated_kwargs(inflate_kwargs(H.nodes(), kwargs))\n\n for v, v_kwargs in zip(iter(H.nodes()), params):\n xy = np.array([node_radius.get(v, 0), 0]) + pos[v]\n ax.annotate(\n labels.get(v, v),\n xy,\n **{\n k: (\n d[v]\n if hasattr(d, \"__getitem__\") and type(d) not in {str, tuple}\n else d\n )\n for k, d in kwargs.items()\n }\n )\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_labels_two_column","title":"draw_hyper_labels_two_column(H, pos, labels={}, with_node_labels=True, with_edge_labels=True, ax=None)
","text":"Renders hyper labels (nodes and edges) for the two column layout.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_labels_two_column--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 labels: dict custom labels for nodes and edges can be supplied with_node_labels: bool False to disable node labels with_edge_labels: bool False to disable edge labels ax: Axis matplotlib axis on which the plot is rendered kwargs: dict keyword arguments passed to matplotlib.LineCollection
Source code in src/aeiva/hypergraph/visualization.py
def draw_hyper_labels_two_column(\n H, pos, labels={}, with_node_labels=True, with_edge_labels=True, ax=None\n):\n \"\"\"\n Renders hyper labels (nodes and edges) for the two column layout.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n pos: dict\n mapping of node and edge positions to R^2\n labels: dict\n custom labels for nodes and edges can be supplied\n with_node_labels: bool\n False to disable node labels\n with_edge_labels: bool\n False to disable edge labels\n ax: Axis\n matplotlib axis on which the plot is rendered\n kwargs: dict\n keyword arguments passed to matplotlib.LineCollection\n\n \"\"\"\n\n ax = ax or plt.gca()\n\n to_draw = []\n if with_node_labels:\n to_draw.append((list(H.nodes()), \"right\"))\n\n if with_edge_labels:\n to_draw.append((list(H.edges()), \"left\"))\n\n for points, ha in to_draw:\n for p in points:\n ax.annotate(labels.get(p, p), pos[p], ha=ha, va=\"center\")\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_nodes","title":"draw_hyper_nodes(H, pos, node_radius={}, r0=None, ax=None, **kwargs)
","text":"Draws a circle for each node in H.
The position of each node is specified by the a dictionary/list-like, pos, where pos[v] is the xy-coordinate for the vertex. The radius of each node can be specified as a dictionary where node_radius[v] is the radius. If a node is missing from this dictionary, or the node_radius is not specified at all, a sensible default radius is chosen based on distances between nodes given by pos.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_nodes--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 node_radius: dict mapping of node to R^1 (radius of each node) r0: float minimum distance that concentric rings start from the node position ax: Axis matplotlib axis on which the plot is rendered kwargs: dict keyword arguments, e.g., linewidth, facecolors, are passed through to the PolyCollection constructor
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_hyper_nodes--returns","title":"Returns","text":"PolyCollection a Matplotlib PolyCollection that can be further styled
Source code in src/aeiva/hypergraph/visualization.py
def draw_hyper_nodes(H, pos, node_radius={}, r0=None, ax=None, **kwargs):\n \"\"\"\n Draws a circle for each node in H.\n\n The position of each node is specified by the a dictionary/list-like, pos,\n where pos[v] is the xy-coordinate for the vertex. The radius of each node\n can be specified as a dictionary where node_radius[v] is the radius. If a\n node is missing from this dictionary, or the node_radius is not specified at\n all, a sensible default radius is chosen based on distances between nodes\n given by pos.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n pos: dict\n mapping of node and edge positions to R^2\n node_radius: dict\n mapping of node to R^1 (radius of each node)\n r0: float\n minimum distance that concentric rings start from the node position\n ax: Axis\n matplotlib axis on which the plot is rendered\n kwargs: dict\n keyword arguments, e.g., linewidth, facecolors, are passed through to the PolyCollection constructor\n\n Returns\n -------\n PolyCollection\n a Matplotlib PolyCollection that can be further styled\n \"\"\"\n\n ax = ax or plt.gca()\n\n r0 = r0 or get_default_radius(H, pos)\n\n points = [node_radius.get(v, r0) * cp + pos[v] for v in H.nodes()]\n\n kwargs.setdefault(\"facecolors\", \"black\")\n\n circles = PolyCollection(points, **inflate_kwargs(H, kwargs))\n\n ax.add_collection(circles)\n\n return circles\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_rubber_band","title":"draw_rubber_band(H, pos=None, with_color=True, with_node_counts=False, with_edge_counts=False, layout=nx.spring_layout, layout_kwargs={}, ax=None, node_radius=None, edges_kwargs={}, nodes_kwargs={}, edge_labels_on_edge=True, edge_labels={}, edge_labels_kwargs={}, node_labels={}, node_labels_kwargs={}, with_edge_labels=True, with_node_labels=True, node_label_alpha=0.35, edge_label_alpha=0.35, with_additional_edges=None, contain_hyper_edges=False, additional_edges_kwargs={}, return_pos=False)
","text":"Draw a hypergraph as a Matplotlib figure
By default this will draw a colorful \"rubber band\" like hypergraph, where convex hulls represent edges and are drawn around the nodes they contain.
This is a convenience function that wraps calls with sensible parameters to the following lower-level drawing functions:
- draw_hyper_edges,
- draw_hyper_edge_labels,
- draw_hyper_labels, and
- draw_hyper_nodes
The default layout algorithm is nx.spring_layout, but other layouts can be passed in. The Hypergraph is converted to a bipartite graph, and the layout algorithm is passed the bipartite graph.
If you have a pre-determined layout, you can pass in a \"pos\" dictionary. This is a dictionary mapping from node id's to x-y coordinates. For example:
>>> pos = {\n>>> 'A': (0, 0),\n>>> 'B': (1, 2),\n>>> 'C': (5, -3)\n>>> }\n
will position the nodes {A, B, C} manually at the locations specified. The coordinate system is in Matplotlib \"data coordinates\", and the figure will be centered within the figure.
By default, this will draw in a new figure, but the axis to render in can be specified using :code:ax
.
This approach works well for small hypergraphs, and does not guarantee a rigorously \"correct\" drawing. Overlapping of sets in the drawing generally implies that the sets intersect, but sometimes sets overlap if there is no intersection. It is not possible, in general, to draw a \"correct\" hypergraph this way for an arbitrary hypergraph, in the same way that not all graphs have planar drawings.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_rubber_band--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 with_color: bool set to False to disable color cycling of edges with_node_counts: bool set to True to replace the label for collapsed nodes with the number of elements with_edge_counts: bool set to True to label collapsed edges with number of elements layout: function layout algorithm to compute layout_kwargs: dict keyword arguments passed to layout function ax: Axis matplotlib axis on which the plot is rendered edges_kwargs: dict keyword arguments passed to matplotlib.collections.PolyCollection for edges node_radius: None, int, float, or dict radius of all nodes, or dictionary of node:value; the default (None) calculates radius based on number of collapsed nodes; reasonable values range between 1 and 3 nodes_kwargs: dict keyword arguments passed to matplotlib.collections.PolyCollection for nodes edge_labels_on_edge: bool whether to draw edge labels on the edge (rubber band) or inside edge_labels_kwargs: dict keyword arguments passed to matplotlib.annotate for edge labels node_labels_kwargs: dict keyword argumetns passed to matplotlib.annotate for node labels with_edge_labels: bool set to False to make edge labels invisible with_node_labels: bool set to False to make node labels invisible node_label_alpha: float the transparency (alpha) of the box behind text drawn in the figure for node labels edge_label_alpha: float the transparency (alpha) of the box behind text drawn in the figure for edge labels with_additional_edges: networkx.Graph ... contain_hyper_edges: bool whether the rubber band shoudl be drawn around the location of the edge in the bipartite graph. This may be invisibile unless \"with_additional_edges\" contains this information.
Source code in src/aeiva/hypergraph/visualization.py
def draw_rubber_band(\n H,\n pos=None,\n with_color=True,\n with_node_counts=False,\n with_edge_counts=False,\n layout=nx.spring_layout,\n layout_kwargs={},\n ax=None,\n node_radius=None,\n edges_kwargs={},\n nodes_kwargs={},\n edge_labels_on_edge=True,\n edge_labels={},\n edge_labels_kwargs={},\n node_labels={},\n node_labels_kwargs={},\n with_edge_labels=True,\n with_node_labels=True,\n node_label_alpha=0.35,\n edge_label_alpha=0.35,\n with_additional_edges=None,\n contain_hyper_edges=False,\n additional_edges_kwargs={},\n return_pos=False,\n):\n \"\"\"\n Draw a hypergraph as a Matplotlib figure\n\n By default this will draw a colorful \"rubber band\" like hypergraph, where\n convex hulls represent edges and are drawn around the nodes they contain.\n\n This is a convenience function that wraps calls with sensible parameters to\n the following lower-level drawing functions:\n\n * draw_hyper_edges,\n * draw_hyper_edge_labels,\n * draw_hyper_labels, and\n * draw_hyper_nodes\n\n The default layout algorithm is nx.spring_layout, but other layouts can be\n passed in. The Hypergraph is converted to a bipartite graph, and the layout\n algorithm is passed the bipartite graph.\n\n If you have a pre-determined layout, you can pass in a \"pos\" dictionary.\n This is a dictionary mapping from node id's to x-y coordinates. For example:\n\n >>> pos = {\n >>> 'A': (0, 0),\n >>> 'B': (1, 2),\n >>> 'C': (5, -3)\n >>> }\n\n will position the nodes {A, B, C} manually at the locations specified. The\n coordinate system is in Matplotlib \"data coordinates\", and the figure will\n be centered within the figure.\n\n By default, this will draw in a new figure, but the axis to render in can be\n specified using :code:`ax`.\n\n This approach works well for small hypergraphs, and does not guarantee\n a rigorously \"correct\" drawing. Overlapping of sets in the drawing generally\n implies that the sets intersect, but sometimes sets overlap if there is no\n intersection. It is not possible, in general, to draw a \"correct\" hypergraph\n this way for an arbitrary hypergraph, in the same way that not all graphs\n have planar drawings.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n pos: dict\n mapping of node and edge positions to R^2\n with_color: bool\n set to False to disable color cycling of edges\n with_node_counts: bool\n set to True to replace the label for collapsed nodes with the number of elements\n with_edge_counts: bool\n set to True to label collapsed edges with number of elements\n layout: function\n layout algorithm to compute\n layout_kwargs: dict\n keyword arguments passed to layout function\n ax: Axis\n matplotlib axis on which the plot is rendered\n edges_kwargs: dict\n keyword arguments passed to matplotlib.collections.PolyCollection for edges\n node_radius: None, int, float, or dict\n radius of all nodes, or dictionary of node:value; the default (None) calculates radius based on number of collapsed nodes; reasonable values range between 1 and 3\n nodes_kwargs: dict\n keyword arguments passed to matplotlib.collections.PolyCollection for nodes\n edge_labels_on_edge: bool\n whether to draw edge labels on the edge (rubber band) or inside\n edge_labels_kwargs: dict\n keyword arguments passed to matplotlib.annotate for edge labels\n node_labels_kwargs: dict\n keyword argumetns passed to matplotlib.annotate for node labels\n with_edge_labels: bool\n set to False to make edge labels invisible\n with_node_labels: bool\n set to False to make node labels invisible\n node_label_alpha: float\n the transparency (alpha) of the box behind text drawn in the figure for node labels\n edge_label_alpha: float\n the transparency (alpha) of the box behind text drawn in the figure for edge labels\n with_additional_edges: networkx.Graph\n ...\n contain_hyper_edges: bool\n whether the rubber band shoudl be drawn around the location of the edge in the bipartite graph. This may be invisibile unless \"with_additional_edges\" contains this information.\n\n \"\"\"\n\n ax = ax or plt.gca()\n\n if pos is None:\n pos = layout_node_link(H, with_additional_edges, layout=layout, **layout_kwargs)\n\n r0 = get_default_radius(H, pos)\n a0 = np.pi * r0**2\n\n def get_node_radius(v):\n if node_radius is None:\n return np.sqrt(a0 * get_collapsed_size(v) / np.pi)\n elif hasattr(node_radius, \"get\"):\n return node_radius.get(v, 1) * r0\n return node_radius * r0\n\n # guarantee that node radius is a dictionary mapping nodes to values\n node_radius = {v: get_node_radius(v) for v in H.nodes()}\n\n # for convenience, we are using setdefault to mutate the argument\n # however, we need to copy this to prevent side-effects\n edges_kwargs = edges_kwargs.copy()\n edges_kwargs.setdefault(\"edgecolors\", plt.cm.tab10(np.arange(len((H.edges()))) % 10))\n edges_kwargs.setdefault(\"facecolors\", \"none\")\n\n polys = draw_hyper_edges(\n H,\n pos,\n node_radius=node_radius,\n ax=ax,\n contain_hyper_edges=contain_hyper_edges,\n **edges_kwargs\n )\n\n if with_additional_edges:\n nx.draw_networkx_edges(\n with_additional_edges,\n pos=pos,\n ax=ax,\n **inflate_kwargs(with_additional_edges.edges(), additional_edges_kwargs)\n )\n\n if with_edge_labels:\n labels = get_frozenset_label(\n H.edges(), count=with_edge_counts, override=edge_labels\n )\n\n draw_hyper_edge_labels(\n H,\n pos,\n polys,\n color=edges_kwargs[\"edgecolors\"],\n backgroundcolor=(1, 1, 1, edge_label_alpha),\n labels=labels,\n ax=ax,\n edge_labels_on_edge=edge_labels_on_edge,\n **edge_labels_kwargs\n )\n\n if with_node_labels:\n labels = get_frozenset_label(\n H.nodes(), count=with_node_counts, override=node_labels\n )\n\n draw_hyper_labels(\n H,\n pos,\n node_radius=node_radius,\n labels=labels,\n ax=ax,\n va=\"center\",\n xytext=(5, 0),\n textcoords=\"offset points\",\n backgroundcolor=(1, 1, 1, node_label_alpha),\n **node_labels_kwargs\n )\n\n draw_hyper_nodes(H, pos, node_radius=node_radius, ax=ax, **nodes_kwargs)\n\n if len(H.nodes()) == 1:\n x, y = pos[list(H.nodes())[0]]\n s = 20\n\n ax.axis([x - s, x + s, y - s, y + s])\n else:\n ax.axis(\"equal\")\n\n ax.axis(\"off\")\n if return_pos:\n return pos\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_two_column","title":"draw_two_column(H, with_node_labels=True, with_edge_labels=True, with_node_counts=False, with_edge_counts=False, with_color=True, edge_kwargs=None, ax=None)
","text":"Draw a hypergraph using a two-collumn layout.
This is intended reproduce an illustrative technique for bipartite graphs and hypergraphs that is typically used in papers and textbooks.
The left column is reserved for nodes and the right column is reserved for edges. A line is drawn between a node an an edge
The order of nodes and edges is optimized to reduce line crossings between the two columns. Spacing between disconnected components is adjusted to make the diagram easier to read, by reducing the angle of the lines.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.draw_two_column--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn with_node_labels: bool False to disable node labels with_edge_labels: bool False to disable edge labels with_node_counts: bool set to True to label collapsed nodes with number of elements with_edge_counts: bool set to True to label collapsed edges with number of elements with_color: bool set to False to disable color cycling of hyper edges edge_kwargs: dict keyword arguments to pass to matplotlib.LineCollection ax: Axis matplotlib axis on which the plot is rendered
Source code in src/aeiva/hypergraph/visualization.py
def draw_two_column(\n H,\n with_node_labels=True,\n with_edge_labels=True,\n with_node_counts=False,\n with_edge_counts=False,\n with_color=True,\n edge_kwargs=None,\n ax=None,\n):\n \"\"\"\n Draw a hypergraph using a two-collumn layout.\n\n This is intended reproduce an illustrative technique for bipartite graphs\n and hypergraphs that is typically used in papers and textbooks.\n\n The left column is reserved for nodes and the right column is reserved for\n edges. A line is drawn between a node an an edge\n\n The order of nodes and edges is optimized to reduce line crossings between\n the two columns. Spacing between disconnected components is adjusted to make\n the diagram easier to read, by reducing the angle of the lines.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n with_node_labels: bool\n False to disable node labels\n with_edge_labels: bool\n False to disable edge labels\n with_node_counts: bool\n set to True to label collapsed nodes with number of elements\n with_edge_counts: bool\n set to True to label collapsed edges with number of elements\n with_color: bool\n set to False to disable color cycling of hyper edges\n edge_kwargs: dict\n keyword arguments to pass to matplotlib.LineCollection\n ax: Axis\n matplotlib axis on which the plot is rendered\n \"\"\"\n\n edge_kwargs = edge_kwargs or {}\n\n ax = ax or plt.gca()\n\n pos = layout_two_column(H)\n\n V = [v for v in H.nodes()]\n E = [e for e in H.edges()]\n\n labels = {}\n labels.update(get_frozenset_label(V, count=with_node_counts))\n labels.update(get_frozenset_label(E, count=with_edge_counts))\n\n if with_color:\n edge_kwargs[\"color\"] = {\n e: plt.cm.tab10(i % 10) for i, e in enumerate(H.edges())\n }\n\n draw_hyper_edges_two_column(H, pos, ax=ax, **edge_kwargs)\n draw_hyper_labels_two_column(\n H,\n pos,\n labels,\n ax=ax,\n with_node_labels=with_node_labels,\n with_edge_labels=with_edge_labels,\n )\n ax.autoscale_view()\n\n ax.axis(\"off\")\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_default_radius","title":"get_default_radius(H, pos)
","text":"Calculate a reasonable default node radius
This function iterates over the hyper edges and finds the most distant pair of points given the positions provided. Then, the node radius is a fraction of the median of this distance take across all hyper-edges.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_default_radius--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_default_radius--returns","title":"Returns","text":"float the recommended radius
Source code in src/aeiva/hypergraph/visualization.py
def get_default_radius(H, pos):\n \"\"\"\n Calculate a reasonable default node radius\n\n This function iterates over the hyper edges and finds the most distant\n pair of points given the positions provided. Then, the node radius is a fraction\n of the median of this distance take across all hyper-edges.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n pos: dict\n mapping of node and edge positions to R^2\n\n Returns\n -------\n float\n the recommended radius\n\n \"\"\"\n if len(H) > 1:\n return 0.0125 * np.median(\n [pdist(np.vstack(list(map(pos.get, H.nodes())))).max() for nodes in H.edges()]\n )\n return 1\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_frozenset_label","title":"get_frozenset_label(S, count=False, override={})
","text":"Helper function for rendering the labels of possibly collapsed nodes and edges
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_frozenset_label--parameters","title":"Parameters","text":"S: iterable list of entities to be labeled count: bool True if labels should be counts of entities instead of list
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_frozenset_label--returns","title":"Returns","text":"dict mapping of entity to its string representation
Source code in src/aeiva/hypergraph/visualization.py
def get_frozenset_label(S, count=False, override={}):\n \"\"\"\n Helper function for rendering the labels of possibly collapsed nodes and edges\n\n Parameters\n ----------\n S: iterable\n list of entities to be labeled\n count: bool\n True if labels should be counts of entities instead of list\n\n Returns\n -------\n dict\n mapping of entity to its string representation\n \"\"\"\n\n def helper(v):\n if type(v) == str:\n n = get_collapsed_size(v)\n if count and n > 1:\n return f\"x {n}\"\n elif count:\n return \"\"\n return str(v)\n\n return {v: override.get(v, helper(v)) for v in S}\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_line_graph","title":"get_line_graph(H, collapse=True)
","text":"Computes the line graph, a directed graph, where a directed edge (u, v) exists if the edge u is a subset of the edge v in the hypergraph.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_line_graph--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn collapse: bool True if edges should be added if hyper edges are identical
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_line_graph--returns","title":"Returns","text":"networkx.DiGraph A directed graph
Source code in src/aeiva/hypergraph/visualization.py
def get_line_graph(H, collapse=True):\n \"\"\"\n Computes the line graph, a directed graph, where a directed edge (u, v)\n exists if the edge u is a subset of the edge v in the hypergraph.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n collapse: bool\n True if edges should be added if hyper edges are identical\n\n Returns\n -------\n networkx.DiGraph\n A directed graph\n \"\"\"\n D = nx.DiGraph()\n\n V = {edge: set(nodes) for edge, nodes in H.edge_elements().items()}\n\n D.add_nodes_from(V)\n\n for u, v in combinations(V, 2):\n if V[u] != V[v] or not collapse:\n if V[u].issubset(V[v]):\n D.add_edge(u, v)\n elif V[v].issubset(V[u]):\n D.add_edge(v, u)\n\n return D\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_set_layering","title":"get_set_layering(H, collapse=True)
","text":"Computes a layering of the edges in the hyper graph.
In this layering, each edge is assigned a level. An edge u will be above (e.g., have a smaller level value) another edge v if v is a subset of u.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_set_layering--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn collapse: bool True if edges should be added if hyper edges are identical
"},{"location":"reference/#src.aeiva.hypergraph.visualization.get_set_layering--returns","title":"Returns","text":"dict a mapping of vertices in H to integer levels
Source code in src/aeiva/hypergraph/visualization.py
def get_set_layering(H, collapse=True):\n \"\"\"\n Computes a layering of the edges in the hyper graph.\n\n In this layering, each edge is assigned a level. An edge u will be above\n (e.g., have a smaller level value) another edge v if v is a subset of u.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n collapse: bool\n True if edges should be added if hyper edges are identical\n\n Returns\n -------\n dict\n a mapping of vertices in H to integer levels\n \"\"\"\n\n D = get_line_graph(H, collapse=collapse)\n\n levels = {}\n\n for v in nx.topological_sort(D):\n parent_levels = [levels[u] for u, _ in D.in_edges(v)]\n levels[v] = max(parent_levels) + 1 if len(parent_levels) else 0\n\n return levels\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.inflate_kwargs","title":"inflate_kwargs(items, kwargs)
","text":"Helper function to expand keyword arguments.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.inflate_kwargs--parameters","title":"Parameters","text":"n: int length of resulting list if argument is expanded kwargs: dict keyword arguments to be expanded
"},{"location":"reference/#src.aeiva.hypergraph.visualization.inflate_kwargs--returns","title":"Returns","text":"dict dictionary with same keys as kwargs and whose values are lists of length n
Source code in src/aeiva/hypergraph/visualization.py
def inflate_kwargs(items, kwargs):\n \"\"\"\n Helper function to expand keyword arguments.\n\n Parameters\n ----------\n n: int\n length of resulting list if argument is expanded\n kwargs: dict\n keyword arguments to be expanded\n\n Returns\n -------\n dict\n dictionary with same keys as kwargs and whose values are lists of length n\n \"\"\"\n\n return {k: inflate(items, v) for k, v in kwargs.items()}\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.layout_hyper_edges","title":"layout_hyper_edges(H, pos, node_radius={}, dr=None, contain_hyper_edges=False)
","text":"Draws a convex hull for each edge in H.
Position of the nodes in the graph is specified by the position dictionary, pos. Convex hulls are spaced out such that if one set contains another, the convex hull will surround the contained set. The amount of spacing added between hulls is specified by the parameter, dr.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.layout_hyper_edges--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn pos: dict mapping of node and edge positions to R^2 node_radius: dict mapping of node to R^1 (radius of each node) dr: float the spacing between concentric rings ax: Axis matplotlib axis on which the plot is rendered
"},{"location":"reference/#src.aeiva.hypergraph.visualization.layout_hyper_edges--returns","title":"Returns","text":"dict A mapping from hyper edge ids to paths (Nx2 numpy matrices)
Source code in src/aeiva/hypergraph/visualization.py
def layout_hyper_edges(H, pos, node_radius={}, dr=None, contain_hyper_edges=False):\n \"\"\"\n Draws a convex hull for each edge in H.\n\n Position of the nodes in the graph is specified by the position dictionary,\n pos. Convex hulls are spaced out such that if one set contains another, the\n convex hull will surround the contained set. The amount of spacing added\n between hulls is specified by the parameter, dr.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n pos: dict\n mapping of node and edge positions to R^2\n node_radius: dict\n mapping of node to R^1 (radius of each node)\n dr: float\n the spacing between concentric rings\n ax: Axis\n matplotlib axis on which the plot is rendered\n\n Returns\n -------\n dict\n A mapping from hyper edge ids to paths (Nx2 numpy matrices)\n \"\"\"\n\n if len(node_radius):\n r0 = min(node_radius.values())\n else:\n r0 = get_default_radius(H, pos)\n\n dr = dr or r0\n\n levels = get_set_layering(H)\n\n radii = {\n v: {v: i for i, v in enumerate(sorted(e, key=levels.get))}\n for v, e in H.node_memberships().items()\n }\n\n def get_padded_hull(uid, edge):\n # make sure the edge contains at least one node\n if len(edge):\n points = [\n cp * (node_radius.get(v, r0) + dr * (2 + radii[v][uid])) + pos[v]\n for v in edge\n ]\n\n if contain_hyper_edges:\n points.append(cp * r0 + pos[uid])\n\n points = np.vstack(points)\n\n # if not, draw an empty edge centered around the location of the edge node (in the bipartite graph)\n else:\n points = 4 * r0 * cp + pos[uid]\n\n hull = ConvexHull(points)\n\n return hull.points[hull.vertices]\n\n return [get_padded_hull(uid, list(H.edge_elements()[uid])) for uid in H.edges()]\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.layout_node_link","title":"layout_node_link(H, G=None, layout=nx.spring_layout, **kwargs)
","text":"Helper function to use a NetwrokX-like graph layout algorithm on a Hypergraph
The hypergraph is converted to a bipartite graph, allowing the usual graph layout techniques to be applied.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.layout_node_link--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn G: Graph an additional set of links to consider during the layout process layout: function the layout algorithm which accepts a NetworkX graph and keyword arguments kwargs: dict Keyword arguments are passed through to the layout algorithm
"},{"location":"reference/#src.aeiva.hypergraph.visualization.layout_node_link--returns","title":"Returns","text":"dict mapping of node and edge positions to R^2
Source code in src/aeiva/hypergraph/visualization.py
def layout_node_link(H, G=None, layout=nx.spring_layout, **kwargs):\n \"\"\"\n Helper function to use a NetwrokX-like graph layout algorithm on a Hypergraph\n\n The hypergraph is converted to a bipartite graph, allowing the usual graph layout\n techniques to be applied.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n G: Graph\n an additional set of links to consider during the layout process\n layout: function\n the layout algorithm which accepts a NetworkX graph and keyword arguments\n kwargs: dict\n Keyword arguments are passed through to the layout algorithm\n\n Returns\n -------\n dict\n mapping of node and edge positions to R^2\n \"\"\"\n\n B = H.to_bipartite_graph()\n\n if G is not None:\n B.add_edges_from(G.edges())\n\n return layout(B, **kwargs)\n
"},{"location":"reference/#src.aeiva.hypergraph.visualization.layout_two_column","title":"layout_two_column(H, spacing=2)
","text":"Two column (bipartite) layout algorithm.
This algorithm first converts the hypergraph into a bipartite graph and then computes connected components. Disonneccted components are handled independently and then stacked together.
Within a connected component, the spectral ordering of the bipartite graph provides a quick and dirty ordering that minimizes edge crossings in the diagram.
"},{"location":"reference/#src.aeiva.hypergraph.visualization.layout_two_column--parameters","title":"Parameters","text":"H: hnx.Hypergraph the entity to be drawn spacing: float amount of whitespace between disconnected components
Source code in src/aeiva/hypergraph/visualization.py
def layout_two_column(H, spacing=2):\n \"\"\"\n Two column (bipartite) layout algorithm.\n\n This algorithm first converts the hypergraph into a bipartite graph and\n then computes connected components. Disonneccted components are handled\n independently and then stacked together.\n\n Within a connected component, the spectral ordering of the bipartite graph\n provides a quick and dirty ordering that minimizes edge crossings in the\n diagram.\n\n Parameters\n ----------\n H: hnx.Hypergraph\n the entity to be drawn\n spacing: float\n amount of whitespace between disconnected components\n \"\"\"\n offset = 0\n pos = {}\n\n def stack(vertices, x, height):\n for i, v in enumerate(vertices):\n pos[v] = (x, i + offset + (height - len(vertices)) / 2)\n\n G = H.to_bipartite_graph()\n for ci in nx.connected_components(G):\n Gi = G.subgraph(ci)\n key = {v: i for i, v in enumerate(nx.spectral_ordering(Gi))}.get\n ci_vertices, ci_edges = [\n sorted([v for v, d in Gi.nodes(data=True) if d[\"bipartite\"] == j], key=key)\n for j in [0, 1]\n ]\n\n height = max(len(ci_vertices), len(ci_edges))\n\n stack(ci_vertices, 0, height)\n stack(ci_edges, 1, height)\n\n offset += height + spacing\n\n return pos\n
"},{"location":"reference/#src.aeiva.llm","title":"llm
","text":""},{"location":"reference/#src.aeiva.llm.llm_client","title":"llm_client
","text":""},{"location":"reference/#src.aeiva.llm.llm_client.LLMClient","title":"LLMClient
","text":"Language Model interface that supports synchronous, asynchronous, and streaming modes, and optionally, tool usage via function calls.
Source code in src/aeiva/llm/llm_client.py
class LLMClient:\n \"\"\"\n Language Model interface that supports synchronous, asynchronous, and streaming modes,\n and optionally, tool usage via function calls.\n \"\"\"\n\n def __init__(self, config: LLMGatewayConfig):\n self.config = config\n self.metrics = LLMUsageMetrics()\n self.logger = get_logger(__name__, level=config.llm_logging_level.upper())\n self._validate_config()\n\n def _validate_config(self):\n if not self.config.llm_api_key:\n raise ValueError(\"API key must be provided in the configuration.\")\n\n @retry_sync(\n max_attempts=lambda self: self.config.llm_num_retries,\n backoff_factor=lambda self: self.config.llm_retry_backoff_factor,\n exceptions=(LLMGatewayError,), # Catching LLMGatewayError\n )\n def generate(\n self, messages: List[Any], tools: List[Dict[str, Any]] = None, **kwargs\n ) -> str:\n try:\n max_iterations = MAX_TOOL_CALL_LOOP # Prevent infinite loops\n iteration = 0\n\n while iteration < max_iterations:\n iteration += 1\n\n # Build parameters\n params = self._build_params(messages=messages, tools=tools, **kwargs)\n response = llm_completion(**params)\n self._update_metrics(response)\n response_message = response.choices[0].message\n\n tool_calls = response_message.tool_calls\n\n if tool_calls:\n # Append assistant's tool call message\n messages.append({\"role\": \"assistant\", \"tool_calls\": tool_calls})\n\n for tool_call in tool_calls:\n function_name = tool_call.function.name\n function_args = json.loads(tool_call.function.arguments)\n tool_call_id = tool_call.id\n self.logger.info(f\"Tool call id: {tool_call_id}\")\n\n try:\n function_response = self.call_tool_sync(\n api_name=function_name, function_name=function_name, params=function_args\n )\n except Exception as e:\n self.logger.error(f\"Error executing tool '{function_name}': {e}\")\n function_response = f\"Error executing tool '{function_name}': {e}\"\n\n # Append the function response to messages\n messages.append(\n {\n \"tool_call_id\": tool_call_id,\n \"role\": \"tool\",\n \"name\": function_name,\n \"content\": str(function_response),\n }\n )\n # Continue the loop to handle further function calls\n continue\n else:\n # Assistant provided a final response\n messages.append({\"role\": \"assistant\", \"content\": response_message.content})\n return response_message.content\n\n # If loop exceeds max iterations\n raise Exception(\"Maximum iterations reached without a final response.\")\n\n except Exception as e:\n self.logger.error(f\"LLM Gateway Error: {e}\")\n raise llm_gateway_exception(e)\n\n @retry_async(\n max_attempts=lambda self: self.config.llm_num_retries,\n backoff_factor=lambda self: self.config.llm_retry_backoff_factor,\n exceptions=(LLMGatewayError,), # Catching LLMGatewayError\n )\n async def agenerate(\n self, messages: List[Any], tools: List[Dict[str, Any]] = None, **kwargs\n ) -> str:\n try:\n max_iterations = MAX_TOOL_CALL_LOOP # Prevent infinite loops\n iteration = 0\n\n while iteration < max_iterations:\n iteration += 1\n\n # Build parameters\n params = self._build_params(messages=messages, tools=tools, **kwargs)\n response = await llm_acompletion(**params)\n self._update_metrics(response)\n response_message = response.choices[0].message\n\n tool_calls = response_message.tool_calls\n\n if tool_calls:\n # Append assistant's tool call message\n messages.append({\"role\": \"assistant\", \"tool_calls\": tool_calls})\n\n for tool_call in tool_calls:\n function_name = tool_call.function.name\n function_args = json.loads(tool_call.function.arguments)\n tool_call_id = tool_call.id\n\n try:\n function_response = await self.call_tool(\n api_name=function_name, function_name=function_name, params=function_args\n )\n except Exception as e:\n self.logger.error(f\"Error executing tool '{function_name}': {e}\")\n function_response = f\"Error executing tool '{function_name}': {e}\"\n\n # Append the function response to messages\n messages.append(\n {\n \"tool_call_id\": tool_call_id,\n \"role\": \"tool\",\n \"name\": function_name,\n \"content\": str(function_response),\n }\n )\n # Continue the loop to handle further function calls\n continue\n else:\n # Assistant provided a final response\n messages.append({\"role\": \"assistant\", \"content\": response_message.content})\n return response_message.content\n\n # If loop exceeds max iterations\n raise Exception(\"Maximum iterations reached without a final response.\")\n\n except Exception as e:\n self.logger.error(f\"LLM Asynchronous Generation Error: {e}\")\n raise llm_gateway_exception(e)\n\n async def stream_generate(\n self, messages: List[Any], tools: List[Dict[str, Any]] = None, **kwargs\n ) -> AsyncGenerator[str, None]:\n try:\n max_iterations = MAX_TOOL_CALL_LOOP # Prevent infinite loops\n iteration = 0\n\n while iteration < max_iterations:\n iteration += 1\n\n # Build parameters\n params = self._build_params(messages=messages, tools=tools, **kwargs)\n response_stream = await llm_acompletion(**params)\n\n # Prepare to collect the assistant's reply\n tool_calls = [] # Accumulator for tool calls\n full_delta_content = '' # Accumulator for assistant's content\n\n # Collect streamed responses\n async for response in response_stream:\n delta = response.choices[0].delta\n\n # Collect assistant's content and yield it\n if getattr(delta, 'content', None):\n full_delta_content += delta.content\n yield delta.content\n\n # Check for tool calls in the delta\n if getattr(delta, 'tool_calls', None):\n tc_chunk_list = delta.tool_calls\n for tc_chunk in tc_chunk_list:\n index = tc_chunk.index\n # Ensure tool_calls list is large enough\n while len(tool_calls) <= index:\n tool_calls.append({\"id\": \"\", \"type\": \"function\", \"function\": {\"name\": \"\", \"arguments\": \"\"}})\n tc = tool_calls[index]\n\n if getattr(tc_chunk, 'id', None):\n tc[\"id\"] += tc_chunk.id\n if getattr(tc_chunk.function, 'name', None):\n tc[\"function\"][\"name\"] += tc_chunk.function.name\n if getattr(tc_chunk.function, 'arguments', None):\n tc[\"function\"][\"arguments\"] += tc_chunk.function.arguments\n\n # After initial streaming, check if there are tool calls\n if tool_calls:\n # Append the assistant's tool_call message to messages\n messages.append({\"role\": \"assistant\", \"tool_calls\": tool_calls})\n\n # Process each tool_call\n available_functions = [tool[\"function\"][\"name\"] for tool in tools]\n for tool_call in tool_calls:\n function_name = tool_call[\"function\"][\"name\"]\n if function_name not in available_functions:\n # Handle error if function not found\n yield f\"Function {function_name} does not exist.\"\n return\n # Call the function with arguments\n try:\n function_args = json.loads(tool_call[\"function\"][\"arguments\"])\n except json.JSONDecodeError as e:\n self.logger.error(f\"Error decoding function arguments: {e}\")\n function_args = {}\n\n try:\n function_response = await self.call_tool(\n api_name=function_name, function_name=function_name, params=function_args\n )\n except Exception as e:\n self.logger.error(f\"Error executing tool '{function_name}': {e}\")\n function_response = f\"Error executing tool '{function_name}': {e}\"\n\n # Append the function's response to messages\n messages.append(\n {\n \"tool_call_id\": tool_call['id'],\n \"role\": \"tool\",\n \"name\": function_name,\n \"content\": str(function_response),\n }\n )\n # Continue the loop to handle further function calls\n continue\n else:\n # No tool calls, streaming is complete\n messages.append({\"role\": \"assistant\", \"content\": full_delta_content})\n return # Exit the loop\n\n # If loop exceeds max iterations\n yield \"Maximum iterations reached without a final response.\"\n\n except Exception as e:\n self.logger.error(f\"Streaming LLM Gateway Error: {e}\")\n yield \"An error occurred during streaming.\"\n\n def call_tool_via_server(self, api_name: str, function_name: str, params: Dict[str, Any]) -> Any: # TODO: may need revise\n \"\"\"Calls the API via FastAPI server.\"\"\"\n url = f\"http://localhost:8000/api/{api_name}/{function_name}\"\n self.logger.info(f\"Calling {api_name} with params: {params}\")\n response = requests.get(url, params=params)\n if response.status_code == 200:\n json_response = response.json()\n if \"result\" in json_response:\n return str(json_response[\"result\"])\n else:\n return f\"Error from API: {json_response.get('error', 'Unknown error')}\"\n else:\n return f\"HTTP Error {response.status_code}: {response.text}\"\n\n async def call_tool(self, api_name: str, function_name: str, params: Dict[str, Any]) -> Any: # TODO: may need revise\n \"\"\"Calls the API via action module.\"\"\"\n tool = Tool(api_name)\n return await tool.aexecute(params)\n\n def call_tool_sync(self, api_name: str, function_name: str, params: Dict[str, Any]) -> Any: # TODO: may need revise\n \"\"\"Calls the API via action module.\"\"\"\n tool = Tool(api_name)\n return tool.execute(params)\n\n def _build_params(\n self, messages: List[Any], tools: List[Dict[str, Any]] = None, **kwargs\n ) -> Dict[str, Any]:\n params = {\n \"model\": self.config.llm_model_name,\n \"messages\": messages,\n \"api_key\": self.config.llm_api_key,\n \"temperature\": self.config.llm_temperature,\n \"top_p\": self.config.llm_top_p,\n \"max_tokens\": self.config.llm_max_output_tokens,\n \"timeout\": self.config.llm_timeout,\n }\n params.update(self.config.llm_additional_params)\n params.update(kwargs)\n\n # Check if the model supports function calling\n if tools and supports_function_calling(self.config.llm_model_name):\n params[\"tools\"] = tools\n params[\"tool_choice\"] = \"auto\"\n\n return params\n\n def _update_metrics(self, response: Any, log: bool = False): # Note: log is False by default. Adjust according to the need.\n usage = getattr(response, \"usage\", {})\n self.metrics.add_tokens(\n prompt_tokens=getattr(usage, \"prompt_tokens\", 0),\n completion_tokens=getattr(usage, \"completion_tokens\", 0),\n )\n self.metrics.add_cost(getattr(usage, \"cost\", 0.0))\n if log:\n self.logger.info(\n f\"Tokens used: {self.metrics.total_tokens}, Cost: ${self.metrics.total_cost:.4f}\"\n )\n\n def __call__(\n self, messages: List[Any], tools: List[Dict[str, Any]] = None, **kwargs\n ) -> Any:\n if self.config.llm_use_async:\n if self.config.llm_stream:\n return self.stream_generate(messages, tools=tools, **kwargs)\n else:\n return self.agenerate(messages, tools=tools, **kwargs)\n else:\n if self.config.llm_stream:\n # OpenAI's API does not support synchronous streaming; streaming must be async\n raise NotImplementedError(\"Synchronous streaming is not supported.\")\n else:\n return self.generate(messages, tools=tools, **kwargs)\n
"},{"location":"reference/#src.aeiva.llm.llm_client.LLMClient.call_tool","title":"call_tool(api_name, function_name, params)
async
","text":"Calls the API via action module.
Source code in src/aeiva/llm/llm_client.py
async def call_tool(self, api_name: str, function_name: str, params: Dict[str, Any]) -> Any: # TODO: may need revise\n \"\"\"Calls the API via action module.\"\"\"\n tool = Tool(api_name)\n return await tool.aexecute(params)\n
"},{"location":"reference/#src.aeiva.llm.llm_client.LLMClient.call_tool_sync","title":"call_tool_sync(api_name, function_name, params)
","text":"Calls the API via action module.
Source code in src/aeiva/llm/llm_client.py
def call_tool_sync(self, api_name: str, function_name: str, params: Dict[str, Any]) -> Any: # TODO: may need revise\n \"\"\"Calls the API via action module.\"\"\"\n tool = Tool(api_name)\n return tool.execute(params)\n
"},{"location":"reference/#src.aeiva.llm.llm_client.LLMClient.call_tool_via_server","title":"call_tool_via_server(api_name, function_name, params)
","text":"Calls the API via FastAPI server.
Source code in src/aeiva/llm/llm_client.py
def call_tool_via_server(self, api_name: str, function_name: str, params: Dict[str, Any]) -> Any: # TODO: may need revise\n \"\"\"Calls the API via FastAPI server.\"\"\"\n url = f\"http://localhost:8000/api/{api_name}/{function_name}\"\n self.logger.info(f\"Calling {api_name} with params: {params}\")\n response = requests.get(url, params=params)\n if response.status_code == 200:\n json_response = response.json()\n if \"result\" in json_response:\n return str(json_response[\"result\"])\n else:\n return f\"Error from API: {json_response.get('error', 'Unknown error')}\"\n else:\n return f\"HTTP Error {response.status_code}: {response.text}\"\n
"},{"location":"reference/#src.aeiva.llm.llm_gateway_config","title":"llm_gateway_config
","text":""},{"location":"reference/#src.aeiva.llm.llm_gateway_config.LLMGatewayConfig","title":"LLMGatewayConfig
dataclass
","text":" Bases: BaseConfig
Configuration for the Language Model (LLM).
Source code in src/aeiva/llm/llm_gateway_config.py
@dataclass\nclass LLMGatewayConfig(BaseConfig):\n \"\"\"\n Configuration for the Language Model (LLM).\n \"\"\"\n\n llm_model_name: Optional[str] = field(\n default='gpt-4',\n metadata={\"help\": \"The name of the LLM model to use (e.g., 'gpt-4', 'gpt-3.5-turbo').\"}\n )\n llm_api_key: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The API key for authentication with the LLM provider.\"}\n )\n llm_base_url: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The base URL for API requests to the LLM provider.\"}\n )\n llm_api_version: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The version of the LLM API to use.\"}\n )\n llm_embedding_model: Optional[str] = field(\n default=None,\n metadata={\"help\": \"The embedding model to use for tasks requiring embeddings.\"}\n )\n llm_timeout: Optional[int] = field(\n default=30,\n metadata={\"help\": \"The timeout in seconds for API requests.\"}\n )\n llm_max_input_tokens: Optional[int] = field(\n default=4096,\n metadata={\"help\": \"The maximum number of input tokens allowed in a request.\"}\n )\n llm_max_output_tokens: Optional[int] = field(\n default=1024,\n metadata={\"help\": \"The maximum number of output tokens generated by the LLM.\"}\n )\n llm_temperature: Optional[float] = field(\n default=0.7,\n metadata={\"help\": \"Sampling temperature for response variability (range: 0.0 - 1.0).\"}\n )\n llm_top_p: Optional[float] = field(\n default=0.9,\n metadata={\"help\": \"Nucleus sampling probability for token selection (range: 0.0 - 1.0).\"}\n )\n llm_num_retries: Optional[int] = field(\n default=3,\n metadata={\"help\": \"The number of times to retry failed API requests.\"}\n )\n llm_retry_backoff_factor: Optional[float] = field(\n default=0.5,\n metadata={\"help\": \"Factor for exponential backoff between retries.\"}\n )\n llm_retry_on_status: Optional[Tuple[int, ...]] = field(\n default=(429, 500, 502, 503, 504),\n metadata={\"help\": \"HTTP status codes that should trigger a retry.\"}\n )\n llm_use_async: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"Whether to use asynchronous API calls.\"}\n )\n llm_stream: Optional[bool] = field(\n default=False,\n metadata={\"help\": \"Whether to enable streaming responses from the LLM.\"}\n )\n llm_logging_level: Optional[str] = field(\n default='INFO',\n metadata={\"help\": \"Logging level for the LLM module (e.g., 'DEBUG', 'INFO').\"}\n )\n llm_additional_params: Optional[Dict[str, Any]] = field(\n default_factory=dict,\n metadata={\"help\": \"Additional parameters to pass to the LLM API.\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n # Load API keys from the configuration file if not provided\n if not self.llm_api_key:\n self.load_api_key()\n\n def load_api_key(self):\n config_path = os.path.join(os.path.dirname(__file__), '../../../configs/llm_api_keys.yaml')\n try:\n with open(config_path, 'r') as f:\n keys = yaml.safe_load(f)\n self.llm_api_key = keys.get('openai_api_key')\n except FileNotFoundError:\n raise FileNotFoundError('API keys file not found.')\n except Exception as e:\n raise e\n\n def to_dict(self):\n return {\n key: ('******' if key == 'llm_api_key' and value else value)\n for key, value in self.__dict__.items()\n if not key.startswith('_')\n }\n
"},{"location":"reference/#src.aeiva.llm.llm_gateway_exceptions","title":"llm_gateway_exceptions
","text":""},{"location":"reference/#src.aeiva.llm.llm_gateway_exceptions.LLMGatewayError","title":"LLMGatewayError
","text":" Bases: Exception
Unified exception class for all LLM-related errors.
Source code in src/aeiva/llm/llm_gateway_exceptions.py
class LLMGatewayError(Exception):\n \"\"\"Unified exception class for all LLM-related errors.\"\"\"\n\n def __init__(self, message: str, original_exception: Exception = None):\n super().__init__(message)\n self.original_exception = original_exception\n
"},{"location":"reference/#src.aeiva.llm.llm_gateway_exceptions.llm_gateway_exception","title":"llm_gateway_exception(e)
","text":"Converts a litellm exception to a unified LLMGatewayError.
Source code in src/aeiva/llm/llm_gateway_exceptions.py
def llm_gateway_exception(e: Exception) -> LLMGatewayError:\n \"\"\"Converts a litellm exception to a unified LLMGatewayError.\"\"\"\n exception_type = type(e)\n mapped_exception = LITELLM_EXCEPTION_MAP.get(exception_type, LLMGatewayError)\n return mapped_exception(str(e), original_exception=e)\n
"},{"location":"reference/#src.aeiva.llm.llm_usage_metrics","title":"llm_usage_metrics
","text":""},{"location":"reference/#src.aeiva.llm.llm_usage_metrics.LLMUsageMetrics","title":"LLMUsageMetrics
","text":"Tracks metrics such as token usage and cost.
Source code in src/aeiva/llm/llm_usage_metrics.py
class LLMUsageMetrics:\n \"\"\"\n Tracks metrics such as token usage and cost.\n \"\"\"\n def __init__(self):\n self.total_tokens = 0\n self.prompt_tokens = 0\n self.completion_tokens = 0\n self.total_cost = 0.0\n\n def add_tokens(self, prompt_tokens: int, completion_tokens: int):\n self.prompt_tokens += prompt_tokens\n self.completion_tokens += completion_tokens\n self.total_tokens += prompt_tokens + completion_tokens\n\n def add_cost(self, cost: float):\n self.total_cost += cost\n
"},{"location":"reference/#src.aeiva.model","title":"model
","text":""},{"location":"reference/#src.aeiva.model.macaw_model","title":"macaw_model
","text":""},{"location":"reference/#src.aeiva.model.macaw_model.LlamaAttention","title":"LlamaAttention
","text":" Bases: Module
Multi-headed attention from 'Attention Is All You Need' paper
Source code in src/aeiva/model/macaw_model.py
class LlamaAttention(nn.Module):\n \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n def __init__(self, config: LlamaConfig):\n super().__init__()\n self.config = config\n self.hidden_size = config.hidden_size\n self.num_heads = config.num_attention_heads\n self.head_dim = self.hidden_size // self.num_heads\n self.max_position_embeddings = config.max_position_embeddings\n\n if (self.head_dim * self.num_heads) != self.hidden_size:\n raise ValueError(\n f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}\"\n f\" and `num_heads`: {self.num_heads}).\"\n )\n self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)\n self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)\n\n def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n def forward(\n self,\n hidden_states: torch.Tensor,\n attention_mask: Optional[torch.Tensor] = None,\n position_ids: Optional[torch.LongTensor] = None,\n past_key_value: Optional[Tuple[torch.Tensor]] = None,\n output_attentions: bool = False,\n use_cache: bool = False,\n ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n bsz, q_len, _ = hidden_states.size()\n\n query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n\n kv_seq_len = key_states.shape[-2]\n if past_key_value is not None:\n kv_seq_len += past_key_value[0].shape[-2]\n cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n # [bsz, nh, t, hd]\n\n if past_key_value is not None:\n # reuse k, v, self_attention\n key_states = torch.cat([past_key_value[0], key_states], dim=2)\n value_states = torch.cat([past_key_value[1], value_states], dim=2)\n\n past_key_value = (key_states, value_states) if use_cache else None\n\n attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n raise ValueError(\n f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n f\" {attn_weights.size()}\"\n )\n\n if attention_mask is not None:\n if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n raise ValueError(\n f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n )\n attn_weights = attn_weights + attention_mask\n attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))\n\n # upcast attention to fp32\n attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n attn_output = torch.matmul(attn_weights, value_states)\n\n if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n raise ValueError(\n f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n f\" {attn_output.size()}\"\n )\n\n attn_output = attn_output.transpose(1, 2)\n attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n\n attn_output = self.o_proj(attn_output)\n\n if not output_attentions:\n attn_weights = None\n\n return attn_output, attn_weights, past_key_value\n
"},{"location":"reference/#src.aeiva.model.macaw_model.LlamaDecoderLayer","title":"LlamaDecoderLayer
","text":" Bases: Module
Source code in src/aeiva/model/macaw_model.py
class LlamaDecoderLayer(nn.Module):\n def __init__(self, config: LlamaConfig):\n super().__init__()\n self.hidden_size = config.hidden_size\n self.self_attn = LlamaAttention(config=config)\n self.mlp = LlamaMLP(\n hidden_size=self.hidden_size,\n intermediate_size=config.intermediate_size,\n hidden_act=config.hidden_act,\n )\n self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n def forward(\n self,\n hidden_states: torch.Tensor,\n attention_mask: Optional[torch.Tensor] = None,\n position_ids: Optional[torch.LongTensor] = None,\n past_key_value: Optional[Tuple[torch.Tensor]] = None,\n output_attentions: Optional[bool] = False,\n use_cache: Optional[bool] = False,\n ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n \"\"\"\n Args:\n hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n output_attentions (`bool`, *optional*):\n Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n returned tensors for more detail.\n use_cache (`bool`, *optional*):\n If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n (see `past_key_values`).\n past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n \"\"\"\n\n residual = hidden_states\n\n hidden_states = self.input_layernorm(hidden_states)\n\n # Self Attention\n hidden_states, self_attn_weights, present_key_value = self.self_attn(\n hidden_states=hidden_states,\n attention_mask=attention_mask,\n position_ids=position_ids,\n past_key_value=past_key_value,\n output_attentions=output_attentions,\n use_cache=use_cache,\n )\n hidden_states = residual + hidden_states\n\n # Fully Connected\n residual = hidden_states\n hidden_states = self.post_attention_layernorm(hidden_states)\n hidden_states = self.mlp(hidden_states)\n hidden_states = residual + hidden_states\n\n outputs = (hidden_states,)\n\n if output_attentions:\n outputs += (self_attn_weights,)\n\n if use_cache:\n outputs += (present_key_value,)\n\n return outputs\n
"},{"location":"reference/#src.aeiva.model.macaw_model.LlamaDecoderLayer.forward","title":"forward(hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False)
","text":"Parameters:
Name Type Description Default hidden_states
`torch.FloatTensor`
input to the layer of shape (batch, seq_len, embed_dim)
required attention_mask
`torch.FloatTensor`, *optional*
attention mask of size (batch, 1, tgt_len, src_len)
where padding elements are indicated by very large negative values.
None
output_attentions
`bool`, *optional*
Whether or not to return the attentions tensors of all attention layers. See attentions
under returned tensors for more detail.
False
use_cache
`bool`, *optional*
If set to True
, past_key_values
key value states are returned and can be used to speed up decoding (see past_key_values
).
False
past_key_value
`Tuple(torch.FloatTensor)`, *optional*
cached past key and value projection states
None
Source code in src/aeiva/model/macaw_model.py
def forward(\n self,\n hidden_states: torch.Tensor,\n attention_mask: Optional[torch.Tensor] = None,\n position_ids: Optional[torch.LongTensor] = None,\n past_key_value: Optional[Tuple[torch.Tensor]] = None,\n output_attentions: Optional[bool] = False,\n use_cache: Optional[bool] = False,\n) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n \"\"\"\n Args:\n hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n output_attentions (`bool`, *optional*):\n Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n returned tensors for more detail.\n use_cache (`bool`, *optional*):\n If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n (see `past_key_values`).\n past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n \"\"\"\n\n residual = hidden_states\n\n hidden_states = self.input_layernorm(hidden_states)\n\n # Self Attention\n hidden_states, self_attn_weights, present_key_value = self.self_attn(\n hidden_states=hidden_states,\n attention_mask=attention_mask,\n position_ids=position_ids,\n past_key_value=past_key_value,\n output_attentions=output_attentions,\n use_cache=use_cache,\n )\n hidden_states = residual + hidden_states\n\n # Fully Connected\n residual = hidden_states\n hidden_states = self.post_attention_layernorm(hidden_states)\n hidden_states = self.mlp(hidden_states)\n hidden_states = residual + hidden_states\n\n outputs = (hidden_states,)\n\n if output_attentions:\n outputs += (self_attn_weights,)\n\n if use_cache:\n outputs += (present_key_value,)\n\n return outputs\n
"},{"location":"reference/#src.aeiva.model.macaw_model.LlamaModel","title":"LlamaModel
","text":" Bases: LlamaPreTrainedModel
Transformer decoder consisting of config.num_hidden_layers layers. Each layer is a [LlamaDecoderLayer
]
Parameters:
Name Type Description Default config
LlamaConfig
LlamaConfig
required Source code in src/aeiva/model/macaw_model.py
class LlamaModel(LlamaPreTrainedModel):\n \"\"\"\n Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n\n Args:\n config: LlamaConfig\n \"\"\"\n\n def __init__(self, config: LlamaConfig):\n super().__init__(config)\n self.padding_idx = config.pad_token_id\n self.vocab_size = config.vocab_size\n\n self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])\n self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n self.gradient_checkpointing = False\n # Initialize weights and apply final processing\n self.post_init()\n\n def get_input_embeddings(self):\n return self.embed_tokens\n\n def set_input_embeddings(self, value):\n self.embed_tokens = value\n\n # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask\n def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):\n # create causal mask\n # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n combined_attention_mask = None\n if input_shape[-1] > 1:\n combined_attention_mask = _make_causal_mask(\n input_shape,\n inputs_embeds.dtype,\n device=inputs_embeds.device,\n past_key_values_length=past_key_values_length,\n )\n\n if attention_mask is not None:\n # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(\n inputs_embeds.device\n )\n combined_attention_mask = (\n expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n )\n\n return combined_attention_mask\n\n # @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n def forward(\n self,\n input_ids: torch.LongTensor = None,\n attention_mask: Optional[torch.Tensor] = None,\n position_ids: Optional[torch.LongTensor] = None,\n past_key_values: Optional[List[torch.FloatTensor]] = None,\n inputs_embeds: Optional[torch.FloatTensor] = None,\n use_cache: Optional[bool] = None,\n output_attentions: Optional[bool] = None,\n output_hidden_states: Optional[bool] = None,\n return_dict: Optional[bool] = None,\n ) -> Union[Tuple, BaseModelOutputWithPast]:\n output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n output_hidden_states = (\n output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n )\n use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n # retrieve input_ids and inputs_embeds\n if input_ids is not None and inputs_embeds is not None:\n raise ValueError(\"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time\")\n elif input_ids is not None:\n batch_size, seq_length = input_ids.shape\n elif inputs_embeds is not None:\n batch_size, seq_length, _ = inputs_embeds.shape\n else:\n raise ValueError(\"You have to specify either decoder_input_ids or decoder_inputs_embeds\")\n\n seq_length_with_past = seq_length\n past_key_values_length = 0\n\n if past_key_values is not None:\n past_key_values_length = past_key_values[0][0].shape[2]\n seq_length_with_past = seq_length_with_past + past_key_values_length\n\n if position_ids is None:\n device = input_ids.device if input_ids is not None else inputs_embeds.device\n position_ids = torch.arange(\n past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device\n )\n position_ids = position_ids.unsqueeze(0).view(-1, seq_length)\n else:\n position_ids = position_ids.view(-1, seq_length).long()\n\n if inputs_embeds is None:\n inputs_embeds = self.embed_tokens(input_ids)\n # embed positions\n if attention_mask is None:\n attention_mask = torch.ones(\n (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device\n )\n attention_mask = self._prepare_decoder_attention_mask(\n attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length\n )\n\n hidden_states = inputs_embeds\n\n if self.gradient_checkpointing and self.training:\n if use_cache:\n logger.warning_once(\n \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n )\n use_cache = False\n\n # decoder layers\n all_hidden_states = () if output_hidden_states else None\n all_self_attns = () if output_attentions else None\n next_decoder_cache = () if use_cache else None\n\n for idx, decoder_layer in enumerate(self.layers):\n if output_hidden_states:\n all_hidden_states += (hidden_states,)\n\n past_key_value = past_key_values[idx] if past_key_values is not None else None\n\n if self.gradient_checkpointing and self.training:\n\n def create_custom_forward(module):\n def custom_forward(*inputs):\n # None for past_key_value\n return module(*inputs, output_attentions, None)\n\n return custom_forward\n\n layer_outputs = torch.utils.checkpoint.checkpoint(\n create_custom_forward(decoder_layer),\n hidden_states,\n attention_mask,\n position_ids,\n None,\n )\n else:\n layer_outputs = decoder_layer(\n hidden_states,\n attention_mask=attention_mask,\n position_ids=position_ids,\n past_key_value=past_key_value,\n output_attentions=output_attentions,\n use_cache=use_cache,\n )\n\n hidden_states = layer_outputs[0]\n\n if use_cache:\n next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)\n\n if output_attentions:\n all_self_attns += (layer_outputs[1],)\n\n hidden_states = self.norm(hidden_states)\n\n # add hidden states from the last decoder layer\n if output_hidden_states:\n all_hidden_states += (hidden_states,)\n\n next_cache = next_decoder_cache if use_cache else None\n if not return_dict:\n return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)\n return BaseModelOutputWithPast(\n last_hidden_state=hidden_states,\n past_key_values=next_cache,\n hidden_states=all_hidden_states,\n attentions=all_self_attns,\n )\n
"},{"location":"reference/#src.aeiva.model.macaw_model.LlamaRMSNorm","title":"LlamaRMSNorm
","text":" Bases: Module
Source code in src/aeiva/model/macaw_model.py
class LlamaRMSNorm(nn.Module):\n def __init__(self, hidden_size, eps=1e-6):\n \"\"\"\n LlamaRMSNorm is equivalent to T5LayerNorm\n \"\"\"\n super().__init__()\n self.weight = nn.Parameter(torch.ones(hidden_size))\n self.variance_epsilon = eps\n\n def forward(self, hidden_states):\n variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)\n hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n\n # convert into half-precision if necessary\n if self.weight.dtype in [torch.float16, torch.bfloat16]:\n hidden_states = hidden_states.to(self.weight.dtype)\n\n return self.weight * hidden_states\n
"},{"location":"reference/#src.aeiva.model.macaw_model.LlamaRMSNorm.__init__","title":"__init__(hidden_size, eps=1e-06)
","text":"LlamaRMSNorm is equivalent to T5LayerNorm
Source code in src/aeiva/model/macaw_model.py
def __init__(self, hidden_size, eps=1e-6):\n \"\"\"\n LlamaRMSNorm is equivalent to T5LayerNorm\n \"\"\"\n super().__init__()\n self.weight = nn.Parameter(torch.ones(hidden_size))\n self.variance_epsilon = eps\n
"},{"location":"reference/#src.aeiva.model.macaw_model.MM_LLMs_Config","title":"MM_LLMs_Config
","text":" Bases: PretrainedConfig
Source code in src/aeiva/model/macaw_model.py
class MM_LLMs_Config(PretrainedConfig):\n model_type = 'mm_llms'\n is_composition = True\n\n def __init__(self, n_frames=6, attention_heads=8, image_conv_kernel=48, image_conv_stride=36, \n video_conv_kernel=36, video_conv_stride=30, audio_conv_kernel=240, audio_conv_stride=220,\n clip_config=None, whisper_config=None, llm_config=None, **kwargs):\n\n self.image_config = clip_config\n self.audio_config = whisper_config\n self.llm_config = llm_config\n self.n_frames = n_frames\n self.attention_heads = attention_heads\n self.image_conv_kernel = image_conv_kernel\n self.image_conv_stride = image_conv_stride\n self.video_conv_kernel = video_conv_kernel\n self.video_conv_stride = video_conv_stride\n self.audio_conv_kernel = audio_conv_kernel\n self.audio_conv_stride = audio_conv_stride\n\n self.hidden_size = max(llm_config.hidden_size, clip_config.projection_dim, whisper_config.d_model, clip_config.projection_dim)\n\n super().__init__(**kwargs)\n\n def to_dict(self):\n \"\"\"\n Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].\n\n Returns:\n `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,\n \"\"\"\n output = copy.deepcopy(self.__dict__)\n output[\"image_config\"] = self.image_config.to_dict()\n output[\"audio_config\"] = self.audio_config.to_dict()\n output['llm_config'] = self.llm_config.to_dict()\n output['n_frames'] = self.n_frames\n output['attention_heads'] = self.attention_heads\n output['image_conv_kernel'] = self.image_conv_kernel\n output['image_conv_stride'] = self.image_conv_stride\n output['video_conv_kernel'] = self.video_conv_kernel\n output['video_conv_stride'] = self.video_conv_stride\n output['audio_conv_kernel'] = self.audio_conv_kernel\n output['audio_conv_stride'] = self.audio_conv_stride\n output['hidden_size'] = self.hidden_size\n output[\"model_type\"] = self.__class__.model_type\n return output\n @classmethod\n def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):\n config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)\n\n clip_config = CLIPConfig.from_dict(config_dict['image_config'])\n whisper_config = WhisperConfig.from_dict(config_dict['audio_config'])\n llm_config = LlamaConfig.from_dict(config_dict['llm_config'])\n\n return cls(clip_config=clip_config, whisper_config=whisper_config, llm_config=llm_config, **kwargs)\n
"},{"location":"reference/#src.aeiva.model.macaw_model.MM_LLMs_Config.to_dict","title":"to_dict()
","text":"Serializes this instance to a Python dictionary. Override the default [~PretrainedConfig.to_dict
].
Returns:
Type Description Dict[str, any]
: Dictionary of all the attributes that make up this configuration instance,
Source code in src/aeiva/model/macaw_model.py
def to_dict(self):\n \"\"\"\n Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].\n\n Returns:\n `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,\n \"\"\"\n output = copy.deepcopy(self.__dict__)\n output[\"image_config\"] = self.image_config.to_dict()\n output[\"audio_config\"] = self.audio_config.to_dict()\n output['llm_config'] = self.llm_config.to_dict()\n output['n_frames'] = self.n_frames\n output['attention_heads'] = self.attention_heads\n output['image_conv_kernel'] = self.image_conv_kernel\n output['image_conv_stride'] = self.image_conv_stride\n output['video_conv_kernel'] = self.video_conv_kernel\n output['video_conv_stride'] = self.video_conv_stride\n output['audio_conv_kernel'] = self.audio_conv_kernel\n output['audio_conv_stride'] = self.audio_conv_stride\n output['hidden_size'] = self.hidden_size\n output[\"model_type\"] = self.__class__.model_type\n return output\n
"},{"location":"reference/#src.aeiva.model.macaw_model.WhisperEncoder","title":"WhisperEncoder
","text":" Bases: WhisperPreTrainedModel
Transformer encoder consisting of config.encoder_layers self attention layers. Each layer is a [WhisperEncoderLayer
].
Parameters:
Name Type Description Default config
WhisperConfig
WhisperConfig
required Source code in src/aeiva/model/macaw_model.py
class WhisperEncoder(WhisperPreTrainedModel):\n \"\"\"\n Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a\n [`WhisperEncoderLayer`].\n\n Args:\n config: WhisperConfig\n \"\"\"\n\n def __init__(self, config: WhisperConfig):\n super().__init__(config)\n self.dropout = config.dropout\n self.layerdrop = config.encoder_layerdrop\n\n embed_dim = config.d_model\n self.num_mel_bins = config.num_mel_bins\n self.padding_idx = config.pad_token_id\n self.max_source_positions = config.max_source_positions\n self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0\n\n self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)\n self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)\n\n self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)\n\n self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)])\n self.layer_norm = nn.LayerNorm(config.d_model)\n\n self.gradient_checkpointing = False\n # Initialize weights and apply final processing\n self.post_init()\n\n def _freeze_parameters(self):\n for param in self.parameters():\n param.requires_grad = False\n self._requires_grad = False\n\n def get_input_embeddings(self) -> nn.Module:\n return self.conv1\n\n def set_input_embeddings(self, value: nn.Module):\n self.conv1 = value\n\n def forward(\n self,\n input_features,\n attention_mask=None,\n head_mask=None,\n output_attentions=None,\n output_hidden_states=None,\n return_dict=None,\n ):\n r\"\"\"\n Args:\n input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):\n Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be\n obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a\n `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into\n `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding\n and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]\n attention_mask (`torch.Tensor`)`, *optional*):\n Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,\n but it is not used. By default the silence in the input log mel spectrogram are ignored.\n head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):\n Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:\n\n - 1 indicates the head is **not masked**,\n - 0 indicates the head is **masked**.\n output_attentions (`bool`, *optional*):\n Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n returned tensors for more detail.\n output_hidden_states (`bool`, *optional*):\n Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors\n for more detail.\n return_dict (`bool`, *optional*):\n Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n \"\"\"\n output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n output_hidden_states = (\n output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n )\n return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n inputs_embeds = nn.functional.gelu(self.conv1(input_features))\n inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))\n\n inputs_embeds = inputs_embeds.permute(0, 2, 1)\n embed_pos = self.embed_positions.weight\n\n hidden_states = inputs_embeds + embed_pos\n hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)\n\n encoder_states = () if output_hidden_states else None\n all_attentions = () if output_attentions else None\n\n # check if head_mask has a correct number of layers specified if desired\n if head_mask is not None:\n assert head_mask.size()[0] == (\n len(self.layers)\n ), f\"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}.\"\n\n for idx, encoder_layer in enumerate(self.layers):\n if output_hidden_states:\n encoder_states = encoder_states + (hidden_states,)\n # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)\n dropout_probability = random.uniform(0, 1)\n if self.training and (dropout_probability < self.layerdrop): # skip the layer\n layer_outputs = (None, None)\n else:\n if self.gradient_checkpointing and self.training:\n\n def create_custom_forward(module):\n def custom_forward(*inputs):\n return module(*inputs, output_attentions)\n\n return custom_forward\n\n layer_outputs = torch.utils.checkpoint.checkpoint(\n create_custom_forward(encoder_layer),\n hidden_states,\n None,\n (head_mask[idx] if head_mask is not None else None),\n )\n else:\n layer_outputs = encoder_layer(\n hidden_states,\n None,\n layer_head_mask=(head_mask[idx] if head_mask is not None else None),\n output_attentions=output_attentions,\n )\n\n hidden_states = layer_outputs[0]\n\n if output_attentions:\n all_attentions = all_attentions + (layer_outputs[1],)\n\n hidden_states = self.layer_norm(hidden_states)\n if output_hidden_states:\n encoder_states = encoder_states + (hidden_states,)\n\n if not return_dict:\n return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)\n return BaseModelOutput(\n last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions\n )\n
"},{"location":"reference/#src.aeiva.model.macaw_model.WhisperEncoder.forward","title":"forward(input_features, attention_mask=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None)
","text":"Parameters:
Name Type Description Default input_features
`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`
Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by loading a .flac
or .wav
audio file into an array of type List[float]
or a numpy.ndarray
, e.g. via the soundfile library (pip install soundfile
). To prepare the array into input_features
, the [AutoFeatureExtractor
] should be used for extracting the mel features, padding and conversion into a tensor of type torch.FloatTensor
. See [~WhisperFeatureExtractor.__call__
]
required attention_mask
`torch.Tensor`)`, *optional*
Whisper does not support masking of the input_features
, this argument is preserved for compatibility, but it is not used. By default the silence in the input log mel spectrogram are ignored.
None
head_mask
`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*
Mask to nullify selected heads of the attention modules. Mask values selected in [0, 1]
:
- 1 indicates the head is not masked,
- 0 indicates the head is masked.
None
output_attentions
`bool`, *optional*
Whether or not to return the attentions tensors of all attention layers. See attentions
under returned tensors for more detail.
None
output_hidden_states
`bool`, *optional*
Whether or not to return the hidden states of all layers. See hidden_states
under returned tensors for more detail.
None
return_dict
`bool`, *optional*
Whether or not to return a [~utils.ModelOutput
] instead of a plain tuple.
None
Source code in src/aeiva/model/macaw_model.py
def forward(\n self,\n input_features,\n attention_mask=None,\n head_mask=None,\n output_attentions=None,\n output_hidden_states=None,\n return_dict=None,\n):\n r\"\"\"\n Args:\n input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):\n Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be\n obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a\n `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into\n `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding\n and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]\n attention_mask (`torch.Tensor`)`, *optional*):\n Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,\n but it is not used. By default the silence in the input log mel spectrogram are ignored.\n head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):\n Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:\n\n - 1 indicates the head is **not masked**,\n - 0 indicates the head is **masked**.\n output_attentions (`bool`, *optional*):\n Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n returned tensors for more detail.\n output_hidden_states (`bool`, *optional*):\n Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors\n for more detail.\n return_dict (`bool`, *optional*):\n Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n \"\"\"\n output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n output_hidden_states = (\n output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n )\n return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n inputs_embeds = nn.functional.gelu(self.conv1(input_features))\n inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))\n\n inputs_embeds = inputs_embeds.permute(0, 2, 1)\n embed_pos = self.embed_positions.weight\n\n hidden_states = inputs_embeds + embed_pos\n hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)\n\n encoder_states = () if output_hidden_states else None\n all_attentions = () if output_attentions else None\n\n # check if head_mask has a correct number of layers specified if desired\n if head_mask is not None:\n assert head_mask.size()[0] == (\n len(self.layers)\n ), f\"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}.\"\n\n for idx, encoder_layer in enumerate(self.layers):\n if output_hidden_states:\n encoder_states = encoder_states + (hidden_states,)\n # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)\n dropout_probability = random.uniform(0, 1)\n if self.training and (dropout_probability < self.layerdrop): # skip the layer\n layer_outputs = (None, None)\n else:\n if self.gradient_checkpointing and self.training:\n\n def create_custom_forward(module):\n def custom_forward(*inputs):\n return module(*inputs, output_attentions)\n\n return custom_forward\n\n layer_outputs = torch.utils.checkpoint.checkpoint(\n create_custom_forward(encoder_layer),\n hidden_states,\n None,\n (head_mask[idx] if head_mask is not None else None),\n )\n else:\n layer_outputs = encoder_layer(\n hidden_states,\n None,\n layer_head_mask=(head_mask[idx] if head_mask is not None else None),\n output_attentions=output_attentions,\n )\n\n hidden_states = layer_outputs[0]\n\n if output_attentions:\n all_attentions = all_attentions + (layer_outputs[1],)\n\n hidden_states = self.layer_norm(hidden_states)\n if output_hidden_states:\n encoder_states = encoder_states + (hidden_states,)\n\n if not return_dict:\n return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)\n return BaseModelOutput(\n last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions\n )\n
"},{"location":"reference/#src.aeiva.model.macaw_model.rotate_half","title":"rotate_half(x)
","text":"Rotates half the hidden dims of the input.
Source code in src/aeiva/model/macaw_model.py
def rotate_half(x):\n \"\"\"Rotates half the hidden dims of the input.\"\"\"\n x1 = x[..., : x.shape[-1] // 2]\n x2 = x[..., x.shape[-1] // 2 :]\n return torch.cat((-x2, x1), dim=-1)\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old","title":"macaw_model_old
","text":"This script contains the implementation of the MACAW model. MACAW is a multimodal transformer model that combines the CLIP and Whisper models.
Author: Bang Liu Date: 2023-06-22
References: - Macaw-LLM code repository: https://github.com/lyuchenyang/Macaw-LLM/blob/main/modeling.py
"},{"location":"reference/#src.aeiva.model.macaw_model_old.LlamaAttention","title":"LlamaAttention
","text":" Bases: Module
Multi-headed attention from 'Attention Is All You Need' paper
Source code in src/aeiva/model/macaw_model_old.py
class LlamaAttention(nn.Module):\n \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n def __init__(self, config: LlamaConfig):\n super().__init__()\n self.config = config\n self.hidden_size = config.hidden_size\n self.num_heads = config.num_attention_heads\n self.head_dim = self.hidden_size // self.num_heads\n self.max_position_embeddings = config.max_position_embeddings # !!! I want to change this variable name.\n\n if (self.head_dim * self.num_heads) != self.hidden_size:\n raise ValueError(\n f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}\"\n f\" and `num_heads`: {self.num_heads}).\"\n )\n self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)\n self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)\n\n def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n def forward(\n self,\n hidden_states: torch.Tensor,\n attention_mask: Optional[torch.Tensor] = None,\n position_ids: Optional[torch.LongTensor] = None,\n past_key_value: Optional[Tuple[torch.Tensor]] = None,\n output_attentions: bool = False,\n use_cache: bool = False,\n ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n bsz, q_len, _ = hidden_states.size()\n\n # By placing the num_heads dimension as the second dimension, it allows for \n # efficient batched matrix operations (e.g., matrix multiplication in attention computation) \n # across all the heads. It is basically a data layout optimization for computational efficiency \n # in the context of multi-head attention.\n query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n\n kv_seq_len = key_states.shape[-2] # the shape is [batch_size, num_heads, seq_len, head_dim], so -2 dimension is 'seq_len'\n if past_key_value is not None: \n # If past_key_value is not None, this means the model is being used in an autoregressive setting, \n # where the past key-value pairs are given to the current step.\n # past_key_value[0] refers to the previously computed key states,\n # past_key_value[1] refers to the previously computed value states.\n # The shape of past_key_value[0] and past_key_value[1] is [batch_size, num_heads, seq_len, head_dim].\n kv_seq_len += past_key_value[0].shape[-2] # + past seq_len\n\n cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n\n if past_key_value is not None:\n # reuse k, v, self_attention\n key_states = torch.cat([past_key_value[0], key_states], dim=2)\n value_states = torch.cat([past_key_value[1], value_states], dim=2)\n\n past_key_value = (key_states, value_states) if use_cache else None\n\n attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n raise ValueError(\n f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n f\" {attn_weights.size()}\"\n )\n\n if attention_mask is not None:\n if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n raise ValueError(\n f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n )\n attn_weights = attn_weights + attention_mask\n # This following line is ensuring numerical stability. It caps the minimum value of the attention weights\n # to be the minimum finite representable number for the data type of attn_weights. This avoids \n # potential issues with underflow when these weights are later passed through the softmax function.\n attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))\n\n # upcast attention to fp32\n # This is done to prevent numerical instability that can occur\n # during operations on very small numbers or very large numbers.\n attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n attn_output = torch.matmul(attn_weights, value_states)\n\n if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n raise ValueError(\n f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n f\" {attn_output.size()}\"\n )\n\n attn_output = attn_output.transpose(1, 2)\n attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) # self.hidden_size is equivalent to self.num_heads * self.head_dim\n\n attn_output = self.o_proj(attn_output)\n\n if not output_attentions:\n attn_weights = None\n\n return attn_output, attn_weights, past_key_value\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.LlamaDecoderLayer","title":"LlamaDecoderLayer
","text":" Bases: Module
Source code in src/aeiva/model/macaw_model_old.py
class LlamaDecoderLayer(nn.Module):\n def __init__(self, config: LlamaConfig):\n super().__init__()\n self.hidden_size = config.hidden_size\n self.self_attn = LlamaAttention(config=config)\n self.mlp = LlamaMLP(\n hidden_size=self.hidden_size,\n intermediate_size=config.intermediate_size,\n hidden_act=config.hidden_act,\n )\n self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n def forward(\n self,\n hidden_states: torch.Tensor,\n attention_mask: Optional[torch.Tensor] = None,\n position_ids: Optional[torch.LongTensor] = None,\n past_key_value: Optional[Tuple[torch.Tensor]] = None,\n output_attentions: Optional[bool] = False,\n use_cache: Optional[bool] = False,\n ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n \"\"\"\n Args:\n hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n output_attentions (`bool`, *optional*):\n Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n returned tensors for more detail.\n use_cache (`bool`, *optional*):\n If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n (see `past_key_values`).\n past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n \"\"\"\n\n residual = hidden_states\n\n hidden_states = self.input_layernorm(hidden_states)\n\n # Self Attention\n hidden_states, self_attn_weights, present_key_value = self.self_attn(\n hidden_states=hidden_states,\n attention_mask=attention_mask,\n position_ids=position_ids,\n past_key_value=past_key_value,\n output_attentions=output_attentions,\n use_cache=use_cache,\n )\n hidden_states = residual + hidden_states\n\n # Fully Connected\n residual = hidden_states\n hidden_states = self.post_attention_layernorm(hidden_states)\n hidden_states = self.mlp(hidden_states)\n hidden_states = residual + hidden_states\n\n outputs = (hidden_states,)\n\n if output_attentions:\n outputs += (self_attn_weights,)\n\n if use_cache:\n outputs += (present_key_value,)\n\n return outputs\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.LlamaDecoderLayer.forward","title":"forward(hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False)
","text":"Parameters:
Name Type Description Default hidden_states
`torch.FloatTensor`
input to the layer of shape (batch, seq_len, embed_dim)
required attention_mask
`torch.FloatTensor`, *optional*
attention mask of size (batch, 1, tgt_len, src_len)
where padding elements are indicated by very large negative values.
None
output_attentions
`bool`, *optional*
Whether or not to return the attentions tensors of all attention layers. See attentions
under returned tensors for more detail.
False
use_cache
`bool`, *optional*
If set to True
, past_key_values
key value states are returned and can be used to speed up decoding (see past_key_values
).
False
past_key_value
`Tuple(torch.FloatTensor)`, *optional*
cached past key and value projection states
None
Source code in src/aeiva/model/macaw_model_old.py
def forward(\n self,\n hidden_states: torch.Tensor,\n attention_mask: Optional[torch.Tensor] = None,\n position_ids: Optional[torch.LongTensor] = None,\n past_key_value: Optional[Tuple[torch.Tensor]] = None,\n output_attentions: Optional[bool] = False,\n use_cache: Optional[bool] = False,\n) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n \"\"\"\n Args:\n hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n output_attentions (`bool`, *optional*):\n Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n returned tensors for more detail.\n use_cache (`bool`, *optional*):\n If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n (see `past_key_values`).\n past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n \"\"\"\n\n residual = hidden_states\n\n hidden_states = self.input_layernorm(hidden_states)\n\n # Self Attention\n hidden_states, self_attn_weights, present_key_value = self.self_attn(\n hidden_states=hidden_states,\n attention_mask=attention_mask,\n position_ids=position_ids,\n past_key_value=past_key_value,\n output_attentions=output_attentions,\n use_cache=use_cache,\n )\n hidden_states = residual + hidden_states\n\n # Fully Connected\n residual = hidden_states\n hidden_states = self.post_attention_layernorm(hidden_states)\n hidden_states = self.mlp(hidden_states)\n hidden_states = residual + hidden_states\n\n outputs = (hidden_states,)\n\n if output_attentions:\n outputs += (self_attn_weights,)\n\n if use_cache:\n outputs += (present_key_value,)\n\n return outputs\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.LlamaModel","title":"LlamaModel
","text":" Bases: LlamaPreTrainedModel
Transformer decoder consisting of config.num_hidden_layers layers. Each layer is a [LlamaDecoderLayer
]
Parameters:
Name Type Description Default config
LlamaConfig
LlamaConfig
required Source code in src/aeiva/model/macaw_model_old.py
class LlamaModel(LlamaPreTrainedModel):\n \"\"\"\n Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n\n Args:\n config: LlamaConfig\n \"\"\"\n\n def __init__(self, config: LlamaConfig):\n super().__init__(config)\n # embedding layer, stacked decoder layers, and layer normalization in llama.\n self.padding_idx = config.pad_token_id\n self.vocab_size = config.vocab_size\n self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n\n self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])\n\n self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n # Gradient checkpointing is a technique to reduce the memory usage when training deep neural networks.\n # In deep learning, when you perform backpropagation to compute gradients and update the model parameters,\n # you need to store the intermediate activations from the forward pass, so you can use them in the backward pass. \n # For large models or long sequences, this can consume a lot of memory.\n # \n # Gradient checkpointing addresses this by not storing all the intermediate activations in memory during the forward pass. \n # Instead, it stores only a subset of the activations, and recomputes the rest during the backward pass as needed. \n # This trades off computation time (because you need to recompute some values) for memory usage.\n # \n # This technique is particularly useful when training large models that would otherwise not fit into GPU memory. \n # However, it can slow down training because of the extra computation.\n self.gradient_checkpointing = False\n\n # Initialize weights and apply final processing\n self.post_init()\n\n def get_input_embeddings(self):\n return self.embed_tokens\n\n def set_input_embeddings(self, value):\n self.embed_tokens = value\n\n # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask\n def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):\n # create causal mask\n # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n combined_attention_mask = None\n if input_shape[-1] > 1: # seq_len > 1\n combined_attention_mask = _make_causal_mask(\n input_shape,\n inputs_embeds.dtype,\n device=inputs_embeds.device,\n past_key_values_length=past_key_values_length,\n )\n\n if attention_mask is not None:\n # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(\n inputs_embeds.device\n )\n combined_attention_mask = (\n expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n )\n\n return combined_attention_mask\n\n # @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n def forward(\n self,\n input_ids: torch.LongTensor = None,\n attention_mask: Optional[torch.Tensor] = None,\n position_ids: Optional[torch.LongTensor] = None,\n past_key_values: Optional[List[torch.FloatTensor]] = None,\n inputs_embeds: Optional[torch.FloatTensor] = None,\n use_cache: Optional[bool] = None,\n output_attentions: Optional[bool] = None,\n output_hidden_states: Optional[bool] = None,\n return_dict: Optional[bool] = None,\n ) -> Union[Tuple, BaseModelOutputWithPast]:\n # set output and cache flags\n output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n use_cache = use_cache if use_cache is not None else self.config.use_cache\n return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n # prepare input_ids/inputs_embeds\n if input_ids is not None and inputs_embeds is not None:\n raise ValueError(\"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time\")\n elif input_ids is not None:\n batch_size, seq_length = input_ids.shape\n elif inputs_embeds is not None:\n batch_size, seq_length, _ = inputs_embeds.shape\n else:\n raise ValueError(\"You have to specify either decoder_input_ids or decoder_inputs_embeds\")\n if inputs_embeds is None:\n inputs_embeds = self.embed_tokens(input_ids)\n\n # prepare attention mask and other parameters for decoder layers\n past_key_values_length = 0\n seq_length_with_past = seq_length\n\n if past_key_values is not None:\n past_key_values_length = past_key_values[0][0].shape[2]\n seq_length_with_past = seq_length_with_past + past_key_values_length\n\n if position_ids is None:\n device = input_ids.device if input_ids is not None else inputs_embeds.device\n position_ids = torch.arange(\n past_key_values_length, past_key_values_length + seq_length, dtype=torch.long, device=device\n )\n position_ids = position_ids.unsqueeze(0).view(-1, seq_length)\n else:\n position_ids = position_ids.view(-1, seq_length).long()\n\n if attention_mask is None:\n attention_mask = torch.ones(\n (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device\n )\n attention_mask = self._prepare_decoder_attention_mask(\n attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length\n )\n\n hidden_states = inputs_embeds\n\n if self.gradient_checkpointing and self.training:\n if use_cache:\n logger.warning_once(\n \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n )\n use_cache = False\n\n # forward through all decoder layers\n all_hidden_states = () if output_hidden_states else None\n all_self_attns = () if output_attentions else None\n next_decoder_cache = () if use_cache else None\n\n for idx, decoder_layer in enumerate(self.layers):\n if output_hidden_states:\n all_hidden_states += (hidden_states,)\n\n past_key_value = past_key_values[idx] if past_key_values is not None else None\n\n if self.gradient_checkpointing and self.training:\n # define the function for gradient checkpointing\n # in checkpointing, we need to create a custom function for the forward pass \n # (the custom_forward function in your code) and then using the \n # torch.utils.checkpoint.checkpoint function to apply this custom function \n # with gradient checkpointing.\n def create_custom_forward(module):\n def custom_forward(*inputs):\n return module(*inputs, output_attentions, None) # None for past_key_value\n return custom_forward\n\n layer_outputs = torch.utils.checkpoint.checkpoint(\n create_custom_forward(decoder_layer),\n hidden_states,\n attention_mask,\n position_ids,\n None,\n )\n else:\n layer_outputs = decoder_layer(\n hidden_states,\n attention_mask=attention_mask,\n position_ids=position_ids,\n past_key_value=past_key_value,\n output_attentions=output_attentions,\n use_cache=use_cache,\n )\n\n hidden_states = layer_outputs[0]\n\n if use_cache:\n next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)\n\n if output_attentions:\n all_self_attns += (layer_outputs[1],)\n\n hidden_states = self.norm(hidden_states)\n\n # add hidden states from the last decoder layer\n if output_hidden_states:\n all_hidden_states += (hidden_states,)\n\n next_cache = next_decoder_cache if use_cache else None\n\n # output the hidden states, the self attentions and the cache (if needed)\n if not return_dict:\n return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)\n return BaseModelOutputWithPast(\n last_hidden_state=hidden_states,\n past_key_values=next_cache,\n hidden_states=all_hidden_states,\n attentions=all_self_attns,\n )\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.LlamaRMSNorm","title":"LlamaRMSNorm
","text":" Bases: Module
Source code in src/aeiva/model/macaw_model_old.py
class LlamaRMSNorm(nn.Module):\n def __init__(self, hidden_size, eps=1e-6):\n \"\"\"\n LlamaRMSNorm is equivalent to T5LayerNorm\n The overall effect of this layer is to ensure that,\n for each feature in the hidden_states,\n the activations have zero mean and unit variance across the batch.\n This can make the training process more stable and faster.\n \"\"\"\n super().__init__()\n self.weight = nn.Parameter(torch.ones(hidden_size)) # trainable parameter for affine transformation\n self.variance_epsilon = eps # for numerical stability\n\n def forward(self, hidden_states):\n variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)\n hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n\n # convert into half-precision if necessary\n if self.weight.dtype in [torch.float16, torch.bfloat16]:\n hidden_states = hidden_states.to(self.weight.dtype)\n\n return self.weight * hidden_states\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.LlamaRMSNorm.__init__","title":"__init__(hidden_size, eps=1e-06)
","text":"LlamaRMSNorm is equivalent to T5LayerNorm The overall effect of this layer is to ensure that, for each feature in the hidden_states, the activations have zero mean and unit variance across the batch. This can make the training process more stable and faster.
Source code in src/aeiva/model/macaw_model_old.py
def __init__(self, hidden_size, eps=1e-6):\n \"\"\"\n LlamaRMSNorm is equivalent to T5LayerNorm\n The overall effect of this layer is to ensure that,\n for each feature in the hidden_states,\n the activations have zero mean and unit variance across the batch.\n This can make the training process more stable and faster.\n \"\"\"\n super().__init__()\n self.weight = nn.Parameter(torch.ones(hidden_size)) # trainable parameter for affine transformation\n self.variance_epsilon = eps # for numerical stability\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.LlamaRotaryEmbedding","title":"LlamaRotaryEmbedding
","text":" Bases: Module
Rotary embedding described in: https://arxiv.org/pdf/2104.09864.pdf. It is used to modulate the position information in the input embeddings. Llama used rotary embedding.
Source code in src/aeiva/model/macaw_model_old.py
class LlamaRotaryEmbedding(torch.nn.Module):\n \"\"\"\n Rotary embedding described in: https://arxiv.org/pdf/2104.09864.pdf.\n It is used to modulate the position information in the input embeddings.\n Llama used rotary embedding.\n \"\"\"\n def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n super().__init__()\n # Compute the inverse frequencies, which will be used to modulate the position information\n inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))\n # The register_buffer() function is used in PyTorch to register a tensor that is not a parameter,\n # but you still want it to be a part of the model's state. It's used for tensors that should\n # have their state saved in the model's state_dict and should be moved to the device with the rest of the model.\n self.register_buffer(\"inv_freq\", inv_freq)\n\n # Build here to make `torch.jit.trace` work.\n # max_position_embeddings: max sequence length that this model might ever be used with\n self.max_seq_len_cached = max_position_embeddings\n\n # Compute the positional encodings (both cos and sin parts)\n t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)\n freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n\n # Different from paper, but it uses a different permutation in order to obtain the same calculation\n emb = torch.cat((freqs, freqs), dim=-1)\n self.register_buffer(\"cos_cached\", emb.cos()[None, None, :, :], persistent=False)\n self.register_buffer(\"sin_cached\", emb.sin()[None, None, :, :], persistent=False)\n\n def forward(self, x, seq_len=None):\n # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.\n # x.shape: [batch_size, num_attention_heads, sequence_length, head_size].\n # The forward function then outputs two tensors, each of which is a sin or cos embedding representation of the input x. \n # Both output tensors will have a shape of [1, 1, sequence_length, head_size].\n # NOTE: Only the dtype and device attributes of x are relevant here. The values are not used.\n if seq_len > self.max_seq_len_cached:\n self.max_seq_len_cached = seq_len\n t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)\n freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n # Different from paper, but it uses a different permutation in order to obtain the same calculation\n emb = torch.cat((freqs, freqs), dim=-1).to(x.device)\n self.register_buffer(\"cos_cached\", emb.cos()[None, None, :, :], persistent=False)\n self.register_buffer(\"sin_cached\", emb.sin()[None, None, :, :], persistent=False)\n return (\n self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),\n self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),\n )\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.MM_LLMs","title":"MM_LLMs
","text":" Bases: PreTrainedModel
This is the multimodal language model that combines CLIP and Whisper encoders with a language model. We need a config file to specify the multimodal encoder configurations.
Source code in src/aeiva/model/macaw_model_old.py
class MM_LLMs(PreTrainedModel):\n \"\"\"\n This is the multimodal language model that combines CLIP and Whisper encoders with a language model.\n We need a config file to specify the multimodal encoder configurations.\n \"\"\"\n def __init__(self, config):\n super().__init__(config)\n # multimodal config\n self.config = config\n\n # multimodal encoders\n self.image_encoder = CLIPModel(config.image_config) # NOTE: here they use CLIP for both image and video.\n self.video_encoder = CLIPModel(config.image_config)\n self.audio_encoder = WhisperModel(config.audio_config)\n self.llm = LlamaForCausalLM(config.llm_config)\n\n # video temporal position embedding layer\n self.temporal_position_embeddings = nn.Embedding(\n config.n_frames, \n config.image_config.projection_dim)\n\n # multimodal attention layers for mapping multimodal features to the same space\n attn_dropout = 0.1\n is_add_bias_kv = True\n is_add_zero_attn = True\n self.temporal_self_attention = nn.MultiheadAttention(config.image_config.projection_dim,\n config.attention_heads,\n dropout=attn_dropout,\n add_bias_kv=is_add_bias_kv,\n add_zero_attn=is_add_zero_attn)\n self.video_align_attention = nn.MultiheadAttention(config.llm_config.hidden_size, \n config.attention_heads,\n dropout=attn_dropout,\n add_bias_kv=is_add_bias_kv,\n add_zero_attn=is_add_zero_attn)\n self.audio_align_attention = nn.MultiheadAttention(config.llm_config.hidden_size, \n config.attention_heads,\n dropout=attn_dropout,\n add_bias_kv=is_add_bias_kv,\n add_zero_attn=is_add_zero_attn)\n self.image_align_attention = nn.MultiheadAttention(config.llm_config.hidden_size, \n config.attention_heads,\n dropout=attn_dropout,\n add_bias_kv=is_add_bias_kv,\n add_zero_attn=is_add_zero_attn)\n\n # multimodal projection layers for mapping multimodal features to the same space\n self.transform_video_to_hidden = nn.Linear(config.image_config.projection_dim, \n config.llm_config.hidden_size)\n self.transform_audio_to_hidden = nn.Linear(config.audio_config.d_model, \n config.llm_config.hidden_size)\n self.transform_image_to_hidden = nn.Linear(config.image_config.projection_dim, \n config.llm_config.hidden_size)\n\n self.project_image = nn.Conv1d(config.image_config.projection_dim, config.image_config.projection_dim, \n kernel_size=48, stride=36)\n self.project_video = nn.Conv1d(config.image_config.projection_dim, config.image_config.projection_dim, \n kernel_size=36, stride=30)\n self.project_audio = nn.Conv1d(config.audio_config.d_model, config.audio_config.d_model, \n kernel_size=240, stride=220)\n\n # multimodal fusion layers\n self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))\n\n self.layer_norm = nn.LayerNorm(config.image_config.projection_dim)\n self.softmax = nn.Softmax(dim=-1)\n self.relu = nn.ReLU()\n self.gelu = nn.GELU()\n self.elu = nn.ELU()\n self.sigmoid = nn.Sigmoid()\n\n self.loss_fct = CrossEntropyLoss()\n\n self.init_weights()\n\n def forward(self, inputs=None):\n # \"\"\"\n # :param inputs:\n # video_frames: (B x F)\n # audios: B x 1\n # images: B x 1\n # input_ids: B x L\n # labels: B x L\n #\n # :return: the output of the language model LlamaForCausalLM.\n # \"\"\"\n text_embeddings, attention_mask, labels = self.prepare_inputs_for_generation(inputs)\n\n if 'inference' in inputs and inputs['inference'] is True:\n # generate_ids = self.llm.generate(input_ids=inputs['input_ids'], inputs_embeds=text_embeddings, max_new_tokens=128)\n # generate_ids = self.llm.generate(inputs_embeds=text_embeddings, max_new_tokens=128)\n\n # !!! The code below will possibly trigger an error in : https://github.com/microsoft/DeepSpeed/issues/3156 (the solution only partially resolves the bug for me)\n generate_ids = self.llm.generate(\n inputs_embeds=text_embeddings, max_new_tokens=128, eos_token_id=2, bos_token_id=1, pad_token_id=32006 # !!! revise later. use config constants instead.\n )\n return generate_ids\n outputs = self.llm(inputs_embeds=text_embeddings, attention_mask=attention_mask, labels=labels)\n\n return outputs\n\n def prepare_inputs_for_generation(self, inputs):\n \"\"\"\n The purpose of this method is to integrate the different modalities into the text embeddings \n and prepare the associated attention mask and labels for the language model, so the model can \n generate text conditioned on all the input modalities.\n\n inputs is a dictionary containing the following keys: (!!! my hypothesis)\n video_frames: (B x F)\n audios: B x 1\n images: B x 1\n input_ids: B x L\n attention_mask: B x L\n labels: B x L\n video_starts: B x 1\n video_ends: B x 1\n audio_starts: B x 1\n audio_ends: B x 1\n image_starts: B x 1\n image_ends: B x 1\n inference: True/False\n \"\"\"\n # get multimodal embeddings\n image_features = self.encode_image(inputs['images']) if inputs['images'] is not None else None\n audio_features = self.encode_audio(inputs['audios']) if inputs['audios'] is not None else None\n video_features = self.encode_video(inputs['videos']) if inputs['videos'] is not None else None\n embed_tokens = self.llm.model.embed_tokens\n\n\n # for debug !!!!!!\n # Find maximum id in input_ids\n max_id = torch.max(inputs['input_ids'])\n print(f\"Max ID in input_ids: {max_id.item()}\")\n\n # Get vocab size from embedding layer\n vocab_size = embed_tokens.num_embeddings\n print(f\"Vocabulary size: {vocab_size}\")\n\n\n\n text_embeddings = embed_tokens(inputs['input_ids'])\n\n token_embeddings = embed_tokens.weight.unsqueeze(0).repeat(\n text_embeddings.size(0), 1, 1).transpose(0, 1)\n\n # ignore_num seems to be a counter that tracks the total size (or length) of the \n # multimodal input segments (video, audio, image) added to the original text inputs.\n ingore_num = 0\n\n # project and merge video features to the same space as text embeddings\n if video_features is not None:\n # get video starts and ends embeddings\n video_starts = embed_tokens(inputs['video_starts']).unsqueeze(1)\n video_ends = embed_tokens(inputs['video_ends']).unsqueeze(1)\n\n # project video features to the same space as text embeddings\n video_features = self.transform_video_to_hidden(video_features)\n\n video_features = self.video_align_attention(\n video_features.transpose(0, 1), token_embeddings, token_embeddings)[0].transpose(0, 1).contiguous()\n\n # concatenate video starts, video features, and video ends embeddings\n video_inputs = torch.cat([torch.cat([video_starts, video_features], dim=1), video_ends], dim=1)\n\n # concatenate video inputs to the original text embeddings\n # NOTE: the first token of text_embeddings keeps at the same position\n text_embeddings = torch.cat([torch.cat([text_embeddings[:, 0, :].unsqueeze(1), video_inputs], dim=1), text_embeddings[:, 1:, :]], dim=1)\n\n ingore_num += (video_inputs.size(1))\n\n # project and merge audio features to the same space as text embeddings\n if audio_features is not None:\n # get audio starts and ends embeddings\n audio_starts = embed_tokens(inputs['audio_starts']).unsqueeze(1)\n audio_ends = embed_tokens(inputs['audio_ends']).unsqueeze(1)\n\n # project audio features to the same space as text embeddings\n audio_features = self.project_audio(audio_features.transpose(1, 2).contiguous()).transpose(1, 2).contiguous()\n audio_features = self.transform_audio_to_hidden(audio_features)\n # mean pooling\n # audio_features = torch.sum(audio_features, dim=1) / audio_features.size(1) \n # audio_features = audio_features.unsqueeze(1)\n audio_features = self.audio_align_attention(\n audio_features.transpose(0, 1), token_embeddings, token_embeddings)[0].transpose(0, 1).contiguous()\n\n # concatenate audio starts, audio features, and audio ends embeddings\n audio_inputs = torch.cat([torch.cat([audio_starts, audio_features], dim=1), audio_ends], dim=1)\n\n # concatenate audio inputs to the original text embeddings\n text_embeddings = torch.cat(\n [torch.cat([text_embeddings[:, 0, :].unsqueeze(1), audio_inputs], dim=1), text_embeddings[:, 1:, :]],\n dim=1)\n\n ingore_num += (audio_inputs.size(1))\n\n # project and merge image features to the same space as text embeddings\n if image_features is not None:\n # get image starts and ends embeddings\n image_starts = embed_tokens(inputs['image_starts']).unsqueeze(1)\n image_ends = embed_tokens(inputs['image_ends']).unsqueeze(1)\n\n # project image features to the same space as text embeddings\n image_features = self.project_image(image_features.transpose(1, 2).contiguous()).transpose(1, 2).contiguous()\n image_features = self.transform_image_to_hidden(image_features)\n image_features = self.image_align_attention(\n image_features.transpose(0, 1), token_embeddings, token_embeddings)[0].transpose(0, 1).contiguous()\n\n # concatenate image starts, image features, and image ends embeddings\n image_inputs = torch.cat([torch.cat([image_starts, image_features], dim=1), image_ends], dim=1)\n\n # concatenate image inputs to the original text embeddings\n text_embeddings = torch.cat(\n [torch.cat([text_embeddings[:, 0, :].unsqueeze(1), image_inputs], dim=1), \n text_embeddings[:, 1:, :]], dim=1)\n\n ingore_num += (image_inputs.size(1))\n\n if 'attention_mask' in inputs:\n # increase the length of attention mask by adding the length of multimodal inputs\n attention_mask = torch.tensor([1]*ingore_num*text_embeddings.size(0), device=text_embeddings.device).view(text_embeddings.size(0), -1) # (B X ignore_num)\n attention_mask = torch.cat([attention_mask, inputs['attention_mask']], dim=1)\n else:\n attention_mask = None\n\n if 'labels' in inputs and inputs['labels'] is not None:\n # increase the length of labels by adding the length of labels\n # we use -100 to ignore the loss of labels in multimodal inputs\n # !!! we can replace -100 by config constants to make the code better\n\n # since the tokens corresponding to the image_inputs, audio_inputs, and video_inputs are not part of the original text \n # and don't have corresponding labels in the true text sequence, their labels are set to -100. This ensures that \n # the model's predictions for these tokens don't affect the loss and, consequently, the gradients and the model's subsequent learning.\n labels = torch.tensor([-100]*ingore_num*text_embeddings.size(0), device=text_embeddings.device).view(text_embeddings.size(0), -1)\n labels = torch.cat([labels, inputs['labels']], dim=1)\n else:\n labels = None\n\n # text_embeddings: (batch_size, sequence_length + ingore_num, embedding_dim)\n # attention_mask: (batch_size, sequence_length + ingore_num). 1 denotes we should attend to, and 0 denotes we should not attend to the token.\n # labels: (batch_size, sequence_length + ingore_num). -100 denotes we should ignore the token.\n return text_embeddings, attention_mask, labels\n\n def encode_video(self, videos):\n \"\"\"\n Encode video features to video embeddings.\n\n Args:\n videos: (batch_size, n_frames, n_channels, height, width)\n\n Returns:\n video_embeddings: (batch_size, n_frames, embedding_dim)\n \"\"\"\n # simple image encoding without temporal embedding and self attention\n # Reference: https://huggingface.co/docs/transformers/model_doc/clip\n videos = videos.view(-1, videos.size(-3), videos.size(-2), videos.size(-1)) # pixel_values (torch.FloatTensor of shape (batch_size * n_frames, num_channels, height, width)) \n video_outputs = self.video_encoder.get_image_features(videos) # image_features (torch.FloatTensor of shape (batch_size * n_frames, output_dim)\n video_features = video_outputs\n temporal_pos = torch.tensor(\n [[i for i in range(self.config.n_frames)] \n for j in range(videos.size(0) // self.config.n_frames)],\n dtype=torch.int, device=video_features.device).view(-1) # 2d indices to 1d indices, shape: (batch_size * n_frames)\n\n frame_temporal_pos_embed = self.temporal_position_embeddings(temporal_pos)\n\n video_features = (video_features + frame_temporal_pos_embed).view(\n videos.size(0) // self.config.n_frames, self.config.n_frames, -1) # (batch_size, n_frames, output_dim)\n\n video_features = video_features.transpose(0, 1).contiguous()\n # nn.MultiheadAttention takes query, key, value as inputs. Their shapes are (sequence_length, batch_size, embedding_dim).\n # The outputs are two elements: attn_output of shape (sequence_length, batch_size, embedding_dim), and attn_output_weights of shape (batch_size, sequence_length, sequence_length).\n self_attn_video_features = self.temporal_self_attention(video_features, video_features, video_features)[0]\n\n return self_attn_video_features.transpose(0, 1).contiguous() # (batch_size, n_frames, output_dim)\n\n def encode_video_long(self, videos):\n \"\"\"\n Encode video features to video embeddings.\n\n Args:\n videos: (batch_size, n_frames, n_channels, height, width)\n\n Returns:\n video_embeddings: (batch_size, n_frames, embedding_dim)\n \"\"\"\n # simple image encoding without temporal embedding and self attention\n videos = videos.view(-1, videos.size(-3), videos.size(-2), videos.size(-1)) # pixel_values (torch.FloatTensor of shape (batch_size * n_frames, num_channels, height, width))\n video_features = self.video_encoder.visual_projection(self.video_encoder.vision_model(videos)[0])[:, 1:, :]\n video_features = video_features.reshape(\n videos.size(0) // self.config.n_frames,\n self.config.n_frames * video_features.size(1),\n -1).contiguous()\n\n return video_features\n\n def encode_audio(self, audios):\n audio_features = self.audio_encoder.encoder(audios)\n return audio_features[0]\n\n def encode_image(self, images):\n # vision_outputs = self.image_encoder.get_image_features(images)\n # image_features = vision_outputs # pooled_output\n # image_features = self.visual_projection(pooled_output)\n # image_features = image_features.unsqueeze(1)\n image_features = self.image_encoder.visual_projection(self.image_encoder.vision_model(images)[0])[:, 1:, :]\n return image_features\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.MM_LLMs.encode_video","title":"encode_video(videos)
","text":"Encode video features to video embeddings.
Parameters:
Name Type Description Default videos
(batch_size, n_frames, n_channels, height, width)
required Returns:
Name Type Description video_embeddings
(batch_size, n_frames, embedding_dim)
Source code in src/aeiva/model/macaw_model_old.py
def encode_video(self, videos):\n \"\"\"\n Encode video features to video embeddings.\n\n Args:\n videos: (batch_size, n_frames, n_channels, height, width)\n\n Returns:\n video_embeddings: (batch_size, n_frames, embedding_dim)\n \"\"\"\n # simple image encoding without temporal embedding and self attention\n # Reference: https://huggingface.co/docs/transformers/model_doc/clip\n videos = videos.view(-1, videos.size(-3), videos.size(-2), videos.size(-1)) # pixel_values (torch.FloatTensor of shape (batch_size * n_frames, num_channels, height, width)) \n video_outputs = self.video_encoder.get_image_features(videos) # image_features (torch.FloatTensor of shape (batch_size * n_frames, output_dim)\n video_features = video_outputs\n temporal_pos = torch.tensor(\n [[i for i in range(self.config.n_frames)] \n for j in range(videos.size(0) // self.config.n_frames)],\n dtype=torch.int, device=video_features.device).view(-1) # 2d indices to 1d indices, shape: (batch_size * n_frames)\n\n frame_temporal_pos_embed = self.temporal_position_embeddings(temporal_pos)\n\n video_features = (video_features + frame_temporal_pos_embed).view(\n videos.size(0) // self.config.n_frames, self.config.n_frames, -1) # (batch_size, n_frames, output_dim)\n\n video_features = video_features.transpose(0, 1).contiguous()\n # nn.MultiheadAttention takes query, key, value as inputs. Their shapes are (sequence_length, batch_size, embedding_dim).\n # The outputs are two elements: attn_output of shape (sequence_length, batch_size, embedding_dim), and attn_output_weights of shape (batch_size, sequence_length, sequence_length).\n self_attn_video_features = self.temporal_self_attention(video_features, video_features, video_features)[0]\n\n return self_attn_video_features.transpose(0, 1).contiguous() # (batch_size, n_frames, output_dim)\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.MM_LLMs.encode_video_long","title":"encode_video_long(videos)
","text":"Encode video features to video embeddings.
Parameters:
Name Type Description Default videos
(batch_size, n_frames, n_channels, height, width)
required Returns:
Name Type Description video_embeddings
(batch_size, n_frames, embedding_dim)
Source code in src/aeiva/model/macaw_model_old.py
def encode_video_long(self, videos):\n \"\"\"\n Encode video features to video embeddings.\n\n Args:\n videos: (batch_size, n_frames, n_channels, height, width)\n\n Returns:\n video_embeddings: (batch_size, n_frames, embedding_dim)\n \"\"\"\n # simple image encoding without temporal embedding and self attention\n videos = videos.view(-1, videos.size(-3), videos.size(-2), videos.size(-1)) # pixel_values (torch.FloatTensor of shape (batch_size * n_frames, num_channels, height, width))\n video_features = self.video_encoder.visual_projection(self.video_encoder.vision_model(videos)[0])[:, 1:, :]\n video_features = video_features.reshape(\n videos.size(0) // self.config.n_frames,\n self.config.n_frames * video_features.size(1),\n -1).contiguous()\n\n return video_features\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.MM_LLMs.prepare_inputs_for_generation","title":"prepare_inputs_for_generation(inputs)
","text":"The purpose of this method is to integrate the different modalities into the text embeddings and prepare the associated attention mask and labels for the language model, so the model can generate text conditioned on all the input modalities.
(!!! my hypothesis) video_frames: (B x F) audios: B x 1 images: B x 1 input_ids: B x L attention_mask: B x L labels: B x L video_starts: B x 1 video_ends: B x 1 audio_starts: B x 1 audio_ends: B x 1 image_starts: B x 1 image_ends: B x 1 inference: True/False
Source code in src/aeiva/model/macaw_model_old.py
def prepare_inputs_for_generation(self, inputs):\n \"\"\"\n The purpose of this method is to integrate the different modalities into the text embeddings \n and prepare the associated attention mask and labels for the language model, so the model can \n generate text conditioned on all the input modalities.\n\n inputs is a dictionary containing the following keys: (!!! my hypothesis)\n video_frames: (B x F)\n audios: B x 1\n images: B x 1\n input_ids: B x L\n attention_mask: B x L\n labels: B x L\n video_starts: B x 1\n video_ends: B x 1\n audio_starts: B x 1\n audio_ends: B x 1\n image_starts: B x 1\n image_ends: B x 1\n inference: True/False\n \"\"\"\n # get multimodal embeddings\n image_features = self.encode_image(inputs['images']) if inputs['images'] is not None else None\n audio_features = self.encode_audio(inputs['audios']) if inputs['audios'] is not None else None\n video_features = self.encode_video(inputs['videos']) if inputs['videos'] is not None else None\n embed_tokens = self.llm.model.embed_tokens\n\n\n # for debug !!!!!!\n # Find maximum id in input_ids\n max_id = torch.max(inputs['input_ids'])\n print(f\"Max ID in input_ids: {max_id.item()}\")\n\n # Get vocab size from embedding layer\n vocab_size = embed_tokens.num_embeddings\n print(f\"Vocabulary size: {vocab_size}\")\n\n\n\n text_embeddings = embed_tokens(inputs['input_ids'])\n\n token_embeddings = embed_tokens.weight.unsqueeze(0).repeat(\n text_embeddings.size(0), 1, 1).transpose(0, 1)\n\n # ignore_num seems to be a counter that tracks the total size (or length) of the \n # multimodal input segments (video, audio, image) added to the original text inputs.\n ingore_num = 0\n\n # project and merge video features to the same space as text embeddings\n if video_features is not None:\n # get video starts and ends embeddings\n video_starts = embed_tokens(inputs['video_starts']).unsqueeze(1)\n video_ends = embed_tokens(inputs['video_ends']).unsqueeze(1)\n\n # project video features to the same space as text embeddings\n video_features = self.transform_video_to_hidden(video_features)\n\n video_features = self.video_align_attention(\n video_features.transpose(0, 1), token_embeddings, token_embeddings)[0].transpose(0, 1).contiguous()\n\n # concatenate video starts, video features, and video ends embeddings\n video_inputs = torch.cat([torch.cat([video_starts, video_features], dim=1), video_ends], dim=1)\n\n # concatenate video inputs to the original text embeddings\n # NOTE: the first token of text_embeddings keeps at the same position\n text_embeddings = torch.cat([torch.cat([text_embeddings[:, 0, :].unsqueeze(1), video_inputs], dim=1), text_embeddings[:, 1:, :]], dim=1)\n\n ingore_num += (video_inputs.size(1))\n\n # project and merge audio features to the same space as text embeddings\n if audio_features is not None:\n # get audio starts and ends embeddings\n audio_starts = embed_tokens(inputs['audio_starts']).unsqueeze(1)\n audio_ends = embed_tokens(inputs['audio_ends']).unsqueeze(1)\n\n # project audio features to the same space as text embeddings\n audio_features = self.project_audio(audio_features.transpose(1, 2).contiguous()).transpose(1, 2).contiguous()\n audio_features = self.transform_audio_to_hidden(audio_features)\n # mean pooling\n # audio_features = torch.sum(audio_features, dim=1) / audio_features.size(1) \n # audio_features = audio_features.unsqueeze(1)\n audio_features = self.audio_align_attention(\n audio_features.transpose(0, 1), token_embeddings, token_embeddings)[0].transpose(0, 1).contiguous()\n\n # concatenate audio starts, audio features, and audio ends embeddings\n audio_inputs = torch.cat([torch.cat([audio_starts, audio_features], dim=1), audio_ends], dim=1)\n\n # concatenate audio inputs to the original text embeddings\n text_embeddings = torch.cat(\n [torch.cat([text_embeddings[:, 0, :].unsqueeze(1), audio_inputs], dim=1), text_embeddings[:, 1:, :]],\n dim=1)\n\n ingore_num += (audio_inputs.size(1))\n\n # project and merge image features to the same space as text embeddings\n if image_features is not None:\n # get image starts and ends embeddings\n image_starts = embed_tokens(inputs['image_starts']).unsqueeze(1)\n image_ends = embed_tokens(inputs['image_ends']).unsqueeze(1)\n\n # project image features to the same space as text embeddings\n image_features = self.project_image(image_features.transpose(1, 2).contiguous()).transpose(1, 2).contiguous()\n image_features = self.transform_image_to_hidden(image_features)\n image_features = self.image_align_attention(\n image_features.transpose(0, 1), token_embeddings, token_embeddings)[0].transpose(0, 1).contiguous()\n\n # concatenate image starts, image features, and image ends embeddings\n image_inputs = torch.cat([torch.cat([image_starts, image_features], dim=1), image_ends], dim=1)\n\n # concatenate image inputs to the original text embeddings\n text_embeddings = torch.cat(\n [torch.cat([text_embeddings[:, 0, :].unsqueeze(1), image_inputs], dim=1), \n text_embeddings[:, 1:, :]], dim=1)\n\n ingore_num += (image_inputs.size(1))\n\n if 'attention_mask' in inputs:\n # increase the length of attention mask by adding the length of multimodal inputs\n attention_mask = torch.tensor([1]*ingore_num*text_embeddings.size(0), device=text_embeddings.device).view(text_embeddings.size(0), -1) # (B X ignore_num)\n attention_mask = torch.cat([attention_mask, inputs['attention_mask']], dim=1)\n else:\n attention_mask = None\n\n if 'labels' in inputs and inputs['labels'] is not None:\n # increase the length of labels by adding the length of labels\n # we use -100 to ignore the loss of labels in multimodal inputs\n # !!! we can replace -100 by config constants to make the code better\n\n # since the tokens corresponding to the image_inputs, audio_inputs, and video_inputs are not part of the original text \n # and don't have corresponding labels in the true text sequence, their labels are set to -100. This ensures that \n # the model's predictions for these tokens don't affect the loss and, consequently, the gradients and the model's subsequent learning.\n labels = torch.tensor([-100]*ingore_num*text_embeddings.size(0), device=text_embeddings.device).view(text_embeddings.size(0), -1)\n labels = torch.cat([labels, inputs['labels']], dim=1)\n else:\n labels = None\n\n # text_embeddings: (batch_size, sequence_length + ingore_num, embedding_dim)\n # attention_mask: (batch_size, sequence_length + ingore_num). 1 denotes we should attend to, and 0 denotes we should not attend to the token.\n # labels: (batch_size, sequence_length + ingore_num). -100 denotes we should ignore the token.\n return text_embeddings, attention_mask, labels\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.MM_LLMs_Config","title":"MM_LLMs_Config
","text":" Bases: PretrainedConfig
This is the configuration class to store the configuration of a MM_LLMsModel
. It contains class level and instance level attributes. It also contains the load (from_pretrained) and save (to_dict) methods for saving and loading configuration files.
Source code in src/aeiva/model/macaw_model_old.py
class MM_LLMs_Config(PretrainedConfig):\n \"\"\"\n This is the configuration class to store the configuration of a `MM_LLMsModel`.\n It contains class level and instance level attributes.\n It also contains the load (from_pretrained) and save (to_dict) methods for saving and loading configuration files.\n \"\"\"\n # general class attributes for all model instances\n model_type = 'mm_llms'\n is_composition = True\n\n def __init__(self, n_frames=6, attention_heads=8, clip_config=None, whisper_config=None, llm_config=None, **kwargs):\n self.image_config = clip_config\n self.audio_config = whisper_config\n self.llm_config = llm_config # language model config\n self.n_frames = n_frames # video config information. How many frames are used for each video clip.\n self.attention_heads = attention_heads\n self.hidden_size = max(llm_config.hidden_size, clip_config.projection_dim, whisper_config.d_model, clip_config.projection_dim)\n super().__init__(**kwargs)\n\n def to_dict(self):\n \"\"\"\n Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].\n This method overrides the base class method to include serialization of the \n image, audio, and language model configurations along with the base configuration.\n\n Returns:\n `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,\n \"\"\"\n output = copy.deepcopy(self.__dict__)\n output[\"image_config\"] = self.image_config.to_dict()\n output[\"audio_config\"] = self.audio_config.to_dict()\n output['llm_config'] = self.llm_config.to_dict()\n output['n_frames'] = self.n_frames\n output['attention_heads'] = self.attention_heads\n output['hidden_size'] = self.hidden_size\n output[\"model_type\"] = self.__class__.model_type\n return output\n\n @classmethod\n def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):\n config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)\n\n clip_config = CLIPConfig.from_dict(config_dict['image_config'])\n whisper_config = WhisperConfig.from_dict(config_dict['audio_config'])\n llm_config = LlamaConfig.from_dict(config_dict['llm_config'])\n\n return cls(clip_config=clip_config, whisper_config=whisper_config, llm_config=llm_config, **kwargs)\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.MM_LLMs_Config.to_dict","title":"to_dict()
","text":"Serializes this instance to a Python dictionary. Override the default [~PretrainedConfig.to_dict
]. This method overrides the base class method to include serialization of the image, audio, and language model configurations along with the base configuration.
Returns:
Type Description Dict[str, any]
: Dictionary of all the attributes that make up this configuration instance,
Source code in src/aeiva/model/macaw_model_old.py
def to_dict(self):\n \"\"\"\n Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].\n This method overrides the base class method to include serialization of the \n image, audio, and language model configurations along with the base configuration.\n\n Returns:\n `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,\n \"\"\"\n output = copy.deepcopy(self.__dict__)\n output[\"image_config\"] = self.image_config.to_dict()\n output[\"audio_config\"] = self.audio_config.to_dict()\n output['llm_config'] = self.llm_config.to_dict()\n output['n_frames'] = self.n_frames\n output['attention_heads'] = self.attention_heads\n output['hidden_size'] = self.hidden_size\n output[\"model_type\"] = self.__class__.model_type\n return output\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.WhisperEncoder","title":"WhisperEncoder
","text":" Bases: WhisperPreTrainedModel
Transformer encoder consisting of config.encoder_layers self attention layers. Each layer is a [WhisperEncoderLayer
].
Parameters:
Name Type Description Default config
WhisperConfig
WhisperConfig
required Source code in src/aeiva/model/macaw_model_old.py
class WhisperEncoder(WhisperPreTrainedModel):\n \"\"\"\n Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a\n [`WhisperEncoderLayer`].\n\n Args:\n config: WhisperConfig\n \"\"\"\n\n def __init__(self, config: WhisperConfig):\n super().__init__(config)\n self.dropout = config.dropout\n self.layerdrop = config.encoder_layerdrop\n\n embed_dim = config.d_model\n # num_mel_bins corresponds to the number of features extracted from the audio signal for each time step. \n # When we convert audio to a Mel spectrogram, each time step (or frame) in the spectrogram \n # is represented by a feature vector of size num_mel_bins. \n self.num_mel_bins = config.num_mel_bins\n self.padding_idx = config.pad_token_id\n self.max_source_positions = config.max_source_positions\n # embed_scale is a scaling factor that is applied to the embeddings.\n self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0\n\n self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)\n self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)\n\n # position embedding layer\n self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)\n\n self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)])\n self.layer_norm = nn.LayerNorm(config.d_model)\n\n self.gradient_checkpointing = False\n # Initialize weights and apply final processing\n self.post_init()\n\n def _freeze_parameters(self):\n for param in self.parameters():\n param.requires_grad = False\n self._requires_grad = False\n\n def get_input_embeddings(self) -> nn.Module:\n return self.conv1\n\n def set_input_embeddings(self, value: nn.Module):\n self.conv1 = value\n\n def forward(\n self,\n input_features,\n attention_mask=None,\n head_mask=None,\n output_attentions=None,\n output_hidden_states=None,\n return_dict=None,\n ):\n r\"\"\"\n Args:\n input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):\n Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be\n obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a\n `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into\n `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding\n and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]\n attention_mask (`torch.Tensor`)`, *optional*):\n Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,\n but it is not used. By default the silence in the input log mel spectrogram are ignored.\n head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):\n Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:\n\n - 1 indicates the head is **not masked**,\n - 0 indicates the head is **masked**.\n output_attentions (`bool`, *optional*):\n Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n returned tensors for more detail.\n output_hidden_states (`bool`, *optional*):\n Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors\n for more detail.\n return_dict (`bool`, *optional*):\n Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n \"\"\"\n # set output flags\n output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n # embed audio features\n # input_features shape: (batch_size, feature_size, sequence_length)\n inputs_embeds = nn.functional.gelu(self.conv1(input_features)) # (batch_size, embed_dim, sequence_length)\n inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) # (batch_size, embed_dim, sequence_length/2), because the stride is 2. Downsampling by 2.\n inputs_embeds = inputs_embeds.permute(0, 2, 1) # (batch_size, sequence_length/2, embed_dim)\n embed_pos = self.embed_positions.weight # (max_source_positions, embed_dim)\n\n # add position embedding to audio features embedding\n # !!!: Do max_source_positions and sequence_length/2 must be the same??? Kind of confusing.\n hidden_states = inputs_embeds + embed_pos\n hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)\n\n encoder_states = () if output_hidden_states else None\n all_attentions = () if output_attentions else None\n\n # check if head_mask has a correct number of layers specified if desired\n if head_mask is not None:\n assert head_mask.size()[0] == (\n len(self.layers)\n ), f\"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}.\"\n\n # go through the whisper encoder layers to get the hidden states and attentions in all layers\n for idx, encoder_layer in enumerate(self.layers):\n if output_hidden_states:\n encoder_states = encoder_states + (hidden_states,)\n # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)\n dropout_probability = random.uniform(0, 1)\n if self.training and (dropout_probability < self.layerdrop): # skip the layer\n layer_outputs = (None, None)\n else:\n if self.gradient_checkpointing and self.training:\n\n def create_custom_forward(module):\n def custom_forward(*inputs):\n return module(*inputs, output_attentions)\n\n return custom_forward\n\n layer_outputs = torch.utils.checkpoint.checkpoint(\n create_custom_forward(encoder_layer),\n hidden_states,\n None,\n (head_mask[idx] if head_mask is not None else None),\n )\n else:\n # The layer_outputs is a tuple of (hidden_states, attention).\n # The attention is None if output_attentions is False.\n # hidden_states shape: (batch_size, sequence_length/2, embed_dim), as stride is 2 in the self.conv2\n # attention shape: (batch_size, num_heads, sequence_length/2, sequence_length/2)\n layer_outputs = encoder_layer(\n hidden_states,\n None,\n layer_head_mask=(head_mask[idx] if head_mask is not None else None),\n output_attentions=output_attentions,\n )\n\n hidden_states = layer_outputs[0]\n\n if output_attentions:\n all_attentions = all_attentions + (layer_outputs[1],)\n\n hidden_states = self.layer_norm(hidden_states)\n if output_hidden_states:\n encoder_states = encoder_states + (hidden_states,)\n\n # output\n if not return_dict:\n return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)\n return BaseModelOutput(\n last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions\n )\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.WhisperEncoder.forward","title":"forward(input_features, attention_mask=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None)
","text":"Parameters:
Name Type Description Default input_features
`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`
Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by loading a .flac
or .wav
audio file into an array of type List[float]
or a numpy.ndarray
, e.g. via the soundfile library (pip install soundfile
). To prepare the array into input_features
, the [AutoFeatureExtractor
] should be used for extracting the mel features, padding and conversion into a tensor of type torch.FloatTensor
. See [~WhisperFeatureExtractor.__call__
]
required attention_mask
`torch.Tensor`)`, *optional*
Whisper does not support masking of the input_features
, this argument is preserved for compatibility, but it is not used. By default the silence in the input log mel spectrogram are ignored.
None
head_mask
`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*
Mask to nullify selected heads of the attention modules. Mask values selected in [0, 1]
:
- 1 indicates the head is not masked,
- 0 indicates the head is masked.
None
output_attentions
`bool`, *optional*
Whether or not to return the attentions tensors of all attention layers. See attentions
under returned tensors for more detail.
None
output_hidden_states
`bool`, *optional*
Whether or not to return the hidden states of all layers. See hidden_states
under returned tensors for more detail.
None
return_dict
`bool`, *optional*
Whether or not to return a [~utils.ModelOutput
] instead of a plain tuple.
None
Source code in src/aeiva/model/macaw_model_old.py
def forward(\n self,\n input_features,\n attention_mask=None,\n head_mask=None,\n output_attentions=None,\n output_hidden_states=None,\n return_dict=None,\n):\n r\"\"\"\n Args:\n input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):\n Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be\n obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a\n `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into\n `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding\n and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]\n attention_mask (`torch.Tensor`)`, *optional*):\n Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,\n but it is not used. By default the silence in the input log mel spectrogram are ignored.\n head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):\n Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:\n\n - 1 indicates the head is **not masked**,\n - 0 indicates the head is **masked**.\n output_attentions (`bool`, *optional*):\n Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n returned tensors for more detail.\n output_hidden_states (`bool`, *optional*):\n Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors\n for more detail.\n return_dict (`bool`, *optional*):\n Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n \"\"\"\n # set output flags\n output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n # embed audio features\n # input_features shape: (batch_size, feature_size, sequence_length)\n inputs_embeds = nn.functional.gelu(self.conv1(input_features)) # (batch_size, embed_dim, sequence_length)\n inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) # (batch_size, embed_dim, sequence_length/2), because the stride is 2. Downsampling by 2.\n inputs_embeds = inputs_embeds.permute(0, 2, 1) # (batch_size, sequence_length/2, embed_dim)\n embed_pos = self.embed_positions.weight # (max_source_positions, embed_dim)\n\n # add position embedding to audio features embedding\n # !!!: Do max_source_positions and sequence_length/2 must be the same??? Kind of confusing.\n hidden_states = inputs_embeds + embed_pos\n hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)\n\n encoder_states = () if output_hidden_states else None\n all_attentions = () if output_attentions else None\n\n # check if head_mask has a correct number of layers specified if desired\n if head_mask is not None:\n assert head_mask.size()[0] == (\n len(self.layers)\n ), f\"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}.\"\n\n # go through the whisper encoder layers to get the hidden states and attentions in all layers\n for idx, encoder_layer in enumerate(self.layers):\n if output_hidden_states:\n encoder_states = encoder_states + (hidden_states,)\n # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)\n dropout_probability = random.uniform(0, 1)\n if self.training and (dropout_probability < self.layerdrop): # skip the layer\n layer_outputs = (None, None)\n else:\n if self.gradient_checkpointing and self.training:\n\n def create_custom_forward(module):\n def custom_forward(*inputs):\n return module(*inputs, output_attentions)\n\n return custom_forward\n\n layer_outputs = torch.utils.checkpoint.checkpoint(\n create_custom_forward(encoder_layer),\n hidden_states,\n None,\n (head_mask[idx] if head_mask is not None else None),\n )\n else:\n # The layer_outputs is a tuple of (hidden_states, attention).\n # The attention is None if output_attentions is False.\n # hidden_states shape: (batch_size, sequence_length/2, embed_dim), as stride is 2 in the self.conv2\n # attention shape: (batch_size, num_heads, sequence_length/2, sequence_length/2)\n layer_outputs = encoder_layer(\n hidden_states,\n None,\n layer_head_mask=(head_mask[idx] if head_mask is not None else None),\n output_attentions=output_attentions,\n )\n\n hidden_states = layer_outputs[0]\n\n if output_attentions:\n all_attentions = all_attentions + (layer_outputs[1],)\n\n hidden_states = self.layer_norm(hidden_states)\n if output_hidden_states:\n encoder_states = encoder_states + (hidden_states,)\n\n # output\n if not return_dict:\n return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)\n return BaseModelOutput(\n last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions\n )\n
"},{"location":"reference/#src.aeiva.model.macaw_model_old.rotate_half","title":"rotate_half(x)
","text":"Rotates half the hidden dims of the input.
Source code in src/aeiva/model/macaw_model_old.py
def rotate_half(x):\n \"\"\"Rotates half the hidden dims of the input.\"\"\"\n x1 = x[..., : x.shape[-1] // 2]\n x2 = x[..., x.shape[-1] // 2 :]\n return torch.cat((-x2, x1), dim=-1)\n
"},{"location":"reference/#src.aeiva.operator","title":"operator
","text":""},{"location":"reference/#src.aeiva.operator.custom_ops","title":"custom_ops
","text":""},{"location":"reference/#src.aeiva.operator.custom_ops.macaw_dataitem_ops","title":"macaw_dataitem_ops
","text":"This module contains the data item processing functions.
For a data item processing function, it takes a data example (a dict) as input and return a processed data example.
@Author: Bang Liu (chatsci.ai@gmail.com) @Date: 2023-07-11
Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.
"},{"location":"reference/#src.aeiva.operator.dataitem_ops","title":"dataitem_ops
","text":"This module contains the data item processing functions.
For a data item processing function, it takes a data example (a dict) as input and return a processed data example.
@Author: Bang Liu (chatsci.ai@gmail.com) @Date: 2023-07-11
Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.
"},{"location":"reference/#src.aeiva.operator.dataset_ops","title":"dataset_ops
","text":"This module contains the utils for processing datasets.
A dataset in aeiva is a dictionary with the following structure: { \"data\": [ {sample1}, {sample2}, ..., {sampleN} ], \"metadata\": { \"num_samples\": XX, ... } } where each sample is a dictionary itself, and metadata is a dictionary that contains the number of samples and possibly other fields.
@Author: Bang Liu (chatsci.ai@gmail.com) @Date: 2023-07-13
Copyright (C) 2023 Bang Liu - All Rights Reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.
"},{"location":"reference/#src.aeiva.operator.dataset_ops.build_and_merge_datasets","title":"build_and_merge_datasets(dataset_names, input_filepaths_dict, pipeline, output_dir, max_samples=sys.maxsize)
","text":"Build multiple datasets by formatting and processing them.
Source code in src/aeiva/operator/dataset_ops.py
def build_and_merge_datasets(dataset_names: list[str],\n input_filepaths_dict: dict[str, str],\n pipeline: list[Callable],\n output_dir: Optional[str],\n max_samples: Optional[int] = sys.maxsize) -> DataSet:\n r\"\"\" Build multiple datasets by formatting and processing them.\n \"\"\"\n merged_datasets = []\n for dataset_name in dataset_names:\n dataset = build_dataset(dataset_name, input_filepaths_dict, pipeline, output_dir, max_samples)\n merged_datasets.append(dataset)\n result = merge_datasets(merged_datasets)\n return result\n
"},{"location":"reference/#src.aeiva.operator.dataset_ops.build_dataset","title":"build_dataset(dataset_name, input_filepaths_dict, pipeline, output_dir, max_samples=sys.maxsize)
","text":"Build a dataset by formatting and processing it.
Source code in src/aeiva/operator/dataset_ops.py
def build_dataset(dataset_name: str,\n input_filepaths_dict: dict[str, str],\n pipeline: list[Callable],\n output_dir: Optional[str],\n max_samples: Optional[int] = sys.maxsize) -> DataSet:\n r\"\"\" Build a dataset by formatting and processing it.\n \"\"\"\n operator_type = 'data_formatter'\n format_func = OPERATORS[operator_type][dataset_name]\n formatted_dataset = format_func(input_filepaths_dict, output_dir, max_samples)\n processed_dataset = process_dataset(formatted_dataset, pipeline, output_dir, dataset_name)\n print(f\"Completed processing dataset: {dataset_name} (output_dir: {output_dir})\")\n return processed_dataset\n
"},{"location":"reference/#src.aeiva.operator.dataset_ops.filter_dataset","title":"filter_dataset(dataset, filter_criteria, *args, **kwargs)
","text":"Filter a dataset by a filter function.
Source code in src/aeiva/operator/dataset_ops.py
def filter_dataset(dataset: DataSet, filter_criteria: str, *args, **kwargs) -> DataSet:\n r\"\"\" Filter a dataset by a filter function.\n \"\"\"\n operator_type = 'data_filter'\n filter_func = OPERATORS[operator_type][filter_criteria]\n filtered_data = filter_func(dataset, *args, **kwargs)\n return filtered_data\n
"},{"location":"reference/#src.aeiva.operator.dataset_ops.filter_dataset_by_keys","title":"filter_dataset_by_keys(dataset, keys_to_preserve)
","text":"Filter the dataset to only include specified keys in each sample.
Source code in src/aeiva/operator/dataset_ops.py
@register_data_filter(\"filter_dataset_by_keys\")\ndef filter_dataset_by_keys(dataset: DataSet, keys_to_preserve: list[str]) -> DataSet:\n r\"\"\" Filter the dataset to only include specified keys in each sample.\n \"\"\"\n filtered_data = []\n for sample in dataset[\"data\"]:\n for key in keys_to_preserve:\n if key not in sample:\n raise KeyError(f\"Key {key} not found in sample\")\n filtered_sample = {key: sample[key] for key in keys_to_preserve if key in sample}\n filtered_data.append(filtered_sample)\n return {\"data\": filtered_data, \"metadata\": dataset[\"metadata\"]}\n
"},{"location":"reference/#src.aeiva.operator.dataset_ops.merge_datasets","title":"merge_datasets(datasets)
","text":"Merge multiple datasets into one.
Source code in src/aeiva/operator/dataset_ops.py
def merge_datasets(datasets: list[DataSet]) -> DataSet:\n r\"\"\" Merge multiple datasets into one.\n \"\"\"\n merged_data = []\n total_samples = 0\n for dataset in datasets:\n merged_data.extend(dataset[\"data\"])\n total_samples += dataset[\"metadata\"][\"num_samples\"]\n result = {\"data\": merged_data, \"metadata\": {\"num_samples\": total_samples}}\n return result\n
"},{"location":"reference/#src.aeiva.operator.dataset_ops.sample_dataset","title":"sample_dataset(dataset, n_samples)
","text":"Sample a number of samples from a dataset.
Source code in src/aeiva/operator/dataset_ops.py
def sample_dataset(dataset: DataSet, n_samples: int) -> DataSet:\n r\"\"\" Sample a number of samples from a dataset.\n \"\"\"\n random_indices = random.sample(range(dataset[\"metadata\"][\"num_samples\"]), n_samples)\n sampled_data = [dataset[\"data\"][i] for i in random_indices]\n return {\"data\": sampled_data, \"metadata\": {\"num_samples\": n_samples}}\n
"},{"location":"reference/#src.aeiva.operator.dataset_ops.save_dataset","title":"save_dataset(dataset, output_path)
","text":"Save a dataset to a file by pickling it.
Source code in src/aeiva/operator/dataset_ops.py
def save_dataset(dataset: DataSet, output_path: str) -> None:\n r\"\"\" Save a dataset to a file by pickling it.\n \"\"\"\n ensure_dir(output_path)\n pickle.dump(dataset, open(output_path, \"wb\"), protocol=4)\n
"},{"location":"reference/#src.aeiva.operator.dataset_ops.split_dataset","title":"split_dataset(dataset, train_ratio, seed=42)
","text":"Split a dataset into a training set and a validation set.
Source code in src/aeiva/operator/dataset_ops.py
def split_dataset(dataset: dict, train_ratio: float, seed: int = 42) -> Tuple[dict]:\n r\"\"\" Split a dataset into a training set and a validation set.\n \"\"\"\n np.random.seed(seed) # ensures the function is deterministic\n\n data = dataset[\"data\"]\n metadata = dataset[\"metadata\"]\n\n # Create a permutation of indices and shuffle the data.\n perm = np.random.permutation(len(data))\n shuffled_data = [data[i] for i in perm]\n\n # Calculate split index\n split_idx = int(train_ratio * len(shuffled_data))\n\n # Split the shuffled data\n train_data = shuffled_data[:split_idx]\n val_data = shuffled_data[split_idx:]\n\n # Create metadata for training and validation datasets\n train_metadata = metadata.copy()\n train_metadata[\"num_samples\"] = len(train_data)\n val_metadata = metadata.copy()\n val_metadata[\"num_samples\"] = len(val_data)\n\n # Create training and validation datasets\n train_dataset = {\"data\": train_data, \"metadata\": train_metadata}\n val_dataset = {\"data\": val_data, \"metadata\": val_metadata}\n\n return train_dataset, val_dataset\n
"},{"location":"reference/#src.aeiva.perception","title":"perception
","text":""},{"location":"reference/#src.aeiva.perception.base_perception_system","title":"base_perception_system
","text":""},{"location":"reference/#src.aeiva.perception.base_perception_system.PerceptionSystem","title":"PerceptionSystem
","text":" Bases: ABC
Abstract base class representing the Perception System of an agent.
The Perception System is responsible for capturing raw sensory data from the environment, processing this data into meaningful observations, and providing access to these observations for other components of the cognitive architecture.
Attributes:
Name Type Description config
Any
Configuration settings for the Perception System.
state
Any
The internal state of the Perception System, including raw data and observations.
Source code in src/aeiva/perception/base_perception_system.py
class PerceptionSystem(ABC):\n \"\"\"\n Abstract base class representing the Perception System of an agent.\n\n The Perception System is responsible for capturing raw sensory data from the environment,\n processing this data into meaningful observations, and providing access to these observations\n for other components of the cognitive architecture.\n\n Attributes:\n config (Any): Configuration settings for the Perception System.\n state (Any): The internal state of the Perception System, including raw data and observations.\n \"\"\"\n\n def __init__(self, config: Any):\n \"\"\"\n Initialize the Perception System with the provided configuration.\n\n Args:\n config (Any): Configuration settings for the Perception System.\n \"\"\"\n self.config = config\n self.state = self.init_state()\n\n @abstractmethod\n def init_state(self) -> Any:\n \"\"\"\n Initialize the internal state of the Perception System.\n\n This method should set up the initial state required for the Perception System's operations.\n\n Returns:\n Any: The initial state of the Perception System.\n \"\"\"\n pass\n\n @abstractmethod\n async def setup(self) -> None:\n \"\"\"\n Asynchronously set up the Perception System's components.\n\n This method should initialize any necessary components or resources based on the provided configuration.\n\n Raises:\n ConfigurationError: If the configuration is invalid or incomplete.\n \"\"\"\n pass\n\n @abstractmethod\n async def capture(self, raw_data: Any) -> None:\n \"\"\"\n Asynchronously capture raw sensory data from the environment.\n\n Args:\n raw_data (Any): The raw sensory data to capture.\n\n Raises:\n CaptureError: If capturing the raw data fails.\n \"\"\"\n pass\n\n @abstractmethod\n async def process(self) -> None:\n \"\"\"\n Asynchronously process the captured raw sensory data into meaningful observations.\n\n This method should transform raw data stored in the internal state into structured observations\n that can be utilized by other components of the cognitive architecture.\n\n Raises:\n ProcessingError: If processing the raw data fails.\n \"\"\"\n pass\n\n async def perceive(self, raw_data: Any) -> None:\n \"\"\"\n Asynchronously perform the full perception cycle: capture and process raw sensory data.\n\n Args:\n raw_data (Any): The raw sensory data to perceive.\n\n Raises:\n CaptureError: If capturing the raw data fails.\n ProcessingError: If processing the raw data fails.\n \"\"\"\n try:\n await self.capture(raw_data)\n await self.process()\n except Exception as e:\n self.handle_error(e)\n raise e\n\n def get_observations(self) -> Any:\n \"\"\"\n Retrieve the current processed observations from the Perception System.\n\n Returns:\n Any: The current observations.\n \"\"\"\n return self.state.get(\"observations\", None)\n\n def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during perception operations.\n\n This method can be overridden to implement custom error handling logic, such as logging\n or retry mechanisms.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n # Default error handling: log the error\n print(f\"PerceptionSystem encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.perception.base_perception_system.PerceptionSystem.__init__","title":"__init__(config)
","text":"Initialize the Perception System with the provided configuration.
Parameters:
Name Type Description Default config
Any
Configuration settings for the Perception System.
required Source code in src/aeiva/perception/base_perception_system.py
def __init__(self, config: Any):\n \"\"\"\n Initialize the Perception System with the provided configuration.\n\n Args:\n config (Any): Configuration settings for the Perception System.\n \"\"\"\n self.config = config\n self.state = self.init_state()\n
"},{"location":"reference/#src.aeiva.perception.base_perception_system.PerceptionSystem.capture","title":"capture(raw_data)
abstractmethod
async
","text":"Asynchronously capture raw sensory data from the environment.
Parameters:
Name Type Description Default raw_data
Any
The raw sensory data to capture.
required Raises:
Type Description CaptureError
If capturing the raw data fails.
Source code in src/aeiva/perception/base_perception_system.py
@abstractmethod\nasync def capture(self, raw_data: Any) -> None:\n \"\"\"\n Asynchronously capture raw sensory data from the environment.\n\n Args:\n raw_data (Any): The raw sensory data to capture.\n\n Raises:\n CaptureError: If capturing the raw data fails.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.perception.base_perception_system.PerceptionSystem.get_observations","title":"get_observations()
","text":"Retrieve the current processed observations from the Perception System.
Returns:
Name Type Description Any
Any
The current observations.
Source code in src/aeiva/perception/base_perception_system.py
def get_observations(self) -> Any:\n \"\"\"\n Retrieve the current processed observations from the Perception System.\n\n Returns:\n Any: The current observations.\n \"\"\"\n return self.state.get(\"observations\", None)\n
"},{"location":"reference/#src.aeiva.perception.base_perception_system.PerceptionSystem.handle_error","title":"handle_error(error)
","text":"Handle errors that occur during perception operations.
This method can be overridden to implement custom error handling logic, such as logging or retry mechanisms.
Parameters:
Name Type Description Default error
Exception
The exception that was raised.
required Source code in src/aeiva/perception/base_perception_system.py
def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during perception operations.\n\n This method can be overridden to implement custom error handling logic, such as logging\n or retry mechanisms.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n # Default error handling: log the error\n print(f\"PerceptionSystem encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.perception.base_perception_system.PerceptionSystem.init_state","title":"init_state()
abstractmethod
","text":"Initialize the internal state of the Perception System.
This method should set up the initial state required for the Perception System's operations.
Returns:
Name Type Description Any
Any
The initial state of the Perception System.
Source code in src/aeiva/perception/base_perception_system.py
@abstractmethod\ndef init_state(self) -> Any:\n \"\"\"\n Initialize the internal state of the Perception System.\n\n This method should set up the initial state required for the Perception System's operations.\n\n Returns:\n Any: The initial state of the Perception System.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.perception.base_perception_system.PerceptionSystem.perceive","title":"perceive(raw_data)
async
","text":"Asynchronously perform the full perception cycle: capture and process raw sensory data.
Parameters:
Name Type Description Default raw_data
Any
The raw sensory data to perceive.
required Raises:
Type Description CaptureError
If capturing the raw data fails.
ProcessingError
If processing the raw data fails.
Source code in src/aeiva/perception/base_perception_system.py
async def perceive(self, raw_data: Any) -> None:\n \"\"\"\n Asynchronously perform the full perception cycle: capture and process raw sensory data.\n\n Args:\n raw_data (Any): The raw sensory data to perceive.\n\n Raises:\n CaptureError: If capturing the raw data fails.\n ProcessingError: If processing the raw data fails.\n \"\"\"\n try:\n await self.capture(raw_data)\n await self.process()\n except Exception as e:\n self.handle_error(e)\n raise e\n
"},{"location":"reference/#src.aeiva.perception.base_perception_system.PerceptionSystem.process","title":"process()
abstractmethod
async
","text":"Asynchronously process the captured raw sensory data into meaningful observations.
This method should transform raw data stored in the internal state into structured observations that can be utilized by other components of the cognitive architecture.
Raises:
Type Description ProcessingError
If processing the raw data fails.
Source code in src/aeiva/perception/base_perception_system.py
@abstractmethod\nasync def process(self) -> None:\n \"\"\"\n Asynchronously process the captured raw sensory data into meaningful observations.\n\n This method should transform raw data stored in the internal state into structured observations\n that can be utilized by other components of the cognitive architecture.\n\n Raises:\n ProcessingError: If processing the raw data fails.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.perception.base_perception_system.PerceptionSystem.setup","title":"setup()
abstractmethod
async
","text":"Asynchronously set up the Perception System's components.
This method should initialize any necessary components or resources based on the provided configuration.
Raises:
Type Description ConfigurationError
If the configuration is invalid or incomplete.
Source code in src/aeiva/perception/base_perception_system.py
@abstractmethod\nasync def setup(self) -> None:\n \"\"\"\n Asynchronously set up the Perception System's components.\n\n This method should initialize any necessary components or resources based on the provided configuration.\n\n Raises:\n ConfigurationError: If the configuration is invalid or incomplete.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.perception.perception_system","title":"perception_system
","text":""},{"location":"reference/#src.aeiva.perception.perception_system.PerceptionSystem","title":"PerceptionSystem
","text":"Manages multiple sensors and emits stimuli via the EventBus.
Source code in src/aeiva/perception/perception_system.py
class PerceptionSystem:\n \"\"\"\n Manages multiple sensors and emits stimuli via the EventBus.\n \"\"\"\n def __init__(self, config: Dict, event_bus):\n \"\"\"\n Initializes the PerceptionSystem with a list of sensors.\n\n Args:\n config (Any): Configuration dictionary for the sensors.\n event_bus: The EventBus instance for emitting events.\n \"\"\"\n self.config = config\n self.event_bus = event_bus\n self.sensors: List[Sensor] = []\n self.logger = logging.getLogger('PerceptionSystem')\n\n def setup(self) -> None:\n \"\"\"\n Sets up the perception system by initializing all configured sensors.\n \"\"\"\n for sensor_config in self.config.get(\"sensors\", []):\n sensor_name = sensor_config.get(\"sensor_name\")\n sensor_params = sensor_config.get(\"sensor_params\", {})\n # TODO: revise later\n if sensor_name == 'percept_terminal_input':\n sensor = TerminalInputSensor(sensor_name, sensor_params, self.event_bus)\n self.sensors.append(sensor)\n else:\n self.logger.warning(f\"Unknown sensor type: {sensor_name}\")\n self.logger.info(\"PerceptionSystem setup complete.\")\n\n async def start(self) -> None: # TODO: maybe rename in the future\n \"\"\"\n Starts all sensors asynchronously.\n \"\"\"\n self.logger.info(\"Starting all sensors.\")\n for sensor in self.sensors:\n await sensor.start()\n\n async def stop(self) -> None:\n \"\"\"\n Stops all sensors asynchronously.\n \"\"\"\n self.logger.info(\"Stopping all sensors.\")\n for sensor in self.sensors:\n await sensor.stop()\n\n def signal_to_stimuli(self, data: Any) -> Any:\n \"\"\"\n Processes raw data from sensors into structured stimuli.\n\n Args:\n data: The raw data emitted by sensors.\n\n Returns:\n Processed data (stimuli).\n \"\"\"\n # Implement your data processing logic here\n signal = Signal(\n data=data,\n modularity=\"text\", # Or appropriate modality\n type=\"input\", # Or appropriate type\n # TODO: After revised Sensor class, Include other metadata as needed\n )\n stimuli = Stimuli(signals=[signal]) # TODO: add more fields\n return stimuli\n
"},{"location":"reference/#src.aeiva.perception.perception_system.PerceptionSystem.__init__","title":"__init__(config, event_bus)
","text":"Initializes the PerceptionSystem with a list of sensors.
Parameters:
Name Type Description Default config
Any
Configuration dictionary for the sensors.
required event_bus
The EventBus instance for emitting events.
required Source code in src/aeiva/perception/perception_system.py
def __init__(self, config: Dict, event_bus):\n \"\"\"\n Initializes the PerceptionSystem with a list of sensors.\n\n Args:\n config (Any): Configuration dictionary for the sensors.\n event_bus: The EventBus instance for emitting events.\n \"\"\"\n self.config = config\n self.event_bus = event_bus\n self.sensors: List[Sensor] = []\n self.logger = logging.getLogger('PerceptionSystem')\n
"},{"location":"reference/#src.aeiva.perception.perception_system.PerceptionSystem.setup","title":"setup()
","text":"Sets up the perception system by initializing all configured sensors.
Source code in src/aeiva/perception/perception_system.py
def setup(self) -> None:\n \"\"\"\n Sets up the perception system by initializing all configured sensors.\n \"\"\"\n for sensor_config in self.config.get(\"sensors\", []):\n sensor_name = sensor_config.get(\"sensor_name\")\n sensor_params = sensor_config.get(\"sensor_params\", {})\n # TODO: revise later\n if sensor_name == 'percept_terminal_input':\n sensor = TerminalInputSensor(sensor_name, sensor_params, self.event_bus)\n self.sensors.append(sensor)\n else:\n self.logger.warning(f\"Unknown sensor type: {sensor_name}\")\n self.logger.info(\"PerceptionSystem setup complete.\")\n
"},{"location":"reference/#src.aeiva.perception.perception_system.PerceptionSystem.signal_to_stimuli","title":"signal_to_stimuli(data)
","text":"Processes raw data from sensors into structured stimuli.
Parameters:
Name Type Description Default data
Any
The raw data emitted by sensors.
required Returns:
Type Description Any
Processed data (stimuli).
Source code in src/aeiva/perception/perception_system.py
def signal_to_stimuli(self, data: Any) -> Any:\n \"\"\"\n Processes raw data from sensors into structured stimuli.\n\n Args:\n data: The raw data emitted by sensors.\n\n Returns:\n Processed data (stimuli).\n \"\"\"\n # Implement your data processing logic here\n signal = Signal(\n data=data,\n modularity=\"text\", # Or appropriate modality\n type=\"input\", # Or appropriate type\n # TODO: After revised Sensor class, Include other metadata as needed\n )\n stimuli = Stimuli(signals=[signal]) # TODO: add more fields\n return stimuli\n
"},{"location":"reference/#src.aeiva.perception.perception_system.PerceptionSystem.start","title":"start()
async
","text":"Starts all sensors asynchronously.
Source code in src/aeiva/perception/perception_system.py
async def start(self) -> None: # TODO: maybe rename in the future\n \"\"\"\n Starts all sensors asynchronously.\n \"\"\"\n self.logger.info(\"Starting all sensors.\")\n for sensor in self.sensors:\n await sensor.start()\n
"},{"location":"reference/#src.aeiva.perception.perception_system.PerceptionSystem.stop","title":"stop()
async
","text":"Stops all sensors asynchronously.
Source code in src/aeiva/perception/perception_system.py
async def stop(self) -> None:\n \"\"\"\n Stops all sensors asynchronously.\n \"\"\"\n self.logger.info(\"Stopping all sensors.\")\n for sensor in self.sensors:\n await sensor.stop()\n
"},{"location":"reference/#src.aeiva.perception.sensation","title":"sensation
","text":""},{"location":"reference/#src.aeiva.perception.sensation.Signal","title":"Signal
","text":"Represents an atomic unit of perception that carries raw data from the environment. This class defines a signal, its characteristics, and its dependencies on other signals.
Source code in src/aeiva/perception/sensation.py
class Signal:\n \"\"\"\n Represents an atomic unit of perception that carries raw data from the environment.\n This class defines a signal, its characteristics, and its dependencies on other signals.\n \"\"\"\n\n def __init__(self, \n data: Any,\n name: Optional[str] = None, # Optional name for the signal\n modularity: Optional[str] = None,\n type: Optional[str] = None, # Renamed to avoid keyword conflict\n timestamp: Optional[datetime] = None,\n id: Optional[str] = None, # Optional unique identifier for the signal\n dependencies: Optional[Dict[str, Any]] = None, # Dependencies by other signal IDs with edge attributes\n description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n \"\"\"\n Initialize a signal with its data and other optional metadata.\n\n Args:\n data (Any): The raw data of the signal.\n name (Optional[str]): An optional name for the signal.\n modularity (Optional[str]): The modality of the signal (e.g., image, video, text, audio).\n type (Optional[str]): A more detailed signal type (e.g., 'text', 'document', etc.).\n timestamp (Optional[datetime]): The time when the signal was created or captured.\n id (Optional[str]): Unique identifier for the signal.\n dependencies (Optional[Dict[str, Any]]): Attributes of dependencies (e.g., relationship types).\n description (Optional[str]): Description of the signal.\n metadata (Optional[Dict[str, Any]]): Optional additional metadata for the signal.\n \"\"\"\n self.data = data\n self.name = name\n self.modularity = modularity\n self.type = type\n self.timestamp = timestamp or datetime.now()\n self.id = id\n self.dependencies = dependencies or {} # Edge attributes (could be string, embedding, etc.)\n self.description = description\n self.metadata = metadata or {}\n\n def to_dict(self) -> Dict[str, Any]:\n \"\"\"\n Converts the signal into a dictionary representation.\n \"\"\"\n return {\n \"data\": self.data,\n \"name\": self.name,\n \"modularity\": self.modularity,\n \"type\": self.type,\n \"timestamp\": self.timestamp,\n \"id\": self.id,\n \"dependencies\": self.dependencies,\n \"description\": self.description,\n \"metadata\": self.metadata\n }\n
"},{"location":"reference/#src.aeiva.perception.sensation.Signal.__init__","title":"__init__(data, name=None, modularity=None, type=None, timestamp=None, id=None, dependencies=None, description=None, metadata=None)
","text":"Initialize a signal with its data and other optional metadata.
Parameters:
Name Type Description Default data
Any
The raw data of the signal.
required name
Optional[str]
An optional name for the signal.
None
modularity
Optional[str]
The modality of the signal (e.g., image, video, text, audio).
None
type
Optional[str]
A more detailed signal type (e.g., 'text', 'document', etc.).
None
timestamp
Optional[datetime]
The time when the signal was created or captured.
None
id
Optional[str]
Unique identifier for the signal.
None
dependencies
Optional[Dict[str, Any]]
Attributes of dependencies (e.g., relationship types).
None
description
Optional[str]
Description of the signal.
None
metadata
Optional[Dict[str, Any]]
Optional additional metadata for the signal.
None
Source code in src/aeiva/perception/sensation.py
def __init__(self, \n data: Any,\n name: Optional[str] = None, # Optional name for the signal\n modularity: Optional[str] = None,\n type: Optional[str] = None, # Renamed to avoid keyword conflict\n timestamp: Optional[datetime] = None,\n id: Optional[str] = None, # Optional unique identifier for the signal\n dependencies: Optional[Dict[str, Any]] = None, # Dependencies by other signal IDs with edge attributes\n description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n \"\"\"\n Initialize a signal with its data and other optional metadata.\n\n Args:\n data (Any): The raw data of the signal.\n name (Optional[str]): An optional name for the signal.\n modularity (Optional[str]): The modality of the signal (e.g., image, video, text, audio).\n type (Optional[str]): A more detailed signal type (e.g., 'text', 'document', etc.).\n timestamp (Optional[datetime]): The time when the signal was created or captured.\n id (Optional[str]): Unique identifier for the signal.\n dependencies (Optional[Dict[str, Any]]): Attributes of dependencies (e.g., relationship types).\n description (Optional[str]): Description of the signal.\n metadata (Optional[Dict[str, Any]]): Optional additional metadata for the signal.\n \"\"\"\n self.data = data\n self.name = name\n self.modularity = modularity\n self.type = type\n self.timestamp = timestamp or datetime.now()\n self.id = id\n self.dependencies = dependencies or {} # Edge attributes (could be string, embedding, etc.)\n self.description = description\n self.metadata = metadata or {}\n
"},{"location":"reference/#src.aeiva.perception.sensation.Signal.to_dict","title":"to_dict()
","text":"Converts the signal into a dictionary representation.
Source code in src/aeiva/perception/sensation.py
def to_dict(self) -> Dict[str, Any]:\n \"\"\"\n Converts the signal into a dictionary representation.\n \"\"\"\n return {\n \"data\": self.data,\n \"name\": self.name,\n \"modularity\": self.modularity,\n \"type\": self.type,\n \"timestamp\": self.timestamp,\n \"id\": self.id,\n \"dependencies\": self.dependencies,\n \"description\": self.description,\n \"metadata\": self.metadata\n }\n
"},{"location":"reference/#src.aeiva.perception.sensor","title":"sensor
","text":""},{"location":"reference/#src.aeiva.perception.sensor.Sensor","title":"Sensor
","text":" Bases: ABC
Abstract base class for all sensors.
Source code in src/aeiva/perception/sensor.py
class Sensor(ABC):\n \"\"\"\n Abstract base class for all sensors.\n \"\"\"\n def __init__(self, name: str, params: dict, event_bus):\n \"\"\"\n Initializes the BaseSensor.\n\n Args:\n name (str): The name of the sensor.\n params (dict): Configuration parameters for the sensor.\n event_bus: The EventBus instance for emitting events.\n \"\"\"\n self.name = name\n self.params = params\n self.event_bus = event_bus\n\n @abstractmethod\n async def start(self):\n \"\"\"\n Starts the sensor.\n \"\"\"\n pass\n\n @abstractmethod\n async def stop(self):\n \"\"\"\n Stops the sensor.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.perception.sensor.Sensor.__init__","title":"__init__(name, params, event_bus)
","text":"Initializes the BaseSensor.
Parameters:
Name Type Description Default name
str
The name of the sensor.
required params
dict
Configuration parameters for the sensor.
required event_bus
The EventBus instance for emitting events.
required Source code in src/aeiva/perception/sensor.py
def __init__(self, name: str, params: dict, event_bus):\n \"\"\"\n Initializes the BaseSensor.\n\n Args:\n name (str): The name of the sensor.\n params (dict): Configuration parameters for the sensor.\n event_bus: The EventBus instance for emitting events.\n \"\"\"\n self.name = name\n self.params = params\n self.event_bus = event_bus\n
"},{"location":"reference/#src.aeiva.perception.sensor.Sensor.start","title":"start()
abstractmethod
async
","text":"Starts the sensor.
Source code in src/aeiva/perception/sensor.py
@abstractmethod\nasync def start(self):\n \"\"\"\n Starts the sensor.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.perception.sensor.Sensor.stop","title":"stop()
abstractmethod
async
","text":"Stops the sensor.
Source code in src/aeiva/perception/sensor.py
@abstractmethod\nasync def stop(self):\n \"\"\"\n Stops the sensor.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.perception.stimuli","title":"stimuli
","text":""},{"location":"reference/#src.aeiva.perception.stimuli.Stimuli","title":"Stimuli
","text":"Represents a structured composition of signals, where each node can be a Signal or a sub-Stimuli. The graph allows flexible, directed relationships between nodes, and the graph can contain cycles.
Source code in src/aeiva/perception/stimuli.py
class Stimuli:\n \"\"\"\n Represents a structured composition of signals, where each node can be a Signal or a sub-Stimuli.\n The graph allows flexible, directed relationships between nodes, and the graph can contain cycles.\n \"\"\"\n\n def __init__(self, \n signals: List[Union[Signal, 'Stimuli']],\n id: Optional[str] = None,\n name: Optional[str] = None,\n type: Optional[str] = None,\n modularity: Optional[str] = None,\n timestamp: Optional[str] = None,\n dependencies: Optional[Dict[str, Dict[str, Any]]] = None,\n description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n \"\"\"\n Initializes the Stimuli object by organizing signals or sub-stimuli in a graph structure.\n \"\"\"\n self.signals = signals or [] # Default to an empty list if no signals provided\n self.id = id\n self.name = name\n self.type = type\n self.modularity = modularity\n self.timestamp = timestamp\n self.description = description\n self.metadata = metadata or {}\n self.dependencies = dependencies or {}\n\n # Graph to represent the structure of signals and their relationships\n self.graph = nx.DiGraph()\n\n # Add all signals and sub-stimuli as nodes in the graph\n for signal in signals:\n self.graph.add_node(signal)\n\n # Handle dependencies for signals or sub-stimuli\n for signal in signals:\n if signal.id in self.dependencies:\n for dep_id, edge_attr in self.dependencies[signal.id].items():\n dep_node = next((s for s in signals if s.id == dep_id), None)\n if dep_node and (isinstance(dep_node, Signal) or isinstance(dep_node, Stimuli)):\n self.graph.add_edge(dep_node, signal, **edge_attr)\n else:\n raise ValueError(f\"Dependency {dep_id} not found or is not valid for signal or stimuli {signal.id}.\")\n\n def traverse(self, method: str = 'dfs') -> List[Union[Signal, 'Stimuli']]:\n \"\"\"\n Traverses the graph using the specified method ('dfs' or 'bfs').\n\n Args:\n method (str): The traversal method to use, either 'dfs' (Depth-First Search) or 'bfs' (Breadth-First Search).\n\n Returns:\n List[Union[Signal, 'Stimuli']]: A list of signals or sub-stimuli in the order they were visited.\n \"\"\"\n if not self.graph.nodes:\n return []\n\n if method == 'dfs':\n return list(nx.dfs_postorder_nodes(self.graph))\n elif method == 'bfs':\n return list(nx.bfs_tree(self.graph, list(self.graph.nodes)[0])) # BFS starting from an arbitrary node\n else:\n raise ValueError(f\"Unknown traversal method: {method}\")\n\n def to_dict(self) -> Dict[str, Any]:\n \"\"\"\n Converts the stimuli into a dictionary representation, including its signals and their relationships.\n \"\"\"\n return {\n \"id\": self.id,\n \"name\": self.name,\n \"type\": self.type,\n \"modularity\": self.modularity,\n \"timestamp\": self.timestamp,\n \"description\": self.description,\n \"metadata\": self.metadata,\n \"signals\": [signal.to_dict() if isinstance(signal, Signal) else signal.to_dict() for signal in self.signals],\n \"dependencies\": self.dependencies\n }\n\n def visualize(self, save_path: Optional[str] = None):\n \"\"\"\n Visualizes the procedure's structure using networkx and matplotlib.\n \"\"\"\n pos = nx.spring_layout(self.graph) # Layout for the graph\n labels = {node: f\"{node.id} ({node.type})\" if isinstance(node, Signal) else f\"{node.id} (Stimuli)\"\n for node in self.graph.nodes()}\n\n # Draw the graph with labels\n nx.draw(self.graph, pos, with_labels=True, labels=labels, node_color='lightblue', node_size=2000, font_size=10, font_weight='bold', arrows=True)\n\n plt.title(f\"{self.type} {self.description} Visualization\")\n if save_path:\n plt.savefig(save_path)\n else:\n plt.show()\n
"},{"location":"reference/#src.aeiva.perception.stimuli.Stimuli.__init__","title":"__init__(signals, id=None, name=None, type=None, modularity=None, timestamp=None, dependencies=None, description=None, metadata=None)
","text":"Initializes the Stimuli object by organizing signals or sub-stimuli in a graph structure.
Source code in src/aeiva/perception/stimuli.py
def __init__(self, \n signals: List[Union[Signal, 'Stimuli']],\n id: Optional[str] = None,\n name: Optional[str] = None,\n type: Optional[str] = None,\n modularity: Optional[str] = None,\n timestamp: Optional[str] = None,\n dependencies: Optional[Dict[str, Dict[str, Any]]] = None,\n description: Optional[str] = None,\n metadata: Optional[Dict[str, Any]] = None):\n \"\"\"\n Initializes the Stimuli object by organizing signals or sub-stimuli in a graph structure.\n \"\"\"\n self.signals = signals or [] # Default to an empty list if no signals provided\n self.id = id\n self.name = name\n self.type = type\n self.modularity = modularity\n self.timestamp = timestamp\n self.description = description\n self.metadata = metadata or {}\n self.dependencies = dependencies or {}\n\n # Graph to represent the structure of signals and their relationships\n self.graph = nx.DiGraph()\n\n # Add all signals and sub-stimuli as nodes in the graph\n for signal in signals:\n self.graph.add_node(signal)\n\n # Handle dependencies for signals or sub-stimuli\n for signal in signals:\n if signal.id in self.dependencies:\n for dep_id, edge_attr in self.dependencies[signal.id].items():\n dep_node = next((s for s in signals if s.id == dep_id), None)\n if dep_node and (isinstance(dep_node, Signal) or isinstance(dep_node, Stimuli)):\n self.graph.add_edge(dep_node, signal, **edge_attr)\n else:\n raise ValueError(f\"Dependency {dep_id} not found or is not valid for signal or stimuli {signal.id}.\")\n
"},{"location":"reference/#src.aeiva.perception.stimuli.Stimuli.to_dict","title":"to_dict()
","text":"Converts the stimuli into a dictionary representation, including its signals and their relationships.
Source code in src/aeiva/perception/stimuli.py
def to_dict(self) -> Dict[str, Any]:\n \"\"\"\n Converts the stimuli into a dictionary representation, including its signals and their relationships.\n \"\"\"\n return {\n \"id\": self.id,\n \"name\": self.name,\n \"type\": self.type,\n \"modularity\": self.modularity,\n \"timestamp\": self.timestamp,\n \"description\": self.description,\n \"metadata\": self.metadata,\n \"signals\": [signal.to_dict() if isinstance(signal, Signal) else signal.to_dict() for signal in self.signals],\n \"dependencies\": self.dependencies\n }\n
"},{"location":"reference/#src.aeiva.perception.stimuli.Stimuli.traverse","title":"traverse(method='dfs')
","text":"Traverses the graph using the specified method ('dfs' or 'bfs').
Parameters:
Name Type Description Default method
str
The traversal method to use, either 'dfs' (Depth-First Search) or 'bfs' (Breadth-First Search).
'dfs'
Returns:
Type Description List[Union[Signal, Stimuli]]
List[Union[Signal, 'Stimuli']]: A list of signals or sub-stimuli in the order they were visited.
Source code in src/aeiva/perception/stimuli.py
def traverse(self, method: str = 'dfs') -> List[Union[Signal, 'Stimuli']]:\n \"\"\"\n Traverses the graph using the specified method ('dfs' or 'bfs').\n\n Args:\n method (str): The traversal method to use, either 'dfs' (Depth-First Search) or 'bfs' (Breadth-First Search).\n\n Returns:\n List[Union[Signal, 'Stimuli']]: A list of signals or sub-stimuli in the order they were visited.\n \"\"\"\n if not self.graph.nodes:\n return []\n\n if method == 'dfs':\n return list(nx.dfs_postorder_nodes(self.graph))\n elif method == 'bfs':\n return list(nx.bfs_tree(self.graph, list(self.graph.nodes)[0])) # BFS starting from an arbitrary node\n else:\n raise ValueError(f\"Unknown traversal method: {method}\")\n
"},{"location":"reference/#src.aeiva.perception.stimuli.Stimuli.visualize","title":"visualize(save_path=None)
","text":"Visualizes the procedure's structure using networkx and matplotlib.
Source code in src/aeiva/perception/stimuli.py
def visualize(self, save_path: Optional[str] = None):\n \"\"\"\n Visualizes the procedure's structure using networkx and matplotlib.\n \"\"\"\n pos = nx.spring_layout(self.graph) # Layout for the graph\n labels = {node: f\"{node.id} ({node.type})\" if isinstance(node, Signal) else f\"{node.id} (Stimuli)\"\n for node in self.graph.nodes()}\n\n # Draw the graph with labels\n nx.draw(self.graph, pos, with_labels=True, labels=labels, node_color='lightblue', node_size=2000, font_size=10, font_weight='bold', arrows=True)\n\n plt.title(f\"{self.type} {self.description} Visualization\")\n if save_path:\n plt.savefig(save_path)\n else:\n plt.show()\n
"},{"location":"reference/#src.aeiva.perception.terminal_input_sensor","title":"terminal_input_sensor
","text":""},{"location":"reference/#src.aeiva.perception.terminal_input_sensor.TerminalInputSensor","title":"TerminalInputSensor
","text":" Bases: Sensor
A sensor that reads input from the terminal and emits stimuli via the EventBus.
Source code in src/aeiva/perception/terminal_input_sensor.py
class TerminalInputSensor(Sensor):\n \"\"\"\n A sensor that reads input from the terminal and emits stimuli via the EventBus.\n \"\"\"\n def __init__(self, name: str, params: dict, event_bus):\n super().__init__(name, params, event_bus)\n self.prompt_message = params.get('prompt_message', 'You: ')\n self._running = False\n self._thread = None\n # self.logger = logging.getLogger(f'TerminalInputSensor-{self.name}')\n\n async def start(self):\n \"\"\"\n Starts the sensor by launching the input thread.\n \"\"\"\n self._running = True\n self._thread = threading.Thread(target=self._run, daemon=True)\n self._thread.start()\n # self.logger.info(f\"{self.name} started.\")\n\n async def stop(self):\n \"\"\"\n Stops the sensor by signaling the thread to stop and waiting for it to finish.\n \"\"\"\n self._running = False\n if self._thread:\n self._thread.join()\n # self.logger.info(f\"{self.name} stopped.\")\n\n def _run(self):\n \"\"\"\n The main loop that reads user input and emits events.\n \"\"\"\n loop = self.event_bus.loop\n if loop is None:\n # self.logger.error(\"EventBus loop is not set. Cannot emit events.\")\n return\n\n while self._running:\n try:\n user_input = input(self.prompt_message)\n if not self._running:\n break # Exit if stopped during input\n\n # # Process input into stimuli\n # stimuli = self.signal_to_stimuli(user_input)\n\n # Emit the stimuli as an event\n asyncio.run_coroutine_threadsafe(\n self.event_bus.emit('perception.stimuli', payload=user_input), # TODO: rename event later\n loop\n )\n except EOFError:\n # Handle end of input (Ctrl+D)\n # self.logger.info(\"EOF received. Stopping TerminalInputSensor.\")\n self._running = False\n except KeyboardInterrupt:\n # Handle Ctrl+C\n # self.logger.info(\"KeyboardInterrupt received. Stopping TerminalInputSensor.\")\n self._running = False\n except Exception as e:\n # self.logger.error(f\"Error in TerminalInputSensor: {e}\")\n self._running = False\n
"},{"location":"reference/#src.aeiva.perception.terminal_input_sensor.TerminalInputSensor.start","title":"start()
async
","text":"Starts the sensor by launching the input thread.
Source code in src/aeiva/perception/terminal_input_sensor.py
async def start(self):\n \"\"\"\n Starts the sensor by launching the input thread.\n \"\"\"\n self._running = True\n self._thread = threading.Thread(target=self._run, daemon=True)\n self._thread.start()\n
"},{"location":"reference/#src.aeiva.perception.terminal_input_sensor.TerminalInputSensor.stop","title":"stop()
async
","text":"Stops the sensor by signaling the thread to stop and waiting for it to finish.
Source code in src/aeiva/perception/terminal_input_sensor.py
async def stop(self):\n \"\"\"\n Stops the sensor by signaling the thread to stop and waiting for it to finish.\n \"\"\"\n self._running = False\n if self._thread:\n self._thread.join()\n
"},{"location":"reference/#src.aeiva.perception.test","title":"test
","text":""},{"location":"reference/#src.aeiva.perception.test.handle_observation","title":"handle_observation(stimuli)
async
","text":"Processes stimuli using the cognition system and outputs the response.
Source code in src/aeiva/perception/test.py
async def handle_observation(stimuli):\n \"\"\"\n Processes stimuli using the cognition system and outputs the response.\n \"\"\"\n for signal in stimuli.signals:\n user_input = signal.data\n stimuli_data = [{\"role\": \"user\", \"content\": user_input}]\n response = await llm_brain.think(stimuli_data, stream=True)\n print(f\"LLM Response: {response}\")\n
"},{"location":"reference/#src.aeiva.plugin","title":"plugin
","text":""},{"location":"reference/#src.aeiva.plugin.ability","title":"ability
","text":""},{"location":"reference/#src.aeiva.plugin.ability.plugin_a","title":"plugin_a
","text":""},{"location":"reference/#src.aeiva.plugin.ability.plugin_a.plugin","title":"plugin
","text":""},{"location":"reference/#src.aeiva.plugin.ability.plugin_a.plugin.PluginA","title":"PluginA
","text":" Bases: Plugin
Example Plugin A.
Source code in src/aeiva/plugin/ability/plugin_a/plugin.py
class PluginA(Plugin):\n \"\"\"\n Example Plugin A.\n \"\"\"\n\n def activate(self) -> None:\n print(\"PluginA activated.\")\n\n def deactivate(self) -> None:\n print(\"PluginA deactivated.\")\n\n def run(self) -> None:\n print(\"PluginA is running.\")\n
"},{"location":"reference/#src.aeiva.plugin.ability.plugin_b","title":"plugin_b
","text":""},{"location":"reference/#src.aeiva.plugin.ability.plugin_b.plugin","title":"plugin
","text":""},{"location":"reference/#src.aeiva.plugin.ability.plugin_b.plugin.PluginB","title":"PluginB
","text":" Bases: Plugin
Example Plugin B.
Source code in src/aeiva/plugin/ability/plugin_b/plugin.py
class PluginB(Plugin):\n \"\"\"\n Example Plugin B.\n \"\"\"\n\n def activate(self) -> None:\n print(\"PluginB activated.\")\n\n def deactivate(self) -> None:\n print(\"PluginB deactivated.\")\n\n def run(self) -> None:\n print(\"PluginB is running.\")\n
"},{"location":"reference/#src.aeiva.plugin.plug","title":"plug
","text":""},{"location":"reference/#src.aeiva.plugin.plug--plug-module","title":"Plug Module","text":"This module provides a flexible plugin system with support for:
- Multiple plugin sources with isolation
- Context managers and import hooks
- Resource loading from plugins
- Loading plugins from directories and zip files
- Hot swapping and lazy loading of plugins
Author: Bang Liu Date: 2024-11-19
"},{"location":"reference/#src.aeiva.plugin.plug.Plugin","title":"Plugin
","text":" Bases: ABC
Abstract base class that all plugins must inherit from.
Source code in src/aeiva/plugin/plug.py
class Plugin(abc.ABC):\n \"\"\"\n Abstract base class that all plugins must inherit from.\n \"\"\"\n\n @abc.abstractmethod\n def activate(self) -> None:\n \"\"\"Method called when the plugin is activated.\"\"\"\n pass\n\n @abc.abstractmethod\n def deactivate(self) -> None:\n \"\"\"Method called when the plugin is deactivated.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.plugin.plug.Plugin.activate","title":"activate()
abstractmethod
","text":"Method called when the plugin is activated.
Source code in src/aeiva/plugin/plug.py
@abc.abstractmethod\ndef activate(self) -> None:\n \"\"\"Method called when the plugin is activated.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.plugin.plug.Plugin.deactivate","title":"deactivate()
abstractmethod
","text":"Method called when the plugin is deactivated.
Source code in src/aeiva/plugin/plug.py
@abc.abstractmethod\ndef deactivate(self) -> None:\n \"\"\"Method called when the plugin is deactivated.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginFinder","title":"PluginFinder
","text":" Bases: MetaPathFinder
Custom finder for plugin modules. Finds plugins as directories containing a plugin.py
file.
Source code in src/aeiva/plugin/plug.py
class PluginFinder(importlib.abc.MetaPathFinder):\n \"\"\"\n Custom finder for plugin modules.\n Finds plugins as directories containing a `plugin.py` file.\n \"\"\"\n\n def __init__(self, plugin_source: 'PluginSource') -> None:\n self.plugin_source = plugin_source\n\n def find_spec(\n self,\n fullname: str,\n path: Optional[List[str]],\n target: Optional[ModuleType] = None\n ) -> Optional[importlib.machinery.ModuleSpec]:\n \"\"\"\n Find the module spec for the given module.\n Handles both the namespace package and its submodules (plugins).\n \"\"\"\n if fullname == self.plugin_source.namespace:\n # Handle the namespace package itself\n print(f\"PluginFinder: Creating namespace package '{fullname}'\")\n spec = importlib.machinery.ModuleSpec(fullname, loader=None, is_package=True)\n spec.submodule_search_locations = []\n return spec\n\n elif fullname.startswith(self.plugin_source.namespace + '.'):\n # Handle submodules (plugins)\n plugin_name = fullname[len(self.plugin_source.namespace) + 1:]\n if plugin_name in self.plugin_source.list_plugins():\n print(f\"PluginFinder: Found plugin '{plugin_name}' for module '{fullname}'\")\n loader = PluginLoader(self.plugin_source, plugin_name)\n spec = importlib.util.spec_from_loader(fullname, loader)\n spec.submodule_search_locations = []\n return spec\n\n # If not handling this module, return None\n print(f\"PluginFinder: Not handling module '{fullname}'\")\n return None\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginFinder.find_spec","title":"find_spec(fullname, path, target=None)
","text":"Find the module spec for the given module. Handles both the namespace package and its submodules (plugins).
Source code in src/aeiva/plugin/plug.py
def find_spec(\n self,\n fullname: str,\n path: Optional[List[str]],\n target: Optional[ModuleType] = None\n) -> Optional[importlib.machinery.ModuleSpec]:\n \"\"\"\n Find the module spec for the given module.\n Handles both the namespace package and its submodules (plugins).\n \"\"\"\n if fullname == self.plugin_source.namespace:\n # Handle the namespace package itself\n print(f\"PluginFinder: Creating namespace package '{fullname}'\")\n spec = importlib.machinery.ModuleSpec(fullname, loader=None, is_package=True)\n spec.submodule_search_locations = []\n return spec\n\n elif fullname.startswith(self.plugin_source.namespace + '.'):\n # Handle submodules (plugins)\n plugin_name = fullname[len(self.plugin_source.namespace) + 1:]\n if plugin_name in self.plugin_source.list_plugins():\n print(f\"PluginFinder: Found plugin '{plugin_name}' for module '{fullname}'\")\n loader = PluginLoader(self.plugin_source, plugin_name)\n spec = importlib.util.spec_from_loader(fullname, loader)\n spec.submodule_search_locations = []\n return spec\n\n # If not handling this module, return None\n print(f\"PluginFinder: Not handling module '{fullname}'\")\n return None\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginLoader","title":"PluginLoader
","text":" Bases: Loader
Custom loader for plugin modules. Loads the plugin.py
file within the plugin directory.
Source code in src/aeiva/plugin/plug.py
class PluginLoader(importlib.abc.Loader):\n \"\"\"\n Custom loader for plugin modules.\n Loads the `plugin.py` file within the plugin directory.\n \"\"\"\n\n def __init__(self, plugin_source: 'PluginSource', plugin_name: str) -> None:\n self.plugin_source = plugin_source\n self.plugin_name = plugin_name\n\n def create_module(self, spec: importlib.machinery.ModuleSpec) -> Optional[ModuleType]:\n \"\"\"Use default module creation semantics.\"\"\"\n return None\n\n def exec_module(self, module: ModuleType) -> None:\n \"\"\"Execute the plugin's `plugin.py` module.\"\"\"\n try:\n code = self.plugin_source.get_plugin_code(self.plugin_name)\n except ImportError as e:\n print(f\"PluginLoader: Failed to get code for plugin '{self.plugin_name}': {e}\")\n raise\n\n # Compute project_root dynamically based on plug.py's location\n plugin_dir = os.path.dirname(os.path.abspath(__file__))\n project_root = os.path.abspath(os.path.join(plugin_dir, '../../../'))\n print(f\"PluginLoader: Adding '{project_root}' to sys.path for plugin '{self.plugin_name}'\")\n sys.path.insert(0, project_root)\n\n try:\n print(f\"PluginLoader: Executing plugin '{self.plugin_name}'\")\n exec(code, module.__dict__)\n print(f\"PluginLoader: Plugin '{self.plugin_name}' executed successfully\")\n except Exception as e:\n print(f\"PluginLoader: Error executing plugin '{self.plugin_name}': {e}\")\n raise\n finally:\n sys.path.pop(0)\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginLoader.create_module","title":"create_module(spec)
","text":"Use default module creation semantics.
Source code in src/aeiva/plugin/plug.py
def create_module(self, spec: importlib.machinery.ModuleSpec) -> Optional[ModuleType]:\n \"\"\"Use default module creation semantics.\"\"\"\n return None\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginLoader.exec_module","title":"exec_module(module)
","text":"Execute the plugin's plugin.py
module.
Source code in src/aeiva/plugin/plug.py
def exec_module(self, module: ModuleType) -> None:\n \"\"\"Execute the plugin's `plugin.py` module.\"\"\"\n try:\n code = self.plugin_source.get_plugin_code(self.plugin_name)\n except ImportError as e:\n print(f\"PluginLoader: Failed to get code for plugin '{self.plugin_name}': {e}\")\n raise\n\n # Compute project_root dynamically based on plug.py's location\n plugin_dir = os.path.dirname(os.path.abspath(__file__))\n project_root = os.path.abspath(os.path.join(plugin_dir, '../../../'))\n print(f\"PluginLoader: Adding '{project_root}' to sys.path for plugin '{self.plugin_name}'\")\n sys.path.insert(0, project_root)\n\n try:\n print(f\"PluginLoader: Executing plugin '{self.plugin_name}'\")\n exec(code, module.__dict__)\n print(f\"PluginLoader: Plugin '{self.plugin_name}' executed successfully\")\n except Exception as e:\n print(f\"PluginLoader: Error executing plugin '{self.plugin_name}': {e}\")\n raise\n finally:\n sys.path.pop(0)\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginManager","title":"PluginManager
","text":"Manages multiple PluginSources and controls plugin imports.
Source code in src/aeiva/plugin/plug.py
class PluginManager:\n \"\"\"\n Manages multiple PluginSources and controls plugin imports.\n \"\"\"\n\n def __init__(self) -> None:\n self.plugin_sources: Dict[str, PluginSource] = {}\n\n def create_plugin_source(self, name: str, search_path: Optional[List[str]] = None) -> PluginSource:\n \"\"\"\n Creates a new PluginSource.\n\n :param name: Unique name for the plugin source.\n :param search_path: List of paths to search for plugins.\n :return: The created PluginSource.\n \"\"\"\n if name in self.plugin_sources:\n raise ValueError(f\"Plugin source '{name}' already exists.\")\n source = PluginSource(name, search_path)\n self.plugin_sources[name] = source\n print(f\"PluginManager: Created plugin source '{name}' with search paths {search_path}.\")\n return source\n\n def get_plugin_source(self, name: str) -> Optional[PluginSource]:\n \"\"\"\n Retrieves a PluginSource by name.\n\n :param name: Name of the PluginSource.\n :return: The PluginSource instance, or None if not found.\n \"\"\"\n return self.plugin_sources.get(name)\n\n def remove_plugin_source(self, name: str) -> None:\n \"\"\"\n Removes a PluginSource.\n\n :param name: Name of the PluginSource to remove.\n \"\"\"\n source = self.plugin_sources.pop(name, None)\n if source:\n source.disable()\n for plugin_name in list(source._modules.keys()):\n source.unload_plugin(plugin_name)\n print(f\"PluginManager: Removed plugin source '{name}'.\")\n else:\n print(f\"PluginManager: Plugin source '{name}' does not exist.\")\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginManager.create_plugin_source","title":"create_plugin_source(name, search_path=None)
","text":"Creates a new PluginSource.
:param name: Unique name for the plugin source. :param search_path: List of paths to search for plugins. :return: The created PluginSource.
Source code in src/aeiva/plugin/plug.py
def create_plugin_source(self, name: str, search_path: Optional[List[str]] = None) -> PluginSource:\n \"\"\"\n Creates a new PluginSource.\n\n :param name: Unique name for the plugin source.\n :param search_path: List of paths to search for plugins.\n :return: The created PluginSource.\n \"\"\"\n if name in self.plugin_sources:\n raise ValueError(f\"Plugin source '{name}' already exists.\")\n source = PluginSource(name, search_path)\n self.plugin_sources[name] = source\n print(f\"PluginManager: Created plugin source '{name}' with search paths {search_path}.\")\n return source\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginManager.get_plugin_source","title":"get_plugin_source(name)
","text":"Retrieves a PluginSource by name.
:param name: Name of the PluginSource. :return: The PluginSource instance, or None if not found.
Source code in src/aeiva/plugin/plug.py
def get_plugin_source(self, name: str) -> Optional[PluginSource]:\n \"\"\"\n Retrieves a PluginSource by name.\n\n :param name: Name of the PluginSource.\n :return: The PluginSource instance, or None if not found.\n \"\"\"\n return self.plugin_sources.get(name)\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginManager.remove_plugin_source","title":"remove_plugin_source(name)
","text":"Removes a PluginSource.
:param name: Name of the PluginSource to remove.
Source code in src/aeiva/plugin/plug.py
def remove_plugin_source(self, name: str) -> None:\n \"\"\"\n Removes a PluginSource.\n\n :param name: Name of the PluginSource to remove.\n \"\"\"\n source = self.plugin_sources.pop(name, None)\n if source:\n source.disable()\n for plugin_name in list(source._modules.keys()):\n source.unload_plugin(plugin_name)\n print(f\"PluginManager: Removed plugin source '{name}'.\")\n else:\n print(f\"PluginManager: Plugin source '{name}' does not exist.\")\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource","title":"PluginSource
","text":"Represents an isolated source of plugins. Each plugin is a directory containing a plugin.py
file.
Source code in src/aeiva/plugin/plug.py
class PluginSource:\n \"\"\"\n Represents an isolated source of plugins.\n Each plugin is a directory containing a `plugin.py` file.\n \"\"\"\n\n def __init__(self, name: str, search_path: Optional[List[str]] = None) -> None:\n \"\"\"\n Initializes the PluginSource.\n\n :param name: Unique name for the plugin source.\n :param search_path: List of paths (directories or zip files) to search for plugins.\n \"\"\"\n self.name = name\n self.search_path = search_path or []\n self._lock = threading.Lock()\n self._modules: Dict[str, ModuleType] = {}\n self.namespace = f\"_plug_{self.name}\"\n self._finder = PluginFinder(self)\n self._finder_enabled = False\n\n def __enter__(self) -> 'PluginSource':\n \"\"\"Enter the runtime context related to this object.\"\"\"\n self.enable()\n return self\n\n def __exit__(self, exc_type, exc_value, traceback) -> None:\n \"\"\"Exit the runtime context.\"\"\"\n self.disable()\n\n def enable(self) -> None:\n \"\"\"Enable the plugin import mechanism.\"\"\"\n if not self._finder_enabled:\n sys.meta_path.insert(0, self._finder)\n self._finder_enabled = True\n print(f\"PluginSource: Import hook enabled for namespace '{self.namespace}'.\")\n\n def disable(self) -> None:\n \"\"\"Disable the plugin import mechanism.\"\"\"\n if self._finder_enabled:\n try:\n sys.meta_path.remove(self._finder)\n print(f\"PluginSource: Import hook disabled for namespace '{self.namespace}'.\")\n except ValueError:\n print(f\"PluginSource: Import hook for namespace '{self.namespace}' was not found in sys.meta_path.\")\n self._finder_enabled = False\n\n def list_plugins(self) -> List[str]:\n \"\"\"\n Lists available plugins in the search paths.\n Each plugin is a directory containing a `plugin.py` file.\n\n :return: List of plugin names.\n \"\"\"\n plugins = set()\n for path in self.search_path:\n if zipfile.is_zipfile(path):\n with zipfile.ZipFile(path, 'r') as z:\n # Identify top-level directories containing `plugin.py`\n plugin_dirs = set()\n for file in z.namelist():\n parts = file.split('/')\n if len(parts) >= 2 and parts[-1] == 'plugin.py':\n plugin_dir = parts[0]\n plugin_dirs.add(plugin_dir)\n plugins.update(plugin_dirs)\n else:\n # Assume it's a directory\n if not os.path.isdir(path):\n print(f\"PluginSource: Path '{path}' is not a directory or a zip file. Skipping.\")\n continue\n for entry in os.listdir(path):\n plugin_path = os.path.join(path, entry)\n if os.path.isdir(plugin_path):\n plugin_main = os.path.join(plugin_path, 'plugin.py')\n if os.path.isfile(plugin_main):\n plugins.add(entry)\n return list(plugins)\n\n def get_plugin_code(self, plugin_name: str) -> str:\n \"\"\"\n Get the source code of the plugin's `plugin.py`.\n\n :param plugin_name: Name of the plugin to load.\n :return: Source code of `plugin.py` as a string.\n \"\"\"\n for path in self.search_path:\n if zipfile.is_zipfile(path):\n with zipfile.ZipFile(path, 'r') as z:\n plugin_main = f\"{plugin_name}/plugin.py\"\n if plugin_main in z.namelist():\n print(f\"PluginSource: Found plugin '{plugin_name}' in zip file '{path}'.\")\n return z.read(plugin_main).decode('utf-8')\n else:\n # Assume it's a directory\n plugin_dir = os.path.join(path, plugin_name)\n plugin_main = os.path.join(plugin_dir, 'plugin.py')\n if os.path.isfile(plugin_main):\n print(f\"PluginSource: Found plugin '{plugin_name}' as module file '{plugin_main}'.\")\n with open(plugin_main, 'r', encoding='utf-8') as f:\n return f.read()\n raise ImportError(f\"Cannot find plugin '{plugin_name}'.\")\n\n def load_plugin(self, plugin_name: str) -> ModuleType:\n \"\"\"\n Loads a plugin by name.\n\n :param plugin_name: Name of the plugin to load.\n :return: The loaded plugin module.\n \"\"\"\n with self._lock:\n full_name = f\"{self.namespace}.{plugin_name}\"\n if full_name in sys.modules:\n print(f\"PluginSource: Plugin '{plugin_name}' is already loaded as '{full_name}'.\")\n return sys.modules[full_name]\n # Enable the finder if not already enabled\n self.enable()\n try:\n print(f\"PluginSource: Loading plugin '{plugin_name}' as '{full_name}'.\")\n module = importlib.import_module(full_name)\n self._modules[plugin_name] = module\n return module\n except ImportError as e:\n print(f\"PluginSource: Cannot import plugin '{plugin_name}': {e}\")\n raise\n\n def unload_plugin(self, plugin_name: str) -> None:\n \"\"\"\n Unloads a plugin by name.\n\n :param plugin_name: Name of the plugin to unload.\n \"\"\"\n with self._lock:\n full_name = f\"{self.namespace}.{plugin_name}\"\n module = self._modules.pop(plugin_name, None)\n if module:\n if hasattr(module, 'deactivate'):\n try:\n print(f\"PluginSource: Deactivating plugin '{plugin_name}'.\")\n getattr(module, 'deactivate')()\n except Exception as e:\n print(f\"PluginSource: Error during deactivation of plugin '{plugin_name}': {e}\")\n if full_name in sys.modules:\n del sys.modules[full_name]\n print(f\"PluginSource: Plugin '{plugin_name}' unloaded and removed from sys.modules.\")\n else:\n print(f\"PluginSource: Plugin '{plugin_name}' is not loaded.\")\n\n def load_resource(self, plugin_name: str, resource_name: str) -> bytes:\n \"\"\"\n Loads a resource from a plugin.\n\n :param plugin_name: Name of the plugin.\n :param resource_name: Name of the resource file.\n :return: Contents of the resource file as bytes.\n \"\"\"\n for path in self.search_path:\n if zipfile.is_zipfile(path):\n with zipfile.ZipFile(path, 'r') as z:\n resource_file = f\"{plugin_name}/{resource_name}\"\n if resource_file in z.namelist():\n print(f\"PluginSource: Loading resource '{resource_name}' from plugin '{plugin_name}' in zip '{path}'.\")\n return z.read(resource_file)\n else:\n # Assume it's a directory\n resource_path = os.path.join(path, plugin_name, resource_name)\n if os.path.isfile(resource_path):\n print(f\"PluginSource: Loading resource '{resource_name}' from plugin '{plugin_name}' at '{resource_path}'.\")\n with open(resource_path, 'rb') as f:\n return f.read()\n raise FileNotFoundError(f\"Resource '{resource_name}' not found in plugin '{plugin_name}'.\")\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource.__enter__","title":"__enter__()
","text":"Enter the runtime context related to this object.
Source code in src/aeiva/plugin/plug.py
def __enter__(self) -> 'PluginSource':\n \"\"\"Enter the runtime context related to this object.\"\"\"\n self.enable()\n return self\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource.__exit__","title":"__exit__(exc_type, exc_value, traceback)
","text":"Exit the runtime context.
Source code in src/aeiva/plugin/plug.py
def __exit__(self, exc_type, exc_value, traceback) -> None:\n \"\"\"Exit the runtime context.\"\"\"\n self.disable()\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource.__init__","title":"__init__(name, search_path=None)
","text":"Initializes the PluginSource.
:param name: Unique name for the plugin source. :param search_path: List of paths (directories or zip files) to search for plugins.
Source code in src/aeiva/plugin/plug.py
def __init__(self, name: str, search_path: Optional[List[str]] = None) -> None:\n \"\"\"\n Initializes the PluginSource.\n\n :param name: Unique name for the plugin source.\n :param search_path: List of paths (directories or zip files) to search for plugins.\n \"\"\"\n self.name = name\n self.search_path = search_path or []\n self._lock = threading.Lock()\n self._modules: Dict[str, ModuleType] = {}\n self.namespace = f\"_plug_{self.name}\"\n self._finder = PluginFinder(self)\n self._finder_enabled = False\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource.disable","title":"disable()
","text":"Disable the plugin import mechanism.
Source code in src/aeiva/plugin/plug.py
def disable(self) -> None:\n \"\"\"Disable the plugin import mechanism.\"\"\"\n if self._finder_enabled:\n try:\n sys.meta_path.remove(self._finder)\n print(f\"PluginSource: Import hook disabled for namespace '{self.namespace}'.\")\n except ValueError:\n print(f\"PluginSource: Import hook for namespace '{self.namespace}' was not found in sys.meta_path.\")\n self._finder_enabled = False\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource.enable","title":"enable()
","text":"Enable the plugin import mechanism.
Source code in src/aeiva/plugin/plug.py
def enable(self) -> None:\n \"\"\"Enable the plugin import mechanism.\"\"\"\n if not self._finder_enabled:\n sys.meta_path.insert(0, self._finder)\n self._finder_enabled = True\n print(f\"PluginSource: Import hook enabled for namespace '{self.namespace}'.\")\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource.get_plugin_code","title":"get_plugin_code(plugin_name)
","text":"Get the source code of the plugin's plugin.py
.
:param plugin_name: Name of the plugin to load. :return: Source code of plugin.py
as a string.
Source code in src/aeiva/plugin/plug.py
def get_plugin_code(self, plugin_name: str) -> str:\n \"\"\"\n Get the source code of the plugin's `plugin.py`.\n\n :param plugin_name: Name of the plugin to load.\n :return: Source code of `plugin.py` as a string.\n \"\"\"\n for path in self.search_path:\n if zipfile.is_zipfile(path):\n with zipfile.ZipFile(path, 'r') as z:\n plugin_main = f\"{plugin_name}/plugin.py\"\n if plugin_main in z.namelist():\n print(f\"PluginSource: Found plugin '{plugin_name}' in zip file '{path}'.\")\n return z.read(plugin_main).decode('utf-8')\n else:\n # Assume it's a directory\n plugin_dir = os.path.join(path, plugin_name)\n plugin_main = os.path.join(plugin_dir, 'plugin.py')\n if os.path.isfile(plugin_main):\n print(f\"PluginSource: Found plugin '{plugin_name}' as module file '{plugin_main}'.\")\n with open(plugin_main, 'r', encoding='utf-8') as f:\n return f.read()\n raise ImportError(f\"Cannot find plugin '{plugin_name}'.\")\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource.list_plugins","title":"list_plugins()
","text":"Lists available plugins in the search paths. Each plugin is a directory containing a plugin.py
file.
:return: List of plugin names.
Source code in src/aeiva/plugin/plug.py
def list_plugins(self) -> List[str]:\n \"\"\"\n Lists available plugins in the search paths.\n Each plugin is a directory containing a `plugin.py` file.\n\n :return: List of plugin names.\n \"\"\"\n plugins = set()\n for path in self.search_path:\n if zipfile.is_zipfile(path):\n with zipfile.ZipFile(path, 'r') as z:\n # Identify top-level directories containing `plugin.py`\n plugin_dirs = set()\n for file in z.namelist():\n parts = file.split('/')\n if len(parts) >= 2 and parts[-1] == 'plugin.py':\n plugin_dir = parts[0]\n plugin_dirs.add(plugin_dir)\n plugins.update(plugin_dirs)\n else:\n # Assume it's a directory\n if not os.path.isdir(path):\n print(f\"PluginSource: Path '{path}' is not a directory or a zip file. Skipping.\")\n continue\n for entry in os.listdir(path):\n plugin_path = os.path.join(path, entry)\n if os.path.isdir(plugin_path):\n plugin_main = os.path.join(plugin_path, 'plugin.py')\n if os.path.isfile(plugin_main):\n plugins.add(entry)\n return list(plugins)\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource.load_plugin","title":"load_plugin(plugin_name)
","text":"Loads a plugin by name.
:param plugin_name: Name of the plugin to load. :return: The loaded plugin module.
Source code in src/aeiva/plugin/plug.py
def load_plugin(self, plugin_name: str) -> ModuleType:\n \"\"\"\n Loads a plugin by name.\n\n :param plugin_name: Name of the plugin to load.\n :return: The loaded plugin module.\n \"\"\"\n with self._lock:\n full_name = f\"{self.namespace}.{plugin_name}\"\n if full_name in sys.modules:\n print(f\"PluginSource: Plugin '{plugin_name}' is already loaded as '{full_name}'.\")\n return sys.modules[full_name]\n # Enable the finder if not already enabled\n self.enable()\n try:\n print(f\"PluginSource: Loading plugin '{plugin_name}' as '{full_name}'.\")\n module = importlib.import_module(full_name)\n self._modules[plugin_name] = module\n return module\n except ImportError as e:\n print(f\"PluginSource: Cannot import plugin '{plugin_name}': {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource.load_resource","title":"load_resource(plugin_name, resource_name)
","text":"Loads a resource from a plugin.
:param plugin_name: Name of the plugin. :param resource_name: Name of the resource file. :return: Contents of the resource file as bytes.
Source code in src/aeiva/plugin/plug.py
def load_resource(self, plugin_name: str, resource_name: str) -> bytes:\n \"\"\"\n Loads a resource from a plugin.\n\n :param plugin_name: Name of the plugin.\n :param resource_name: Name of the resource file.\n :return: Contents of the resource file as bytes.\n \"\"\"\n for path in self.search_path:\n if zipfile.is_zipfile(path):\n with zipfile.ZipFile(path, 'r') as z:\n resource_file = f\"{plugin_name}/{resource_name}\"\n if resource_file in z.namelist():\n print(f\"PluginSource: Loading resource '{resource_name}' from plugin '{plugin_name}' in zip '{path}'.\")\n return z.read(resource_file)\n else:\n # Assume it's a directory\n resource_path = os.path.join(path, plugin_name, resource_name)\n if os.path.isfile(resource_path):\n print(f\"PluginSource: Loading resource '{resource_name}' from plugin '{plugin_name}' at '{resource_path}'.\")\n with open(resource_path, 'rb') as f:\n return f.read()\n raise FileNotFoundError(f\"Resource '{resource_name}' not found in plugin '{plugin_name}'.\")\n
"},{"location":"reference/#src.aeiva.plugin.plug.PluginSource.unload_plugin","title":"unload_plugin(plugin_name)
","text":"Unloads a plugin by name.
:param plugin_name: Name of the plugin to unload.
Source code in src/aeiva/plugin/plug.py
def unload_plugin(self, plugin_name: str) -> None:\n \"\"\"\n Unloads a plugin by name.\n\n :param plugin_name: Name of the plugin to unload.\n \"\"\"\n with self._lock:\n full_name = f\"{self.namespace}.{plugin_name}\"\n module = self._modules.pop(plugin_name, None)\n if module:\n if hasattr(module, 'deactivate'):\n try:\n print(f\"PluginSource: Deactivating plugin '{plugin_name}'.\")\n getattr(module, 'deactivate')()\n except Exception as e:\n print(f\"PluginSource: Error during deactivation of plugin '{plugin_name}': {e}\")\n if full_name in sys.modules:\n del sys.modules[full_name]\n print(f\"PluginSource: Plugin '{plugin_name}' unloaded and removed from sys.modules.\")\n else:\n print(f\"PluginSource: Plugin '{plugin_name}' is not loaded.\")\n
"},{"location":"reference/#src.aeiva.plugin.test","title":"test
","text":""},{"location":"reference/#src.aeiva.plugin.test--main-application","title":"Main Application","text":"This script demonstrates the usage of the plug module and plugin system.
"},{"location":"reference/#src.aeiva.society","title":"society
","text":""},{"location":"reference/#src.aeiva.society.society","title":"society
","text":""},{"location":"reference/#src.aeiva.society.society.Society","title":"Society
","text":" Bases: ABC
Abstract base class representing a Society that connects an environment and agents.
The Society enables agents to interact with each other and with the environment, providing mechanisms for integrating social systems, such as communication or economy.
Attributes:
Name Type Description config
Any
Configuration settings for the society.
environment
Environment
The environment in which agents operate.
agents
Dict[str, Any]
A dictionary of agents within the society.
social_systems
Dict[str, Any]
A dictionary representing various social systems (e.g., communication).
Source code in src/aeiva/society/society.py
class Society(ABC):\n \"\"\"\n Abstract base class representing a Society that connects an environment and agents.\n\n The Society enables agents to interact with each other and with the environment, providing\n mechanisms for integrating social systems, such as communication or economy.\n\n Attributes:\n config (Any): Configuration settings for the society.\n environment (Environment): The environment in which agents operate.\n agents (Dict[str, Any]): A dictionary of agents within the society.\n social_systems (Dict[str, Any]): A dictionary representing various social systems (e.g., communication).\n \"\"\"\n\n def __init__(self, config: Any, environment: Any, agents: Dict[str, Any]):\n \"\"\"\n Initialize the Society with the provided configuration, environment, and agents.\n\n Args:\n config (Any): Configuration settings for the society.\n env (Environment): The environment in which agents operate.\n agents (Dict[str, Any]): A dictionary of agents within the society, keyed by their IDs.\n \"\"\"\n self.config = config\n self.environment = environment\n self.agents = agents # Agents are stored in a dictionary with IDs as keys\n self.social_systems = self.init_social_systems()\n\n @abstractmethod\n def init_social_systems(self) -> Dict[str, Any]:\n \"\"\"\n Initialize the social systems that operate within the society (e.g., communication, financial, law, political, social network systems).\n\n Returns:\n Dict[str, Any]: A dictionary of initialized social systems.\n \"\"\"\n pass\n\n @abstractmethod\n async def setup(self) -> None:\n \"\"\"\n Asynchronously set up the society's components, such as initializing the environment and agents.\n \"\"\"\n await self.env.setup()\n await asyncio.gather(*(agent.setup() for agent in self.agents.values()))\n print(\"Society: Setup completed.\")\n\n @abstractmethod\n async def run(self) -> None:\n \"\"\"\n Asynchronously run the society, managing interactions between agents and the environment.\n\n This method should control the flow of interactions between agents and the environment,\n and it can be designed as a continuous loop or a task-based execution.\n \"\"\"\n pass\n\n def add_agent(self, agent_id: str, agent: Any) -> None:\n \"\"\"\n Add a new agent to the society.\n\n Args:\n agent_id (str): The unique identifier of the agent.\n agent (Any): The agent object to add to the society.\n \"\"\"\n self.agents[agent_id] = agent\n\n def remove_agent(self, agent_id: str) -> None:\n \"\"\"\n Remove an agent from the society by its ID.\n\n Args:\n agent_id (str): The unique identifier of the agent.\n \"\"\"\n if agent_id in self.agents:\n del self.agents[agent_id]\n\n def get_agent(self, agent_id: str) -> Any:\n \"\"\"\n Retrieve an agent by its ID.\n\n Args:\n agent_id (str): The unique identifier of the agent.\n\n Returns:\n Any: The agent object, if found.\n \"\"\"\n return self.agents.get(agent_id, None)\n\n def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during society operations.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n print(f\"Society encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.society.society.Society.__init__","title":"__init__(config, environment, agents)
","text":"Initialize the Society with the provided configuration, environment, and agents.
Parameters:
Name Type Description Default config
Any
Configuration settings for the society.
required env
Environment
The environment in which agents operate.
required agents
Dict[str, Any]
A dictionary of agents within the society, keyed by their IDs.
required Source code in src/aeiva/society/society.py
def __init__(self, config: Any, environment: Any, agents: Dict[str, Any]):\n \"\"\"\n Initialize the Society with the provided configuration, environment, and agents.\n\n Args:\n config (Any): Configuration settings for the society.\n env (Environment): The environment in which agents operate.\n agents (Dict[str, Any]): A dictionary of agents within the society, keyed by their IDs.\n \"\"\"\n self.config = config\n self.environment = environment\n self.agents = agents # Agents are stored in a dictionary with IDs as keys\n self.social_systems = self.init_social_systems()\n
"},{"location":"reference/#src.aeiva.society.society.Society.add_agent","title":"add_agent(agent_id, agent)
","text":"Add a new agent to the society.
Parameters:
Name Type Description Default agent_id
str
The unique identifier of the agent.
required agent
Any
The agent object to add to the society.
required Source code in src/aeiva/society/society.py
def add_agent(self, agent_id: str, agent: Any) -> None:\n \"\"\"\n Add a new agent to the society.\n\n Args:\n agent_id (str): The unique identifier of the agent.\n agent (Any): The agent object to add to the society.\n \"\"\"\n self.agents[agent_id] = agent\n
"},{"location":"reference/#src.aeiva.society.society.Society.get_agent","title":"get_agent(agent_id)
","text":"Retrieve an agent by its ID.
Parameters:
Name Type Description Default agent_id
str
The unique identifier of the agent.
required Returns:
Name Type Description Any
Any
The agent object, if found.
Source code in src/aeiva/society/society.py
def get_agent(self, agent_id: str) -> Any:\n \"\"\"\n Retrieve an agent by its ID.\n\n Args:\n agent_id (str): The unique identifier of the agent.\n\n Returns:\n Any: The agent object, if found.\n \"\"\"\n return self.agents.get(agent_id, None)\n
"},{"location":"reference/#src.aeiva.society.society.Society.handle_error","title":"handle_error(error)
","text":"Handle errors that occur during society operations.
Parameters:
Name Type Description Default error
Exception
The exception that was raised.
required Source code in src/aeiva/society/society.py
def handle_error(self, error: Exception) -> None:\n \"\"\"\n Handle errors that occur during society operations.\n\n Args:\n error (Exception): The exception that was raised.\n \"\"\"\n print(f\"Society encountered an error: {error}\")\n
"},{"location":"reference/#src.aeiva.society.society.Society.init_social_systems","title":"init_social_systems()
abstractmethod
","text":"Initialize the social systems that operate within the society (e.g., communication, financial, law, political, social network systems).
Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary of initialized social systems.
Source code in src/aeiva/society/society.py
@abstractmethod\ndef init_social_systems(self) -> Dict[str, Any]:\n \"\"\"\n Initialize the social systems that operate within the society (e.g., communication, financial, law, political, social network systems).\n\n Returns:\n Dict[str, Any]: A dictionary of initialized social systems.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.society.society.Society.remove_agent","title":"remove_agent(agent_id)
","text":"Remove an agent from the society by its ID.
Parameters:
Name Type Description Default agent_id
str
The unique identifier of the agent.
required Source code in src/aeiva/society/society.py
def remove_agent(self, agent_id: str) -> None:\n \"\"\"\n Remove an agent from the society by its ID.\n\n Args:\n agent_id (str): The unique identifier of the agent.\n \"\"\"\n if agent_id in self.agents:\n del self.agents[agent_id]\n
"},{"location":"reference/#src.aeiva.society.society.Society.run","title":"run()
abstractmethod
async
","text":"Asynchronously run the society, managing interactions between agents and the environment.
This method should control the flow of interactions between agents and the environment, and it can be designed as a continuous loop or a task-based execution.
Source code in src/aeiva/society/society.py
@abstractmethod\nasync def run(self) -> None:\n \"\"\"\n Asynchronously run the society, managing interactions between agents and the environment.\n\n This method should control the flow of interactions between agents and the environment,\n and it can be designed as a continuous loop or a task-based execution.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.society.society.Society.setup","title":"setup()
abstractmethod
async
","text":"Asynchronously set up the society's components, such as initializing the environment and agents.
Source code in src/aeiva/society/society.py
@abstractmethod\nasync def setup(self) -> None:\n \"\"\"\n Asynchronously set up the society's components, such as initializing the environment and agents.\n \"\"\"\n await self.env.setup()\n await asyncio.gather(*(agent.setup() for agent in self.agents.values()))\n print(\"Society: Setup completed.\")\n
"},{"location":"reference/#src.aeiva.storage","title":"storage
","text":""},{"location":"reference/#src.aeiva.storage.azure_ai_search","title":"azure_ai_search
","text":""},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_config","title":"azure_ai_search_config
","text":""},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_config.AzureAISearchConfig","title":"AzureAISearchConfig
dataclass
","text":" Bases: BaseConfig
Configuration for Azure Cognitive Search vector database.
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_config.py
@dataclass\nclass AzureAISearchConfig(BaseConfig):\n \"\"\"\n Configuration for Azure Cognitive Search vector database.\n \"\"\"\n\n collection_name: str = field(\n default=\"mem0\",\n metadata={\"help\": \"Name of the collection (index name).\"}\n )\n service_name: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Azure Cognitive Search service name.\"}\n )\n api_key: Optional[str] = field(\n default=None,\n metadata={\"help\": \"API key for the Azure Cognitive Search service.\"}\n )\n embedding_model_dims: int = field(\n default=1536,\n metadata={\"help\": \"Dimension of the embedding vector.\"}\n )\n use_compression: bool = field(\n default=False,\n metadata={\"help\": \"Whether to use scalar quantization vector compression.\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n # Validate that service_name and api_key are provided\n if not self.service_name or not self.api_key:\n raise ValueError(\"Both 'service_name' and 'api_key' must be provided.\")\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database","title":"azure_ai_search_database
","text":""},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase","title":"AzureAISearchDatabase
","text":" Bases: VectorDatabase
Concrete implementation of VectorStoreBase using Azure Cognitive Search.
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
class AzureAISearchDatabase(VectorDatabase):\n \"\"\"\n Concrete implementation of VectorStoreBase using Azure Cognitive Search.\n \"\"\"\n\n def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the Azure Cognitive Search vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.index_name = config.get('collection_name')\n self.service_name = config.get('service_name')\n self.api_key = config.get('api_key')\n self.embedding_model_dims = config.get('embedding_model_dims')\n self.use_compression = config.get('use_compression', False)\n\n if not all([self.service_name, self.api_key, self.index_name, self.embedding_model_dims]):\n raise ValueError(\"Required configuration parameters are missing.\")\n\n self.create_client(\n uri=None,\n service_name=self.service_name,\n api_key=self.api_key\n )\n self.create_collection(\n collection_name=self.index_name,\n vector_size=self.embedding_model_dims,\n distance_metric='cosine'\n )\n\n def create_client(\n self,\n uri: Optional[str] = None,\n service_name: Optional[str] = None,\n api_key: Optional[str] = None,\n **kwargs\n ) -> None:\n \"\"\"\n Initializes the client connection to the vector store.\n\n Args:\n uri (Optional[str]): Not used for Azure Cognitive Search.\n service_name (str): Azure Cognitive Search service name.\n api_key (str): API key for the Azure Cognitive Search service.\n **kwargs: Additional parameters.\n \"\"\"\n if not service_name or not api_key:\n raise ValueError(\"Both 'service_name' and 'api_key' must be provided.\")\n\n endpoint = f\"https://{service_name}.search.windows.net\"\n credential = AzureKeyCredential(api_key)\n self.search_client = SearchClient(\n endpoint=endpoint,\n index_name=self.index_name,\n credential=credential\n )\n self.index_client = SearchIndexClient(\n endpoint=endpoint,\n credential=credential\n )\n\n def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection (index) in Azure Cognitive Search.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use (e.g., 'cosine').\n \"\"\"\n # Check if the index already exists\n try:\n self.index_client.get_index(collection_name)\n logger.info(f\"Index {collection_name} already exists. Skipping creation.\")\n return\n except ResourceNotFoundError:\n pass # Index does not exist, proceed to create\n\n if self.use_compression:\n vector_type = \"Collection(Edm.Half)\"\n compression_name = \"myCompression\"\n compression_configurations = [ScalarQuantizationCompression(compression_name=compression_name)]\n else:\n vector_type = \"Collection(Edm.Single)\"\n compression_name = None\n compression_configurations = []\n\n fields = [\n SimpleField(name=\"id\", type=SearchFieldDataType.String, key=True),\n SearchField(\n name=\"vector\",\n type=vector_type,\n searchable=True,\n vector_search_dimensions=vector_size,\n vector_search_profile_name=\"my-vector-config\",\n ),\n SimpleField(name=\"payload\", type=SearchFieldDataType.String, searchable=True),\n ]\n\n vector_search = VectorSearch(\n profiles=[\n VectorSearchProfile(name=\"my-vector-config\", algorithm_configuration_name=\"my-algorithms-config\")\n ],\n algorithms=[HnswAlgorithmConfiguration(name=\"my-algorithms-config\")],\n compressions=compression_configurations,\n )\n index = SearchIndex(name=collection_name, fields=fields, vector_search=vector_search)\n self.index_client.create_or_update_index(index)\n logger.info(f\"Index {collection_name} created successfully.\")\n\n def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n ) -> None:\n \"\"\"\n Insert vectors into the index.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n if ids is None:\n ids = [str(i) for i in range(len(vectors))]\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n documents = [\n {\"id\": id_, \"vector\": vector, \"payload\": json.dumps(payload)}\n for id_, vector, payload in zip(ids, vectors, payloads)\n ]\n self.search_client.upload_documents(documents)\n logger.info(f\"Inserted {len(vectors)} vectors into index {collection_name}.\")\n\n def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n vector_query = VectorizedQuery(vector=query_vector, k_nearest_neighbors=top_k, fields=\"vector\")\n search_results = self.search_client.search(vector_queries=[vector_query], top=top_k)\n\n results = []\n for result in search_results:\n payload = json.loads(result[\"payload\"])\n if filters:\n for key, value in filters.items():\n if key not in payload or payload[key] != value:\n continue\n result_dict = {\n \"id\": result[\"id\"],\n \"score\": result[\"@search.score\"],\n \"payload\": payload\n }\n results.append(result_dict)\n return results\n\n def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n self.search_client.delete_documents(documents=[{\"id\": vector_id}])\n logger.info(f\"Deleted vector with ID {vector_id} from collection {collection_name}.\")\n\n def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n ) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n document = {\"id\": vector_id}\n if vector is not None:\n document[\"vector\"] = vector\n if payload is not None:\n document[\"payload\"] = json.dumps(payload)\n self.search_client.merge_or_upload_documents(documents=[document])\n logger.info(f\"Updated vector with ID {vector_id} in collection {collection_name}.\")\n\n def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n try:\n result = self.search_client.get_document(key=vector_id)\n payload = json.loads(result[\"payload\"])\n vector_data = {\n \"id\": result[\"id\"],\n \"vector\": result[\"vector\"],\n \"payload\": payload\n }\n return vector_data\n except ResourceNotFoundError:\n raise KeyError(f\"Vector with ID {vector_id} not found in collection {collection_name}.\")\n\n def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections.\n\n Returns:\n List[str]: A list of collection names.\n \"\"\"\n indexes = self.index_client.list_indexes()\n return [index.name for index in indexes]\n\n def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n \"\"\"\n self.index_client.delete_index(collection_name)\n logger.info(f\"Deleted collection {collection_name}.\")\n\n def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection.\n \"\"\"\n index = self.index_client.get_index(collection_name)\n return {\n \"name\": index.name,\n \"fields\": [field.name for field in index.fields],\n \"vector_search\": index.vector_search\n }\n\n def __del__(self):\n \"\"\"Clean up resources.\"\"\"\n self.search_client.close()\n self.index_client.close()\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.__del__","title":"__del__()
","text":"Clean up resources.
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def __del__(self):\n \"\"\"Clean up resources.\"\"\"\n self.search_client.close()\n self.index_client.close()\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.__init__","title":"__init__(config)
","text":"Initialize the Azure Cognitive Search vector store.
Parameters:
Name Type Description Default config
Dict[str, Any]
Configuration dictionary.
required Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the Azure Cognitive Search vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.index_name = config.get('collection_name')\n self.service_name = config.get('service_name')\n self.api_key = config.get('api_key')\n self.embedding_model_dims = config.get('embedding_model_dims')\n self.use_compression = config.get('use_compression', False)\n\n if not all([self.service_name, self.api_key, self.index_name, self.embedding_model_dims]):\n raise ValueError(\"Required configuration parameters are missing.\")\n\n self.create_client(\n uri=None,\n service_name=self.service_name,\n api_key=self.api_key\n )\n self.create_collection(\n collection_name=self.index_name,\n vector_size=self.embedding_model_dims,\n distance_metric='cosine'\n )\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.create_client","title":"create_client(uri=None, service_name=None, api_key=None, **kwargs)
","text":"Initializes the client connection to the vector store.
Parameters:
Name Type Description Default uri
Optional[str]
Not used for Azure Cognitive Search.
None
service_name
str
Azure Cognitive Search service name.
None
api_key
str
API key for the Azure Cognitive Search service.
None
**kwargs
Additional parameters.
{}
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def create_client(\n self,\n uri: Optional[str] = None,\n service_name: Optional[str] = None,\n api_key: Optional[str] = None,\n **kwargs\n) -> None:\n \"\"\"\n Initializes the client connection to the vector store.\n\n Args:\n uri (Optional[str]): Not used for Azure Cognitive Search.\n service_name (str): Azure Cognitive Search service name.\n api_key (str): API key for the Azure Cognitive Search service.\n **kwargs: Additional parameters.\n \"\"\"\n if not service_name or not api_key:\n raise ValueError(\"Both 'service_name' and 'api_key' must be provided.\")\n\n endpoint = f\"https://{service_name}.search.windows.net\"\n credential = AzureKeyCredential(api_key)\n self.search_client = SearchClient(\n endpoint=endpoint,\n index_name=self.index_name,\n credential=credential\n )\n self.index_client = SearchIndexClient(\n endpoint=endpoint,\n credential=credential\n )\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.create_collection","title":"create_collection(collection_name, vector_size, distance_metric)
","text":"Create a new vector collection (index) in Azure Cognitive Search.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_size
int
The dimensionality of the vectors.
required distance_metric
str
The distance metric to use (e.g., 'cosine').
required Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection (index) in Azure Cognitive Search.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use (e.g., 'cosine').\n \"\"\"\n # Check if the index already exists\n try:\n self.index_client.get_index(collection_name)\n logger.info(f\"Index {collection_name} already exists. Skipping creation.\")\n return\n except ResourceNotFoundError:\n pass # Index does not exist, proceed to create\n\n if self.use_compression:\n vector_type = \"Collection(Edm.Half)\"\n compression_name = \"myCompression\"\n compression_configurations = [ScalarQuantizationCompression(compression_name=compression_name)]\n else:\n vector_type = \"Collection(Edm.Single)\"\n compression_name = None\n compression_configurations = []\n\n fields = [\n SimpleField(name=\"id\", type=SearchFieldDataType.String, key=True),\n SearchField(\n name=\"vector\",\n type=vector_type,\n searchable=True,\n vector_search_dimensions=vector_size,\n vector_search_profile_name=\"my-vector-config\",\n ),\n SimpleField(name=\"payload\", type=SearchFieldDataType.String, searchable=True),\n ]\n\n vector_search = VectorSearch(\n profiles=[\n VectorSearchProfile(name=\"my-vector-config\", algorithm_configuration_name=\"my-algorithms-config\")\n ],\n algorithms=[HnswAlgorithmConfiguration(name=\"my-algorithms-config\")],\n compressions=compression_configurations,\n )\n index = SearchIndex(name=collection_name, fields=fields, vector_search=vector_search)\n self.index_client.create_or_update_index(index)\n logger.info(f\"Index {collection_name} created successfully.\")\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.delete_collection","title":"delete_collection(collection_name)
","text":"Delete an entire vector collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection to delete.
required Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n \"\"\"\n self.index_client.delete_index(collection_name)\n logger.info(f\"Deleted collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.delete_vector","title":"delete_vector(collection_name, vector_id)
","text":"Delete a vector from a collection by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to delete.
required Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n self.search_client.delete_documents(documents=[{\"id\": vector_id}])\n logger.info(f\"Deleted vector with ID {vector_id} from collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.get_collection_info","title":"get_collection_info(collection_name)
","text":"Get information about a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: Information about the collection.
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection.\n \"\"\"\n index = self.index_client.get_index(collection_name)\n return {\n \"name\": index.name,\n \"fields\": [field.name for field in index.fields],\n \"vector_search\": index.vector_search\n }\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.get_vector","title":"get_vector(collection_name, vector_id)
","text":"Retrieve a vector by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the vector data and payload.
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n try:\n result = self.search_client.get_document(key=vector_id)\n payload = json.loads(result[\"payload\"])\n vector_data = {\n \"id\": result[\"id\"],\n \"vector\": result[\"vector\"],\n \"payload\": payload\n }\n return vector_data\n except ResourceNotFoundError:\n raise KeyError(f\"Vector with ID {vector_id} not found in collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.insert_vectors","title":"insert_vectors(collection_name, vectors, payloads=None, ids=None)
","text":"Insert vectors into the index.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vectors
List[List[float]]
A list of vectors to insert.
required payloads
Optional[List[Dict[str, Any]]]
Optional metadata associated with each vector.
None
ids
Optional[List[str]]
Optional unique identifiers for each vector.
None
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n) -> None:\n \"\"\"\n Insert vectors into the index.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n if ids is None:\n ids = [str(i) for i in range(len(vectors))]\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n documents = [\n {\"id\": id_, \"vector\": vector, \"payload\": json.dumps(payload)}\n for id_, vector, payload in zip(ids, vectors, payloads)\n ]\n self.search_client.upload_documents(documents)\n logger.info(f\"Inserted {len(vectors)} vectors into index {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.list_collections","title":"list_collections()
","text":"List all available vector collections.
Returns:
Type Description List[str]
List[str]: A list of collection names.
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections.\n\n Returns:\n List[str]: A list of collection names.\n \"\"\"\n indexes = self.index_client.list_indexes()\n return [index.name for index in indexes]\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.search_vectors","title":"search_vectors(collection_name, query_vector, top_k=5, filters=None)
","text":"Search for similar vectors in a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required query_vector
List[float]
The vector to search with.
required top_k
int
The number of top results to return.
5
filters
Optional[Dict[str, Any]]
Optional filters to apply to the search.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of search results.
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n vector_query = VectorizedQuery(vector=query_vector, k_nearest_neighbors=top_k, fields=\"vector\")\n search_results = self.search_client.search(vector_queries=[vector_query], top=top_k)\n\n results = []\n for result in search_results:\n payload = json.loads(result[\"payload\"])\n if filters:\n for key, value in filters.items():\n if key not in payload or payload[key] != value:\n continue\n result_dict = {\n \"id\": result[\"id\"],\n \"score\": result[\"@search.score\"],\n \"payload\": payload\n }\n results.append(result_dict)\n return results\n
"},{"location":"reference/#src.aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase.update_vector","title":"update_vector(collection_name, vector_id, vector=None, payload=None)
","text":"Update a vector's data or payload.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to update.
required vector
Optional[List[float]]
The new vector data.
None
payload
Optional[Dict[str, Any]]
The new payload data.
None
Source code in src/aeiva/storage/azure_ai_search/azure_ai_search_database.py
def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n document = {\"id\": vector_id}\n if vector is not None:\n document[\"vector\"] = vector\n if payload is not None:\n document[\"payload\"] = json.dumps(payload)\n self.search_client.merge_or_upload_documents(documents=[document])\n logger.info(f\"Updated vector with ID {vector_id} in collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.chroma","title":"chroma
","text":""},{"location":"reference/#src.aeiva.storage.chroma.chroma_config","title":"chroma_config
","text":""},{"location":"reference/#src.aeiva.storage.chroma.chroma_config.ChromaConfig","title":"ChromaConfig
dataclass
","text":" Bases: BaseConfig
Configuration for ChromaDB vector database.
Source code in src/aeiva/storage/chroma/chroma_config.py
@dataclass\nclass ChromaConfig(BaseConfig):\n \"\"\"\n Configuration for ChromaDB vector database.\n \"\"\"\n\n collection_name: str = field(\n default=\"mem0\",\n metadata={\"help\": \"Name of the collection.\"}\n )\n client: Optional[Any] = field(\n default=None,\n metadata={\"help\": \"Existing ChromaDB client instance (if any).\"}\n )\n path: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Path to the database directory for local storage.\"}\n )\n host: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Remote host address for ChromaDB.\"}\n )\n port: Optional[int] = field(\n default=None,\n metadata={\"help\": \"Remote port for ChromaDB.\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n # Validate that either path or host and port are provided\n if not self.path and not (self.host and self.port):\n raise ValueError(\"Either 'path' for local storage or both 'host' and 'port' for remote connection must be provided.\")\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database","title":"chroma_database
","text":""},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase","title":"ChromaDatabase
","text":" Bases: VectorDatabase
Concrete implementation of VectorStoreBase using ChromaDB.
Source code in src/aeiva/storage/chroma/chroma_database.py
class ChromaDatabase(VectorDatabase):\n \"\"\"\n Concrete implementation of VectorStoreBase using ChromaDB.\n \"\"\"\n\n def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the ChromaDB vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.collection_name = config.get('collection_name')\n self.client = config.get('client')\n self.host = config.get('host')\n self.port = config.get('port')\n self.path = config.get('path')\n\n if not self.collection_name:\n raise ValueError(\"Collection name must be provided in the configuration.\")\n\n self.create_client(\n host=self.host,\n port=self.port,\n path=self.path\n )\n self.create_collection(\n collection_name=self.collection_name,\n vector_size=None, # ChromaDB does not require specifying vector size upfront\n distance_metric='cosine'\n )\n\n def create_client(\n self,\n uri: Optional[str] = None,\n host: Optional[str] = None,\n port: Optional[int] = None,\n path: Optional[str] = None,\n **kwargs\n ) -> None:\n \"\"\"\n Initializes the client connection to the vector store.\n\n Args:\n uri (Optional[str]): Not used for ChromaDB.\n host (Optional[str]): Host address for ChromaDB server.\n port (Optional[int]): Port for ChromaDB server.\n path (Optional[str]): Path to the database directory.\n **kwargs: Additional parameters.\n \"\"\"\n if self.client:\n return # Client already provided\n\n settings = Settings(anonymized_telemetry=False)\n\n if host and port:\n settings.chroma_api_impl = \"chromadb.api.fastapi.FastAPI\"\n settings.chroma_server_host = host\n settings.chroma_server_http_port = port\n else:\n if not path:\n path = \"db\"\n settings.persist_directory = path\n settings.is_persistent = True\n\n self.client = chromadb.Client(settings)\n logger.info(\"ChromaDB client initialized.\")\n\n def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection in ChromaDB.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): Not used for ChromaDB.\n distance_metric (str): Not used for ChromaDB.\n \"\"\"\n # Check if collection exists\n existing_collections = self.list_collections()\n if collection_name in existing_collections:\n logger.info(f\"Collection {collection_name} already exists. Skipping creation.\")\n self.collection = self.client.get_collection(name=collection_name)\n else:\n self.collection = self.client.create_collection(name=collection_name)\n logger.info(f\"Collection {collection_name} created successfully.\")\n\n def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n ) -> None:\n \"\"\"\n Insert vectors into a collection.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n if ids is None:\n ids = [str(i) for i in range(len(vectors))]\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)\n logger.info(f\"Inserted {len(vectors)} vectors into collection {collection_name}.\")\n\n def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n results = self.collection.query(\n query_embeddings=[query_vector],\n where=filters,\n n_results=top_k\n )\n # Parse the results\n output = []\n for idx, (ids, distances, metadatas) in enumerate(zip(results['ids'], results['distances'], results['metadatas'])):\n for i in range(len(ids)):\n result = {\n 'id': ids[i],\n 'score': distances[i],\n 'payload': metadatas[i]\n }\n output.append(result)\n return output\n\n def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n self.collection.delete(ids=[vector_id])\n logger.info(f\"Deleted vector with ID {vector_id} from collection {collection_name}.\")\n\n def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n ) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n self.collection.update(ids=[vector_id], embeddings=[vector] if vector else None, metadatas=[payload] if payload else None)\n logger.info(f\"Updated vector with ID {vector_id} in collection {collection_name}.\")\n\n def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n result = self.collection.get(ids=[vector_id])\n if not result['ids']:\n raise KeyError(f\"Vector with ID {vector_id} not found in collection {collection_name}.\")\n\n vector_data = {\n 'id': result['ids'][0],\n 'vector': result['embeddings'][0] if 'embeddings' in result else None,\n 'payload': result['metadatas'][0]\n }\n return vector_data\n\n def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections.\n\n Returns:\n List[str]: A list of collection names.\n \"\"\"\n collections = self.client.list_collections()\n return [collection.name for collection in collections]\n\n def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n \"\"\"\n self.client.delete_collection(name=collection_name)\n logger.info(f\"Deleted collection {collection_name}.\")\n\n def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection.\n \"\"\"\n collection = self.client.get_collection(name=collection_name)\n return {\n 'name': collection.name,\n 'metadata': collection.metadata\n }\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.__init__","title":"__init__(config)
","text":"Initialize the ChromaDB vector store.
Parameters:
Name Type Description Default config
Dict[str, Any]
Configuration dictionary.
required Source code in src/aeiva/storage/chroma/chroma_database.py
def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the ChromaDB vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.collection_name = config.get('collection_name')\n self.client = config.get('client')\n self.host = config.get('host')\n self.port = config.get('port')\n self.path = config.get('path')\n\n if not self.collection_name:\n raise ValueError(\"Collection name must be provided in the configuration.\")\n\n self.create_client(\n host=self.host,\n port=self.port,\n path=self.path\n )\n self.create_collection(\n collection_name=self.collection_name,\n vector_size=None, # ChromaDB does not require specifying vector size upfront\n distance_metric='cosine'\n )\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.create_client","title":"create_client(uri=None, host=None, port=None, path=None, **kwargs)
","text":"Initializes the client connection to the vector store.
Parameters:
Name Type Description Default uri
Optional[str]
Not used for ChromaDB.
None
host
Optional[str]
Host address for ChromaDB server.
None
port
Optional[int]
Port for ChromaDB server.
None
path
Optional[str]
Path to the database directory.
None
**kwargs
Additional parameters.
{}
Source code in src/aeiva/storage/chroma/chroma_database.py
def create_client(\n self,\n uri: Optional[str] = None,\n host: Optional[str] = None,\n port: Optional[int] = None,\n path: Optional[str] = None,\n **kwargs\n) -> None:\n \"\"\"\n Initializes the client connection to the vector store.\n\n Args:\n uri (Optional[str]): Not used for ChromaDB.\n host (Optional[str]): Host address for ChromaDB server.\n port (Optional[int]): Port for ChromaDB server.\n path (Optional[str]): Path to the database directory.\n **kwargs: Additional parameters.\n \"\"\"\n if self.client:\n return # Client already provided\n\n settings = Settings(anonymized_telemetry=False)\n\n if host and port:\n settings.chroma_api_impl = \"chromadb.api.fastapi.FastAPI\"\n settings.chroma_server_host = host\n settings.chroma_server_http_port = port\n else:\n if not path:\n path = \"db\"\n settings.persist_directory = path\n settings.is_persistent = True\n\n self.client = chromadb.Client(settings)\n logger.info(\"ChromaDB client initialized.\")\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.create_collection","title":"create_collection(collection_name, vector_size, distance_metric)
","text":"Create a new vector collection in ChromaDB.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_size
int
Not used for ChromaDB.
required distance_metric
str
Not used for ChromaDB.
required Source code in src/aeiva/storage/chroma/chroma_database.py
def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection in ChromaDB.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): Not used for ChromaDB.\n distance_metric (str): Not used for ChromaDB.\n \"\"\"\n # Check if collection exists\n existing_collections = self.list_collections()\n if collection_name in existing_collections:\n logger.info(f\"Collection {collection_name} already exists. Skipping creation.\")\n self.collection = self.client.get_collection(name=collection_name)\n else:\n self.collection = self.client.create_collection(name=collection_name)\n logger.info(f\"Collection {collection_name} created successfully.\")\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.delete_collection","title":"delete_collection(collection_name)
","text":"Delete an entire vector collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection to delete.
required Source code in src/aeiva/storage/chroma/chroma_database.py
def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n \"\"\"\n self.client.delete_collection(name=collection_name)\n logger.info(f\"Deleted collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.delete_vector","title":"delete_vector(collection_name, vector_id)
","text":"Delete a vector from a collection by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to delete.
required Source code in src/aeiva/storage/chroma/chroma_database.py
def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n self.collection.delete(ids=[vector_id])\n logger.info(f\"Deleted vector with ID {vector_id} from collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.get_collection_info","title":"get_collection_info(collection_name)
","text":"Get information about a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: Information about the collection.
Source code in src/aeiva/storage/chroma/chroma_database.py
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection.\n \"\"\"\n collection = self.client.get_collection(name=collection_name)\n return {\n 'name': collection.name,\n 'metadata': collection.metadata\n }\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.get_vector","title":"get_vector(collection_name, vector_id)
","text":"Retrieve a vector by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the vector data and payload.
Source code in src/aeiva/storage/chroma/chroma_database.py
def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n result = self.collection.get(ids=[vector_id])\n if not result['ids']:\n raise KeyError(f\"Vector with ID {vector_id} not found in collection {collection_name}.\")\n\n vector_data = {\n 'id': result['ids'][0],\n 'vector': result['embeddings'][0] if 'embeddings' in result else None,\n 'payload': result['metadatas'][0]\n }\n return vector_data\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.insert_vectors","title":"insert_vectors(collection_name, vectors, payloads=None, ids=None)
","text":"Insert vectors into a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vectors
List[List[float]]
A list of vectors to insert.
required payloads
Optional[List[Dict[str, Any]]]
Optional metadata associated with each vector.
None
ids
Optional[List[str]]
Optional unique identifiers for each vector.
None
Source code in src/aeiva/storage/chroma/chroma_database.py
def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n) -> None:\n \"\"\"\n Insert vectors into a collection.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n if ids is None:\n ids = [str(i) for i in range(len(vectors))]\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)\n logger.info(f\"Inserted {len(vectors)} vectors into collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.list_collections","title":"list_collections()
","text":"List all available vector collections.
Returns:
Type Description List[str]
List[str]: A list of collection names.
Source code in src/aeiva/storage/chroma/chroma_database.py
def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections.\n\n Returns:\n List[str]: A list of collection names.\n \"\"\"\n collections = self.client.list_collections()\n return [collection.name for collection in collections]\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.search_vectors","title":"search_vectors(collection_name, query_vector, top_k=5, filters=None)
","text":"Search for similar vectors in a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required query_vector
List[float]
The vector to search with.
required top_k
int
The number of top results to return.
5
filters
Optional[Dict[str, Any]]
Optional filters to apply to the search.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of search results.
Source code in src/aeiva/storage/chroma/chroma_database.py
def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n results = self.collection.query(\n query_embeddings=[query_vector],\n where=filters,\n n_results=top_k\n )\n # Parse the results\n output = []\n for idx, (ids, distances, metadatas) in enumerate(zip(results['ids'], results['distances'], results['metadatas'])):\n for i in range(len(ids)):\n result = {\n 'id': ids[i],\n 'score': distances[i],\n 'payload': metadatas[i]\n }\n output.append(result)\n return output\n
"},{"location":"reference/#src.aeiva.storage.chroma.chroma_database.ChromaDatabase.update_vector","title":"update_vector(collection_name, vector_id, vector=None, payload=None)
","text":"Update a vector's data or payload.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to update.
required vector
Optional[List[float]]
The new vector data.
None
payload
Optional[Dict[str, Any]]
The new payload data.
None
Source code in src/aeiva/storage/chroma/chroma_database.py
def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n self.collection.update(ids=[vector_id], embeddings=[vector] if vector else None, metadatas=[payload] if payload else None)\n logger.info(f\"Updated vector with ID {vector_id} in collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.database_factory","title":"database_factory
","text":""},{"location":"reference/#src.aeiva.storage.database_factory.DatabaseConfigFactory","title":"DatabaseConfigFactory
","text":"Factory class to create database configuration objects based on the provider name.
Example config = DatabaseConfigFactory.create( 'milvus', host='localhost', port=19530, embedding_model_dims=128, ... )
Source code in src/aeiva/storage/database_factory.py
class DatabaseConfigFactory:\n \"\"\"\n Factory class to create database configuration objects based on the provider name.\n\n Example:\n config = DatabaseConfigFactory.create(\n 'milvus',\n host='localhost',\n port=19530,\n embedding_model_dims=128,\n ...\n )\n \"\"\"\n\n provider_to_class = {\n \"milvus\": \"aeiva.storage.milvus.milvus_config.MilvusConfig\",\n \"chroma\": \"aeiva.storage.chroma.chroma_config.ChromaConfig\",\n \"azure_ai_search\": \"aeiva.storage.azure_ai_search.azure_ai_search_config.AzureAISearchConfig\",\n \"pgvector\": \"aeiva.storage.pgvector.pgvector_config.PGVectorConfig\",\n \"qdrant\": \"aeiva.storage.qdrant.qdrant_config.QdrantConfig\",\n \"neo4j\": \"aeiva.storage.neo4jdb.neo4j_config.Neo4jConfig\",\n \"sqlite\": \"aeiva.storage.sqlite.sqlite_config.SQLiteConfig\",\n \"postgresql\": \"aeiva.storage.postgresql.postgresql_config.PostgreSQLConfig\",\n \"weaviate\": \"aeiva.storage.weaviate.weaviate_config.WeaviateConfig\",\n }\n\n @classmethod\n def create(cls, provider_name: str, **kwargs) -> Any:\n \"\"\"\n Create a database configuration object based on the provider name.\n\n Args:\n provider_name (str): The name of the database provider (e.g., 'milvus', 'chroma').\n **kwargs: Configuration parameters specific to the database provider.\n\n Returns:\n Any: An instance of the database configuration class.\n\n Raises:\n ValueError: If the provider name is not supported.\n ImportError: If the configuration class cannot be imported.\n \"\"\"\n class_path = cls.provider_to_class.get(provider_name.lower())\n if class_path:\n config_class = load_class(class_path)\n return config_class(**kwargs)\n else:\n raise ValueError(f\"Unsupported database provider: {provider_name}\")\n
"},{"location":"reference/#src.aeiva.storage.database_factory.DatabaseConfigFactory.create","title":"create(provider_name, **kwargs)
classmethod
","text":"Create a database configuration object based on the provider name.
Parameters:
Name Type Description Default provider_name
str
The name of the database provider (e.g., 'milvus', 'chroma').
required **kwargs
Configuration parameters specific to the database provider.
{}
Returns:
Name Type Description Any
Any
An instance of the database configuration class.
Raises:
Type Description ValueError
If the provider name is not supported.
ImportError
If the configuration class cannot be imported.
Source code in src/aeiva/storage/database_factory.py
@classmethod\ndef create(cls, provider_name: str, **kwargs) -> Any:\n \"\"\"\n Create a database configuration object based on the provider name.\n\n Args:\n provider_name (str): The name of the database provider (e.g., 'milvus', 'chroma').\n **kwargs: Configuration parameters specific to the database provider.\n\n Returns:\n Any: An instance of the database configuration class.\n\n Raises:\n ValueError: If the provider name is not supported.\n ImportError: If the configuration class cannot be imported.\n \"\"\"\n class_path = cls.provider_to_class.get(provider_name.lower())\n if class_path:\n config_class = load_class(class_path)\n return config_class(**kwargs)\n else:\n raise ValueError(f\"Unsupported database provider: {provider_name}\")\n
"},{"location":"reference/#src.aeiva.storage.database_factory.DatabaseFactory","title":"DatabaseFactory
","text":"Factory class to create database objects based on the provider name and configuration.
Example db = DatabaseFactory.create('milvus', config)
Source code in src/aeiva/storage/database_factory.py
class DatabaseFactory:\n \"\"\"\n Factory class to create database objects based on the provider name and configuration.\n\n Example:\n db = DatabaseFactory.create('milvus', config)\n \"\"\"\n\n provider_to_class = {\n \"milvus\": \"aeiva.storage.milvus.milvus_database.MilvusDatabase\",\n \"chroma\": \"aeiva.storage.chroma.chroma_database.ChromaDatabase\",\n \"azure_ai_search\": \"aeiva.storage.azure_ai_search.azure_ai_search_database.AzureAISearchDatabase\",\n \"pgvector\": \"aeiva.storage.pgvector.pgvector_database.PGVectorDatabase\",\n \"qdrant\": \"aeiva.storage.qdrant.qdrant_database.QdrantDatabase\",\n \"neo4j\": \"aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase\",\n \"sqlite\": \"aeiva.storage.sqlite.sqlite_database.SQLiteDatabase\",\n \"postgresql\": \"aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase\",\n \"weaviate\": \"aeiva.storage.weaviate.weaviate_database.WeaviateDatabase\",\n }\n\n @classmethod\n def create(cls, provider_name: str, config: Any) -> Any:\n \"\"\"\n Create a database object based on the provider name and configuration.\n\n Args:\n provider_name (str): The name of the database provider.\n config (Any): Configuration object or dictionary for the database.\n\n Returns:\n Any: An instance of the database class.\n\n Raises:\n ValueError: If the provider name is not supported.\n ImportError: If the database class cannot be imported.\n TypeError: If the configuration cannot be converted to a dictionary.\n \"\"\"\n class_path = cls.provider_to_class.get(provider_name.lower())\n if class_path:\n db_class = load_class(class_path)\n if isinstance(config, dict):\n return db_class(config)\n elif hasattr(config, 'to_dict'):\n # Assuming config is a dataclass with a 'to_dict' method\n return db_class(config.to_dict())\n elif hasattr(config, '__dict__'):\n # If config is a dataclass without 'to_dict', use __dict__\n return db_class(config.__dict__)\n else:\n raise TypeError(\n \"Config must be a dict or an object with 'to_dict' or '__dict__' method.\"\n )\n else:\n raise ValueError(f\"Unsupported database provider: {provider_name}\")\n
"},{"location":"reference/#src.aeiva.storage.database_factory.DatabaseFactory.create","title":"create(provider_name, config)
classmethod
","text":"Create a database object based on the provider name and configuration.
Parameters:
Name Type Description Default provider_name
str
The name of the database provider.
required config
Any
Configuration object or dictionary for the database.
required Returns:
Name Type Description Any
Any
An instance of the database class.
Raises:
Type Description ValueError
If the provider name is not supported.
ImportError
If the database class cannot be imported.
TypeError
If the configuration cannot be converted to a dictionary.
Source code in src/aeiva/storage/database_factory.py
@classmethod\ndef create(cls, provider_name: str, config: Any) -> Any:\n \"\"\"\n Create a database object based on the provider name and configuration.\n\n Args:\n provider_name (str): The name of the database provider.\n config (Any): Configuration object or dictionary for the database.\n\n Returns:\n Any: An instance of the database class.\n\n Raises:\n ValueError: If the provider name is not supported.\n ImportError: If the database class cannot be imported.\n TypeError: If the configuration cannot be converted to a dictionary.\n \"\"\"\n class_path = cls.provider_to_class.get(provider_name.lower())\n if class_path:\n db_class = load_class(class_path)\n if isinstance(config, dict):\n return db_class(config)\n elif hasattr(config, 'to_dict'):\n # Assuming config is a dataclass with a 'to_dict' method\n return db_class(config.to_dict())\n elif hasattr(config, '__dict__'):\n # If config is a dataclass without 'to_dict', use __dict__\n return db_class(config.__dict__)\n else:\n raise TypeError(\n \"Config must be a dict or an object with 'to_dict' or '__dict__' method.\"\n )\n else:\n raise ValueError(f\"Unsupported database provider: {provider_name}\")\n
"},{"location":"reference/#src.aeiva.storage.database_factory.load_class","title":"load_class(class_path)
","text":"Dynamically load a class from a string.
Parameters:
Name Type Description Default class_path
str
The full path to the class, e.g., 'module.submodule.ClassName'.
required Returns:
Name Type Description Type
Type
The class type.
Raises:
Type Description ImportError
If the module or class cannot be found.
Source code in src/aeiva/storage/database_factory.py
def load_class(class_path: str) -> Type:\n \"\"\"\n Dynamically load a class from a string.\n\n Args:\n class_path (str): The full path to the class, e.g., 'module.submodule.ClassName'.\n\n Returns:\n Type: The class type.\n\n Raises:\n ImportError: If the module or class cannot be found.\n \"\"\"\n try:\n module_path, class_name = class_path.rsplit('.', 1)\n module = importlib.import_module(module_path)\n return getattr(module, class_name)\n except (ImportError, AttributeError) as e:\n raise ImportError(f\"Cannot import '{class_name}' from '{module_path}': {e}\")\n
"},{"location":"reference/#src.aeiva.storage.graph_database","title":"graph_database
","text":""},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase","title":"GraphDatabase
","text":" Bases: ABC
Abstract base class for graph database operations.
Source code in src/aeiva/storage/graph_database.py
class GraphDatabase(ABC):\n \"\"\"\n Abstract base class for graph database operations.\n \"\"\"\n\n @abstractmethod\n def add_node(\n self, \n node_id: str, \n properties: Optional[Dict[str, Any]] = None, \n labels: Optional[List[str]] = None\n ) -> None:\n \"\"\"\n Adds a node to the graph.\n\n Args:\n node_id (str): Unique identifier for the node.\n properties (Optional[Dict[str, Any]]): Properties associated with the node.\n labels (Optional[List[str]]): Labels or types associated with the node.\n\n Raises:\n StorageError: If there is an issue adding the node.\n \"\"\"\n pass\n\n @abstractmethod\n def add_edge(\n self, \n source_id: str, \n target_id: str, \n relationship: str, \n properties: Optional[Dict[str, Any]] = None\n ) -> None:\n \"\"\"\n Adds an edge (relationship) between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship.\n properties (Optional[Dict[str, Any]]): Properties associated with the edge.\n\n Raises:\n NodeNotFoundError: If either the source or target node does not exist.\n StorageError: If there is an issue adding the edge.\n \"\"\"\n pass\n\n @abstractmethod\n def get_node(self, node_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieves a node by its identifier.\n\n Args:\n node_id (str): Unique identifier of the node.\n\n Returns:\n Dict[str, Any]: A dictionary containing the node's properties and labels.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue retrieving the node.\n \"\"\"\n pass\n\n @abstractmethod\n def update_node(self, node_id: str, properties: Dict[str, Any]) -> None:\n \"\"\"\n Updates properties of a node.\n\n Args:\n node_id (str): Unique identifier of the node.\n properties (Dict[str, Any]): Properties to update.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue updating the node.\n \"\"\"\n pass\n\n @abstractmethod\n def delete_node(self, node_id: str) -> None:\n \"\"\"\n Deletes a node from the graph.\n\n Args:\n node_id (str): Unique identifier of the node.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue deleting the node.\n \"\"\"\n pass\n\n @abstractmethod\n def delete_all(self) -> None:\n \"\"\"\n Deletes all nodes and their associated relationships from the graph.\n\n Raises:\n StorageError: If there is an issue deleting all nodes and relationships.\n \"\"\"\n pass\n\n @abstractmethod\n def delete_all_edges(self) -> None:\n \"\"\"\n Deletes all edges from the graph without deleting the nodes.\n\n Raises:\n StorageError: If there is an issue deleting all relationships.\n \"\"\"\n pass\n\n @abstractmethod\n def delete_relationships_by_type(self, relationship: str) -> None:\n \"\"\"\n Deletes all relationships of a specific type from the graph.\n\n Args:\n relationship (str): The type of relationships to delete.\n\n Raises:\n StorageError: If there is an issue deleting the relationships.\n \"\"\"\n pass\n\n @abstractmethod\n def delete_edge(\n self, \n source_id: str, \n target_id: str, \n relationship: str\n ) -> None:\n \"\"\"\n Deletes a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to delete.\n\n Raises:\n RelationshipNotFoundError: If the relationship does not exist.\n StorageError: If there is an issue deleting the relationship.\n \"\"\"\n pass\n\n @abstractmethod\n def update_edge(\n self, \n source_id: str, \n target_id: str, \n relationship: str, \n properties: Dict[str, Any]\n ) -> None:\n \"\"\"\n Updates properties of a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to update.\n properties (Dict[str, Any]): Properties to update on the relationship.\n\n Raises:\n RelationshipNotFoundError: If the relationship does not exist.\n StorageError: If there is an issue updating the relationship.\n \"\"\"\n pass\n\n @abstractmethod\n def get_relationship(\n self, \n source_id: str, \n target_id: str, \n relationship: str\n ) -> Dict[str, Any]:\n \"\"\"\n Retrieves a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to retrieve.\n\n Returns:\n Dict[str, Any]: A dictionary containing the relationship's properties.\n\n Raises:\n RelationshipNotFoundError: If the relationship does not exist.\n StorageError: If there is an issue retrieving the relationship.\n \"\"\"\n pass\n\n @abstractmethod\n def get_neighbors(\n self, \n node_id: str, \n relationship: Optional[str] = None, \n direction: str = \"both\"\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Retrieves neighboring nodes connected by edges.\n\n Args:\n node_id (str): Unique identifier of the node.\n relationship (Optional[str]): Filter by relationship type.\n direction (str): Direction of the relationships ('in', 'out', 'both').\n\n Returns:\n List[Dict[str, Any]]: A list of neighboring nodes.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue retrieving neighbors.\n \"\"\"\n pass\n\n @abstractmethod\n def query_nodes(\n self, \n properties: Dict[str, Any], \n labels: Optional[List[str]] = None\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Queries nodes based on properties and labels.\n\n Args:\n properties (Dict[str, Any]): Properties to filter nodes.\n labels (Optional[List[str]]): Labels to filter nodes.\n\n Returns:\n List[Dict[str, Any]]: A list of nodes matching the query.\n\n Raises:\n StorageError: If there is an issue querying nodes.\n \"\"\"\n pass\n\n @abstractmethod\n def execute_query(\n self, \n query: str, \n parameters: Optional[Dict[str, Any]] = None\n ) -> Any:\n \"\"\"\n Executes a raw query against the graph database.\n\n Args:\n query (str): The query string.\n parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.\n\n Returns:\n Any: The result of the query.\n\n Raises:\n StorageError: If there is an issue executing the query.\n \"\"\"\n pass\n\n @abstractmethod\n def close(self) -> None:\n \"\"\"\n Closes the graph database connection and releases resources.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.add_edge","title":"add_edge(source_id, target_id, relationship, properties=None)
abstractmethod
","text":"Adds an edge (relationship) between two nodes.
Parameters:
Name Type Description Default source_id
str
Unique identifier of the source node.
required target_id
str
Unique identifier of the target node.
required relationship
str
Type of the relationship.
required properties
Optional[Dict[str, Any]]
Properties associated with the edge.
None
Raises:
Type Description NodeNotFoundError
If either the source or target node does not exist.
StorageError
If there is an issue adding the edge.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef add_edge(\n self, \n source_id: str, \n target_id: str, \n relationship: str, \n properties: Optional[Dict[str, Any]] = None\n) -> None:\n \"\"\"\n Adds an edge (relationship) between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship.\n properties (Optional[Dict[str, Any]]): Properties associated with the edge.\n\n Raises:\n NodeNotFoundError: If either the source or target node does not exist.\n StorageError: If there is an issue adding the edge.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.add_node","title":"add_node(node_id, properties=None, labels=None)
abstractmethod
","text":"Adds a node to the graph.
Parameters:
Name Type Description Default node_id
str
Unique identifier for the node.
required properties
Optional[Dict[str, Any]]
Properties associated with the node.
None
labels
Optional[List[str]]
Labels or types associated with the node.
None
Raises:
Type Description StorageError
If there is an issue adding the node.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef add_node(\n self, \n node_id: str, \n properties: Optional[Dict[str, Any]] = None, \n labels: Optional[List[str]] = None\n) -> None:\n \"\"\"\n Adds a node to the graph.\n\n Args:\n node_id (str): Unique identifier for the node.\n properties (Optional[Dict[str, Any]]): Properties associated with the node.\n labels (Optional[List[str]]): Labels or types associated with the node.\n\n Raises:\n StorageError: If there is an issue adding the node.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.close","title":"close()
abstractmethod
","text":"Closes the graph database connection and releases resources.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef close(self) -> None:\n \"\"\"\n Closes the graph database connection and releases resources.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.delete_all","title":"delete_all()
abstractmethod
","text":"Deletes all nodes and their associated relationships from the graph.
Raises:
Type Description StorageError
If there is an issue deleting all nodes and relationships.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef delete_all(self) -> None:\n \"\"\"\n Deletes all nodes and their associated relationships from the graph.\n\n Raises:\n StorageError: If there is an issue deleting all nodes and relationships.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.delete_all_edges","title":"delete_all_edges()
abstractmethod
","text":"Deletes all edges from the graph without deleting the nodes.
Raises:
Type Description StorageError
If there is an issue deleting all relationships.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef delete_all_edges(self) -> None:\n \"\"\"\n Deletes all edges from the graph without deleting the nodes.\n\n Raises:\n StorageError: If there is an issue deleting all relationships.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.delete_edge","title":"delete_edge(source_id, target_id, relationship)
abstractmethod
","text":"Deletes a specific relationship between two nodes.
Parameters:
Name Type Description Default source_id
str
Unique identifier of the source node.
required target_id
str
Unique identifier of the target node.
required relationship
str
Type of the relationship to delete.
required Raises:
Type Description RelationshipNotFoundError
If the relationship does not exist.
StorageError
If there is an issue deleting the relationship.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef delete_edge(\n self, \n source_id: str, \n target_id: str, \n relationship: str\n) -> None:\n \"\"\"\n Deletes a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to delete.\n\n Raises:\n RelationshipNotFoundError: If the relationship does not exist.\n StorageError: If there is an issue deleting the relationship.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.delete_node","title":"delete_node(node_id)
abstractmethod
","text":"Deletes a node from the graph.
Parameters:
Name Type Description Default node_id
str
Unique identifier of the node.
required Raises:
Type Description NodeNotFoundError
If the node does not exist.
StorageError
If there is an issue deleting the node.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef delete_node(self, node_id: str) -> None:\n \"\"\"\n Deletes a node from the graph.\n\n Args:\n node_id (str): Unique identifier of the node.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue deleting the node.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.delete_relationships_by_type","title":"delete_relationships_by_type(relationship)
abstractmethod
","text":"Deletes all relationships of a specific type from the graph.
Parameters:
Name Type Description Default relationship
str
The type of relationships to delete.
required Raises:
Type Description StorageError
If there is an issue deleting the relationships.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef delete_relationships_by_type(self, relationship: str) -> None:\n \"\"\"\n Deletes all relationships of a specific type from the graph.\n\n Args:\n relationship (str): The type of relationships to delete.\n\n Raises:\n StorageError: If there is an issue deleting the relationships.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.execute_query","title":"execute_query(query, parameters=None)
abstractmethod
","text":"Executes a raw query against the graph database.
Parameters:
Name Type Description Default query
str
The query string.
required parameters
Optional[Dict[str, Any]]
Parameters for parameterized queries.
None
Returns:
Name Type Description Any
Any
The result of the query.
Raises:
Type Description StorageError
If there is an issue executing the query.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef execute_query(\n self, \n query: str, \n parameters: Optional[Dict[str, Any]] = None\n) -> Any:\n \"\"\"\n Executes a raw query against the graph database.\n\n Args:\n query (str): The query string.\n parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.\n\n Returns:\n Any: The result of the query.\n\n Raises:\n StorageError: If there is an issue executing the query.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.get_neighbors","title":"get_neighbors(node_id, relationship=None, direction='both')
abstractmethod
","text":"Retrieves neighboring nodes connected by edges.
Parameters:
Name Type Description Default node_id
str
Unique identifier of the node.
required relationship
Optional[str]
Filter by relationship type.
None
direction
str
Direction of the relationships ('in', 'out', 'both').
'both'
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of neighboring nodes.
Raises:
Type Description NodeNotFoundError
If the node does not exist.
StorageError
If there is an issue retrieving neighbors.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef get_neighbors(\n self, \n node_id: str, \n relationship: Optional[str] = None, \n direction: str = \"both\"\n) -> List[Dict[str, Any]]:\n \"\"\"\n Retrieves neighboring nodes connected by edges.\n\n Args:\n node_id (str): Unique identifier of the node.\n relationship (Optional[str]): Filter by relationship type.\n direction (str): Direction of the relationships ('in', 'out', 'both').\n\n Returns:\n List[Dict[str, Any]]: A list of neighboring nodes.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue retrieving neighbors.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.get_node","title":"get_node(node_id)
abstractmethod
","text":"Retrieves a node by its identifier.
Parameters:
Name Type Description Default node_id
str
Unique identifier of the node.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the node's properties and labels.
Raises:
Type Description NodeNotFoundError
If the node does not exist.
StorageError
If there is an issue retrieving the node.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef get_node(self, node_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieves a node by its identifier.\n\n Args:\n node_id (str): Unique identifier of the node.\n\n Returns:\n Dict[str, Any]: A dictionary containing the node's properties and labels.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue retrieving the node.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.get_relationship","title":"get_relationship(source_id, target_id, relationship)
abstractmethod
","text":"Retrieves a specific relationship between two nodes.
Parameters:
Name Type Description Default source_id
str
Unique identifier of the source node.
required target_id
str
Unique identifier of the target node.
required relationship
str
Type of the relationship to retrieve.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the relationship's properties.
Raises:
Type Description RelationshipNotFoundError
If the relationship does not exist.
StorageError
If there is an issue retrieving the relationship.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef get_relationship(\n self, \n source_id: str, \n target_id: str, \n relationship: str\n) -> Dict[str, Any]:\n \"\"\"\n Retrieves a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to retrieve.\n\n Returns:\n Dict[str, Any]: A dictionary containing the relationship's properties.\n\n Raises:\n RelationshipNotFoundError: If the relationship does not exist.\n StorageError: If there is an issue retrieving the relationship.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.query_nodes","title":"query_nodes(properties, labels=None)
abstractmethod
","text":"Queries nodes based on properties and labels.
Parameters:
Name Type Description Default properties
Dict[str, Any]
Properties to filter nodes.
required labels
Optional[List[str]]
Labels to filter nodes.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of nodes matching the query.
Raises:
Type Description StorageError
If there is an issue querying nodes.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef query_nodes(\n self, \n properties: Dict[str, Any], \n labels: Optional[List[str]] = None\n) -> List[Dict[str, Any]]:\n \"\"\"\n Queries nodes based on properties and labels.\n\n Args:\n properties (Dict[str, Any]): Properties to filter nodes.\n labels (Optional[List[str]]): Labels to filter nodes.\n\n Returns:\n List[Dict[str, Any]]: A list of nodes matching the query.\n\n Raises:\n StorageError: If there is an issue querying nodes.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.update_edge","title":"update_edge(source_id, target_id, relationship, properties)
abstractmethod
","text":"Updates properties of a specific relationship between two nodes.
Parameters:
Name Type Description Default source_id
str
Unique identifier of the source node.
required target_id
str
Unique identifier of the target node.
required relationship
str
Type of the relationship to update.
required properties
Dict[str, Any]
Properties to update on the relationship.
required Raises:
Type Description RelationshipNotFoundError
If the relationship does not exist.
StorageError
If there is an issue updating the relationship.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef update_edge(\n self, \n source_id: str, \n target_id: str, \n relationship: str, \n properties: Dict[str, Any]\n) -> None:\n \"\"\"\n Updates properties of a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to update.\n properties (Dict[str, Any]): Properties to update on the relationship.\n\n Raises:\n RelationshipNotFoundError: If the relationship does not exist.\n StorageError: If there is an issue updating the relationship.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.GraphDatabase.update_node","title":"update_node(node_id, properties)
abstractmethod
","text":"Updates properties of a node.
Parameters:
Name Type Description Default node_id
str
Unique identifier of the node.
required properties
Dict[str, Any]
Properties to update.
required Raises:
Type Description NodeNotFoundError
If the node does not exist.
StorageError
If there is an issue updating the node.
Source code in src/aeiva/storage/graph_database.py
@abstractmethod\ndef update_node(self, node_id: str, properties: Dict[str, Any]) -> None:\n \"\"\"\n Updates properties of a node.\n\n Args:\n node_id (str): Unique identifier of the node.\n properties (Dict[str, Any]): Properties to update.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue updating the node.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.NodeNotFoundError","title":"NodeNotFoundError
","text":" Bases: Exception
Exception raised when a node is not found in the graph database.
Source code in src/aeiva/storage/graph_database.py
class NodeNotFoundError(Exception):\n \"\"\"Exception raised when a node is not found in the graph database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.RelationshipNotFoundError","title":"RelationshipNotFoundError
","text":" Bases: Exception
Exception raised when a relationship is not found in the graph database.
Source code in src/aeiva/storage/graph_database.py
class RelationshipNotFoundError(Exception):\n \"\"\"Exception raised when a relationship is not found in the graph database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.graph_database.StorageError","title":"StorageError
","text":" Bases: Exception
Exception raised when there is a storage-related error in the graph database.
Source code in src/aeiva/storage/graph_database.py
class StorageError(Exception):\n \"\"\"Exception raised when there is a storage-related error in the graph database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.milvus","title":"milvus
","text":""},{"location":"reference/#src.aeiva.storage.milvus.milvus_config","title":"milvus_config
","text":""},{"location":"reference/#src.aeiva.storage.milvus.milvus_config.MilvusConfig","title":"MilvusConfig
dataclass
","text":" Bases: BaseConfig
Configuration for Milvus vector database.
Source code in src/aeiva/storage/milvus/milvus_config.py
@dataclass\nclass MilvusConfig(BaseConfig):\n \"\"\"\n Configuration for Milvus vector database.\n \"\"\"\n\n uri: str = field(\n default=\"http://localhost:19530\",\n metadata={\"help\": \"Full URL for Milvus server.\"}\n )\n token: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Token for Milvus server authentication (if required).\"}\n )\n collection_name: str = field(\n default=\"mem0\",\n metadata={\"help\": \"Name of the collection.\"}\n )\n embedding_model_dims: int = field(\n default=1536,\n metadata={\"help\": \"Dimensions of the embedding model.\"}\n )\n metric_type: str = field(\n default=\"L2\",\n metadata={\"help\": \"Metric type for similarity search (e.g., 'L2', 'IP', 'COSINE').\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n # Validate metric_type\n valid_metrics = {\"L2\", \"IP\", \"COSINE\", \"HAMMING\", \"JACCARD\"}\n if self.metric_type not in valid_metrics:\n raise ValueError(f\"Invalid metric_type '{self.metric_type}'. Valid options are {valid_metrics}.\")\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database","title":"milvus_database
","text":""},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase","title":"MilvusDatabase
","text":" Bases: VectorDatabase
Concrete implementation of VectorStoreBase using Milvus.
Source code in src/aeiva/storage/milvus/milvus_database.py
class MilvusDatabase(VectorDatabase):\n \"\"\"\n Concrete implementation of VectorStoreBase using Milvus.\n \"\"\"\n\n def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the Milvus vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.collection_name = config.get('collection_name')\n self.uri = config.get('uri')\n self.user = config.get('user')\n self.password = config.get('password')\n self.token = config.get('token')\n self.embedding_model_dims = config.get('embedding_model_dims')\n self.metric_type = config.get('metric_type', 'L2') # Default to 'L2' metric\n\n if not all([self.collection_name, self.uri, self.embedding_model_dims]):\n raise ValueError(\"Required configuration parameters are missing.\")\n\n self.create_client(\n uri=self.uri,\n user=self.user,\n password=self.password,\n token=self.token\n )\n self.create_collection(\n collection_name=self.collection_name,\n vector_size=self.embedding_model_dims,\n distance_metric=self.metric_type\n )\n\n def create_client(\n self,\n uri: str,\n user: Optional[str] = None,\n password: Optional[str] = None,\n token: Optional[str] = None,\n **kwargs\n ) -> None:\n \"\"\"\n Initializes the client connection to the Milvus vector store.\n\n Args:\n uri (str): The URI of the vector store instance.\n user (Optional[str]): Username for authentication.\n password (Optional[str]): Password for authentication.\n token (Optional[str]): Access token for authentication.\n **kwargs: Additional parameters.\n \"\"\"\n try:\n connections.connect(\n alias=\"default\",\n uri=uri,\n user=user,\n password=password,\n token=token,\n **kwargs\n )\n logger.info(f\"Connected to Milvus at {uri}.\")\n except MilvusException as e:\n logger.error(f\"Failed to connect to Milvus: {e}\")\n raise ConnectionError(f\"Failed to connect to Milvus: {e}\")\n\n def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection in Milvus.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use (e.g., 'L2', 'IP', 'COSINE').\n \"\"\"\n if utility.has_collection(collection_name):\n logger.info(f\"Collection {collection_name} already exists. Skipping creation.\")\n self.collection = Collection(collection_name)\n return\n\n # Define the schema\n fields = [\n FieldSchema(name=\"id\", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=64),\n FieldSchema(name=\"vector\", dtype=DataType.FLOAT_VECTOR, dim=vector_size),\n FieldSchema(name=\"payload\", dtype=DataType.JSON)\n ]\n schema = CollectionSchema(fields=fields, description=\"Milvus Vector Store Collection\")\n\n # Create the collection\n self.collection = Collection(name=collection_name, schema=schema)\n logger.info(f\"Collection {collection_name} created successfully.\")\n\n # Create index\n index_params = {\n \"metric_type\": distance_metric,\n \"index_type\": \"AUTOINDEX\",\n \"params\": {}\n }\n self.collection.create_index(field_name=\"vector\", index_params=index_params)\n logger.info(f\"Index created on collection {collection_name}.\")\n\n def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n ) -> None:\n \"\"\"\n Insert vectors into a collection.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n if ids is None:\n raise ValueError(\"Milvus requires IDs to be provided for each vector.\")\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n data = [\n ids,\n vectors,\n payloads\n ]\n self.collection.insert(data)\n logger.info(f\"Inserted {len(vectors)} vectors into collection {collection_name}.\")\n\n def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n search_params = {\n \"metric_type\": self.metric_type,\n \"params\": {}\n }\n\n expr = self._build_filter_expression(filters)\n results = self.collection.search(\n data=[query_vector],\n anns_field=\"vector\",\n param=search_params,\n limit=top_k,\n expr=expr,\n output_fields=[\"id\", \"payload\"]\n )\n\n output = []\n for hits in results:\n for hit in hits:\n result = {\n 'id': hit.entity.get('id'),\n 'score': hit.distance,\n 'payload': hit.entity.get('payload')\n }\n output.append(result)\n return output\n\n def _build_filter_expression(self, filters: Optional[Dict[str, Any]]) -> str:\n \"\"\"\n Build an expression string for filtering in Milvus.\n\n Args:\n filters (Optional[Dict[str, Any]]): Filters to apply.\n\n Returns:\n str: The expression string.\n \"\"\"\n if not filters:\n return \"\"\n\n expressions = []\n for key, value in filters.items():\n if isinstance(value, str):\n expressions.append(f'payload[\"{key}\"] == \"{value}\"')\n else:\n expressions.append(f'payload[\"{key}\"] == {value}')\n expr = \" and \".join(expressions)\n return expr\n\n def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n expr = f'id == \"{vector_id}\"'\n self.collection.delete(expr)\n logger.info(f\"Deleted vector with ID {vector_id} from collection {collection_name}.\")\n\n def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n ) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n # Milvus doesn't support direct updates; need to delete and re-insert\n # Fetch existing vector and payload\n expr = f'id == \"{vector_id}\"'\n results = self.collection.query(expr=expr, output_fields=[\"vector\", \"payload\"])\n\n if not results:\n raise KeyError(f\"Vector with ID {vector_id} not found in collection {collection_name}.\")\n\n existing_vector = results[0]['vector']\n existing_payload = results[0]['payload']\n\n new_vector = vector if vector is not None else existing_vector\n new_payload = payload if payload is not None else existing_payload\n\n # Delete the existing vector\n self.collection.delete(expr)\n\n # Re-insert with updated data\n self.insert_vectors(\n collection_name=collection_name,\n vectors=[new_vector],\n payloads=[new_payload],\n ids=[vector_id]\n )\n logger.info(f\"Updated vector with ID {vector_id} in collection {collection_name}.\")\n\n def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n expr = f'id == \"{vector_id}\"'\n results = self.collection.query(expr=expr, output_fields=[\"vector\", \"payload\"])\n\n if not results:\n raise KeyError(f\"Vector with ID {vector_id} not found in collection {collection_name}.\")\n\n vector_data = {\n 'id': vector_id,\n 'vector': results[0]['vector'],\n 'payload': results[0]['payload']\n }\n return vector_data\n\n def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections.\n\n Returns:\n List[str]: A list of collection names.\n \"\"\"\n return utility.list_collections()\n\n def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n \"\"\"\n self.collection.drop()\n logger.info(f\"Deleted collection {collection_name}.\")\n\n def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n info = self.collection.describe()\n return info\n\n def __del__(self):\n \"\"\"Clean up resources.\"\"\"\n connections.disconnect(\"default\")\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.__del__","title":"__del__()
","text":"Clean up resources.
Source code in src/aeiva/storage/milvus/milvus_database.py
def __del__(self):\n \"\"\"Clean up resources.\"\"\"\n connections.disconnect(\"default\")\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.__init__","title":"__init__(config)
","text":"Initialize the Milvus vector store.
Parameters:
Name Type Description Default config
Dict[str, Any]
Configuration dictionary.
required Source code in src/aeiva/storage/milvus/milvus_database.py
def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the Milvus vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.collection_name = config.get('collection_name')\n self.uri = config.get('uri')\n self.user = config.get('user')\n self.password = config.get('password')\n self.token = config.get('token')\n self.embedding_model_dims = config.get('embedding_model_dims')\n self.metric_type = config.get('metric_type', 'L2') # Default to 'L2' metric\n\n if not all([self.collection_name, self.uri, self.embedding_model_dims]):\n raise ValueError(\"Required configuration parameters are missing.\")\n\n self.create_client(\n uri=self.uri,\n user=self.user,\n password=self.password,\n token=self.token\n )\n self.create_collection(\n collection_name=self.collection_name,\n vector_size=self.embedding_model_dims,\n distance_metric=self.metric_type\n )\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.create_client","title":"create_client(uri, user=None, password=None, token=None, **kwargs)
","text":"Initializes the client connection to the Milvus vector store.
Parameters:
Name Type Description Default uri
str
The URI of the vector store instance.
required user
Optional[str]
Username for authentication.
None
password
Optional[str]
Password for authentication.
None
token
Optional[str]
Access token for authentication.
None
**kwargs
Additional parameters.
{}
Source code in src/aeiva/storage/milvus/milvus_database.py
def create_client(\n self,\n uri: str,\n user: Optional[str] = None,\n password: Optional[str] = None,\n token: Optional[str] = None,\n **kwargs\n) -> None:\n \"\"\"\n Initializes the client connection to the Milvus vector store.\n\n Args:\n uri (str): The URI of the vector store instance.\n user (Optional[str]): Username for authentication.\n password (Optional[str]): Password for authentication.\n token (Optional[str]): Access token for authentication.\n **kwargs: Additional parameters.\n \"\"\"\n try:\n connections.connect(\n alias=\"default\",\n uri=uri,\n user=user,\n password=password,\n token=token,\n **kwargs\n )\n logger.info(f\"Connected to Milvus at {uri}.\")\n except MilvusException as e:\n logger.error(f\"Failed to connect to Milvus: {e}\")\n raise ConnectionError(f\"Failed to connect to Milvus: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.create_collection","title":"create_collection(collection_name, vector_size, distance_metric)
","text":"Create a new vector collection in Milvus.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_size
int
The dimensionality of the vectors.
required distance_metric
str
The distance metric to use (e.g., 'L2', 'IP', 'COSINE').
required Source code in src/aeiva/storage/milvus/milvus_database.py
def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection in Milvus.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use (e.g., 'L2', 'IP', 'COSINE').\n \"\"\"\n if utility.has_collection(collection_name):\n logger.info(f\"Collection {collection_name} already exists. Skipping creation.\")\n self.collection = Collection(collection_name)\n return\n\n # Define the schema\n fields = [\n FieldSchema(name=\"id\", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=64),\n FieldSchema(name=\"vector\", dtype=DataType.FLOAT_VECTOR, dim=vector_size),\n FieldSchema(name=\"payload\", dtype=DataType.JSON)\n ]\n schema = CollectionSchema(fields=fields, description=\"Milvus Vector Store Collection\")\n\n # Create the collection\n self.collection = Collection(name=collection_name, schema=schema)\n logger.info(f\"Collection {collection_name} created successfully.\")\n\n # Create index\n index_params = {\n \"metric_type\": distance_metric,\n \"index_type\": \"AUTOINDEX\",\n \"params\": {}\n }\n self.collection.create_index(field_name=\"vector\", index_params=index_params)\n logger.info(f\"Index created on collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.delete_collection","title":"delete_collection(collection_name)
","text":"Delete an entire vector collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection to delete.
required Source code in src/aeiva/storage/milvus/milvus_database.py
def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n \"\"\"\n self.collection.drop()\n logger.info(f\"Deleted collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.delete_vector","title":"delete_vector(collection_name, vector_id)
","text":"Delete a vector from a collection by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to delete.
required Source code in src/aeiva/storage/milvus/milvus_database.py
def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n expr = f'id == \"{vector_id}\"'\n self.collection.delete(expr)\n logger.info(f\"Deleted vector with ID {vector_id} from collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.get_collection_info","title":"get_collection_info(collection_name)
","text":"Get information about a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: Information about the collection.
Source code in src/aeiva/storage/milvus/milvus_database.py
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n info = self.collection.describe()\n return info\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.get_vector","title":"get_vector(collection_name, vector_id)
","text":"Retrieve a vector by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the vector data and payload.
Source code in src/aeiva/storage/milvus/milvus_database.py
def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n expr = f'id == \"{vector_id}\"'\n results = self.collection.query(expr=expr, output_fields=[\"vector\", \"payload\"])\n\n if not results:\n raise KeyError(f\"Vector with ID {vector_id} not found in collection {collection_name}.\")\n\n vector_data = {\n 'id': vector_id,\n 'vector': results[0]['vector'],\n 'payload': results[0]['payload']\n }\n return vector_data\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.insert_vectors","title":"insert_vectors(collection_name, vectors, payloads=None, ids=None)
","text":"Insert vectors into a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vectors
List[List[float]]
A list of vectors to insert.
required payloads
Optional[List[Dict[str, Any]]]
Optional metadata associated with each vector.
None
ids
Optional[List[str]]
Optional unique identifiers for each vector.
None
Source code in src/aeiva/storage/milvus/milvus_database.py
def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n) -> None:\n \"\"\"\n Insert vectors into a collection.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n if ids is None:\n raise ValueError(\"Milvus requires IDs to be provided for each vector.\")\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n data = [\n ids,\n vectors,\n payloads\n ]\n self.collection.insert(data)\n logger.info(f\"Inserted {len(vectors)} vectors into collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.list_collections","title":"list_collections()
","text":"List all available vector collections.
Returns:
Type Description List[str]
List[str]: A list of collection names.
Source code in src/aeiva/storage/milvus/milvus_database.py
def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections.\n\n Returns:\n List[str]: A list of collection names.\n \"\"\"\n return utility.list_collections()\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.search_vectors","title":"search_vectors(collection_name, query_vector, top_k=5, filters=None)
","text":"Search for similar vectors in a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required query_vector
List[float]
The vector to search with.
required top_k
int
The number of top results to return.
5
filters
Optional[Dict[str, Any]]
Optional filters to apply to the search.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of search results.
Source code in src/aeiva/storage/milvus/milvus_database.py
def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n search_params = {\n \"metric_type\": self.metric_type,\n \"params\": {}\n }\n\n expr = self._build_filter_expression(filters)\n results = self.collection.search(\n data=[query_vector],\n anns_field=\"vector\",\n param=search_params,\n limit=top_k,\n expr=expr,\n output_fields=[\"id\", \"payload\"]\n )\n\n output = []\n for hits in results:\n for hit in hits:\n result = {\n 'id': hit.entity.get('id'),\n 'score': hit.distance,\n 'payload': hit.entity.get('payload')\n }\n output.append(result)\n return output\n
"},{"location":"reference/#src.aeiva.storage.milvus.milvus_database.MilvusDatabase.update_vector","title":"update_vector(collection_name, vector_id, vector=None, payload=None)
","text":"Update a vector's data or payload.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to update.
required vector
Optional[List[float]]
The new vector data.
None
payload
Optional[Dict[str, Any]]
The new payload data.
None
Source code in src/aeiva/storage/milvus/milvus_database.py
def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n # Milvus doesn't support direct updates; need to delete and re-insert\n # Fetch existing vector and payload\n expr = f'id == \"{vector_id}\"'\n results = self.collection.query(expr=expr, output_fields=[\"vector\", \"payload\"])\n\n if not results:\n raise KeyError(f\"Vector with ID {vector_id} not found in collection {collection_name}.\")\n\n existing_vector = results[0]['vector']\n existing_payload = results[0]['payload']\n\n new_vector = vector if vector is not None else existing_vector\n new_payload = payload if payload is not None else existing_payload\n\n # Delete the existing vector\n self.collection.delete(expr)\n\n # Re-insert with updated data\n self.insert_vectors(\n collection_name=collection_name,\n vectors=[new_vector],\n payloads=[new_payload],\n ids=[vector_id]\n )\n logger.info(f\"Updated vector with ID {vector_id} in collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb","title":"neo4jdb
","text":""},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_config","title":"neo4j_config
","text":""},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_config.Neo4jConfig","title":"Neo4jConfig
dataclass
","text":" Bases: BaseConfig
Configuration for Neo4j graph database.
Source code in src/aeiva/storage/neo4jdb/neo4j_config.py
@dataclass\nclass Neo4jConfig(BaseConfig):\n \"\"\"\n Configuration for Neo4j graph database.\n \"\"\"\n\n uri: str = field(\n default=\"bolt://localhost:7687\",\n metadata={\"help\": \"URI for connecting to Neo4j (e.g., 'bolt://localhost:7687').\"}\n )\n user: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Username for Neo4j authentication.\"}\n )\n password: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Password for Neo4j authentication.\"}\n )\n database: Optional[str] = field(\n default=\"neo4j\",\n metadata={\"help\": \"Neo4j database name.\"}\n )\n encrypted: bool = field(\n default=True,\n metadata={\"help\": \"Whether to use encrypted connection (True or False).\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n if not self.user or not self.password:\n raise ValueError(\"Both 'user' and 'password' must be provided for Neo4j authentication.\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database","title":"neo4j_database
","text":""},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase","title":"Neo4jDatabase
","text":" Bases: GraphDatabase
Concrete implementation of GraphStoreBase using Neo4j.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
class Neo4jDatabase(GraphDatabase):\n \"\"\"\n Concrete implementation of GraphStoreBase using Neo4j.\n \"\"\"\n\n def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the Neo4j graph database connection.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.uri = config.get('uri')\n self.user = config.get('user')\n self.password = config.get('password')\n self.database = config.get('database', 'neo4j')\n self.encrypted = config.get('encrypted', True)\n\n if not all([self.uri, self.user, self.password]):\n raise ValueError(\"Required configuration parameters 'uri', 'user', and 'password' are missing.\")\n\n self.create_client(\n uri=self.uri,\n user=self.user,\n password=self.password,\n encrypted=self.encrypted\n )\n\n def create_client(\n self,\n uri: str,\n user: str,\n password: str,\n encrypted: bool = True,\n **kwargs\n ) -> None:\n \"\"\"\n Initializes the client connection to the Neo4j graph database.\n\n Args:\n uri (str): The URI of the Neo4j instance.\n user (str): Username for authentication.\n password (str): Password for authentication.\n encrypted (bool): Whether to use encrypted connection.\n **kwargs: Additional parameters.\n\n Raises:\n ConnectionError: If the client fails to connect to the graph database.\n \"\"\"\n try:\n auth = basic_auth(user, password)\n self.driver = Neo4jGraphDatabase.driver(uri, auth=auth, encrypted=encrypted, **kwargs)\n self.session = self.driver.session(database=self.database)\n logger.info(f\"Connected to Neo4j at {uri}.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to connect to Neo4j: {e}\")\n raise ConnectionError(f\"Failed to connect to Neo4j: {e}\")\n\n def add_node(\n self,\n node_id: str,\n properties: Optional[Dict[str, Any]] = None,\n labels: Optional[List[str]] = None\n ) -> None:\n \"\"\"\n Adds a node to the graph.\n\n Args:\n node_id (str): Unique identifier for the node.\n properties (Optional[Dict[str, Any]]): Properties associated with the node.\n labels (Optional[List[str]]): Labels or types associated with the node.\n\n Raises:\n StorageError: If there is an issue adding the node.\n \"\"\"\n properties = properties or {}\n labels = labels or []\n labels_str = ':' + ':'.join(labels) if labels else ''\n cypher = f\"MERGE (n{labels_str} {{id: $node_id}}) SET n += $properties\"\n params = {\n 'node_id': node_id,\n 'properties': properties\n }\n try:\n self.session.run(cypher, params)\n logger.info(f\"Node with id '{node_id}' added to the graph.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to add node: {e}\")\n raise StorageError(f\"Failed to add node: {e}\")\n\n def add_edge(\n self,\n source_id: str,\n target_id: str,\n relationship: str,\n properties: Optional[Dict[str, Any]] = None\n ) -> None:\n \"\"\"\n Adds an edge (relationship) between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship.\n properties (Optional[Dict[str, Any]]): Properties associated with the edge.\n\n Raises:\n NodeNotFoundError: If either the source or target node does not exist.\n StorageError: If there is an issue adding the edge.\n \"\"\"\n properties = properties or {}\n # First, check if both nodes exist\n cypher_check = \"MATCH (a {id: $source_id}), (b {id: $target_id}) RETURN a, b\"\n params = {\n 'source_id': source_id,\n 'target_id': target_id\n }\n try:\n result = self.session.run(cypher_check, params)\n record = result.single()\n if not record:\n missing_nodes = []\n # Check if source node exists\n node_a_exists = self.session.run(\"MATCH (a {id: $source_id}) RETURN a\", {'source_id': source_id}).single()\n if not node_a_exists:\n missing_nodes.append(source_id)\n # Check if target node exists\n node_b_exists = self.session.run(\"MATCH (b {id: $target_id}) RETURN b\", {'target_id': target_id}).single()\n if not node_b_exists:\n missing_nodes.append(target_id)\n logger.warning(f\"Node(s) with id(s) {missing_nodes} not found.\")\n raise NodeNotFoundError(f\"Node(s) with id(s) {missing_nodes} not found.\")\n # Proceed to add the edge\n cypher_edge = (\n \"MATCH (a {id: $source_id}), (b {id: $target_id}) \"\n f\"MERGE (a)-[r:{relationship}]->(b) \"\n \"SET r += $properties\"\n )\n params['properties'] = properties\n self.session.run(cypher_edge, params)\n logger.info(f\"Relationship '{relationship}' added between '{source_id}' and '{target_id}'.\")\n except NodeNotFoundError:\n raise\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to add edge: {e}\")\n raise StorageError(f\"Failed to add edge: {e}\")\n\n def get_node(self, node_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieves a node by its identifier.\n\n Args:\n node_id (str): Unique identifier of the node.\n\n Returns:\n Dict[str, Any]: A dictionary containing the node's properties and labels.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue retrieving the node.\n \"\"\"\n cypher = \"MATCH (n {id: $node_id}) RETURN n\"\n params = {'node_id': node_id}\n try:\n result = self.session.run(cypher, params)\n record = result.single()\n if record:\n node = record['n']\n node_data = {\n 'id': node['id'],\n 'properties': {k: v for k, v in node.items() if k != 'id'},\n 'labels': list(node.labels)\n }\n logger.info(f\"Node with id '{node_id}' retrieved.\")\n return node_data\n else:\n logger.warning(f\"Node with id '{node_id}' not found.\")\n raise NodeNotFoundError(f\"Node with id '{node_id}' not found.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to get node: {e}\")\n raise StorageError(f\"Failed to get node: {e}\")\n\n def update_node(self, node_id: str, properties: Dict[str, Any]) -> None:\n \"\"\"\n Updates properties of a node.\n\n Args:\n node_id (str): Unique identifier of the node.\n properties (Dict[str, Any]): Properties to update.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue updating the node.\n \"\"\"\n cypher = \"MATCH (n {id: $node_id}) SET n += $properties RETURN n\"\n params = {\n 'node_id': node_id,\n 'properties': properties\n }\n try:\n result = self.session.run(cypher, params)\n record = result.single()\n if record:\n logger.info(f\"Node with id '{node_id}' updated.\")\n else:\n logger.warning(f\"Node with id '{node_id}' not found.\")\n raise NodeNotFoundError(f\"Node with id '{node_id}' not found.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to update node: {e}\")\n raise StorageError(f\"Failed to update node: {e}\")\n\n def delete_node(self, node_id: str) -> None:\n \"\"\"\n Deletes a node from the graph.\n\n Args:\n node_id (str): Unique identifier of the node.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue deleting the node.\n \"\"\"\n cypher = \"MATCH (n {id: $node_id}) DETACH DELETE n RETURN COUNT(n) AS count\"\n params = {'node_id': node_id}\n try:\n result = self.session.run(cypher, params)\n record = result.single()\n if record and record['count'] > 0:\n logger.info(f\"Node with id '{node_id}' deleted.\")\n else:\n logger.warning(f\"Node with id '{node_id}' not found.\")\n raise NodeNotFoundError(f\"Node with id '{node_id}' not found.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to delete node: {e}\")\n raise StorageError(f\"Failed to delete node: {e}\")\n\n def delete_edge(\n self,\n source_id: str,\n target_id: str,\n relationship: str\n ) -> None:\n \"\"\"\n Deletes a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to delete.\n\n Raises:\n StorageError: If there is an issue deleting the relationship.\n \"\"\"\n cypher = (\n \"MATCH (a {id: $source_id})-[r:%s]->(b {id: $target_id}) \"\n \"DELETE r\"\n ) % relationship\n params = {\n 'source_id': source_id,\n 'target_id': target_id\n }\n try:\n result = self.session.run(cypher, params)\n if result.consume().counters.relationships_deleted == 0:\n logger.warning(f\"No relationship '{relationship}' found between '{source_id}' and '{target_id}'.\")\n raise StorageError(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.\")\n logger.info(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' deleted.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to delete relationship: {e}\")\n raise StorageError(f\"Failed to delete relationship: {e}\")\n\n def update_edge(\n self,\n source_id: str,\n target_id: str,\n relationship: str,\n properties: Dict[str, Any]\n ) -> None:\n \"\"\"\n Updates properties of a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to update.\n properties (Dict[str, Any]): Properties to update on the relationship.\n\n Raises:\n StorageError: If there is an issue updating the relationship.\n \"\"\"\n cypher = (\n \"MATCH (a {id: $source_id})-[r:%s]->(b {id: $target_id}) \"\n \"SET r += $properties RETURN r\"\n ) % relationship\n params = {\n 'source_id': source_id,\n 'target_id': target_id,\n 'properties': properties\n }\n try:\n result = self.session.run(cypher, params)\n record = result.single()\n if record:\n logger.info(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' updated with properties {properties}.\")\n else:\n logger.warning(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.\")\n raise StorageError(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to update relationship: {e}\")\n raise StorageError(f\"Failed to update relationship: {e}\")\n\n def get_relationship(\n self,\n source_id: str,\n target_id: str,\n relationship: str\n ) -> Dict[str, Any]:\n \"\"\"\n Retrieves a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to retrieve.\n\n Returns:\n Dict[str, Any]: A dictionary containing the relationship's properties.\n\n Raises:\n StorageError: If there is an issue retrieving the relationship.\n \"\"\"\n cypher = (\n \"MATCH (a {id: $source_id})-[r:%s]->(b {id: $target_id}) \"\n \"RETURN r\"\n ) % relationship\n params = {\n 'source_id': source_id,\n 'target_id': target_id\n }\n try:\n result = self.session.run(cypher, params)\n record = result.single()\n if record:\n relationship_data = record['r']\n properties = dict(relationship_data)\n properties['type'] = relationship.type # Include relationship type\n logger.info(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' retrieved.\")\n return properties\n else:\n logger.warning(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.\")\n raise StorageError(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to retrieve relationship: {e}\")\n raise StorageError(f\"Failed to retrieve relationship: {e}\")\n\n def delete_all_edges(self) -> None:\n \"\"\"\n Deletes all relationships from the Neo4j graph database without deleting nodes.\n\n Raises:\n StorageError: If there is an issue deleting relationships.\n \"\"\"\n cypher = \"MATCH ()-[r]->() DELETE r\"\n try:\n self.session.run(cypher)\n logger.info(\"All relationships have been deleted from Neo4j.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to delete all relationships: {e}\")\n raise StorageError(f\"Failed to delete all relationships: {e}\")\n\n def delete_relationships_by_type(self, relationship: str) -> None:\n \"\"\"\n Deletes all relationships of a specific type from the Neo4j graph database.\n\n Args:\n relationship (str): The type of relationships to delete.\n\n Raises:\n StorageError: If there is an issue deleting the relationships.\n \"\"\"\n cypher = f\"MATCH ()-[r:{relationship}]->() DELETE r\"\n try:\n self.session.run(cypher)\n logger.info(f\"All relationships of type '{relationship}' have been deleted from Neo4j.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to delete relationships of type '{relationship}': {e}\")\n raise StorageError(f\"Failed to delete relationships of type '{relationship}': {e}\")\n\n def delete_all(self) -> None:\n \"\"\"\n Deletes all nodes and relationships from the Neo4j graph database.\n\n Raises:\n StorageError: If there is an issue deleting all nodes and relationships.\n \"\"\"\n cypher = \"MATCH (n) DETACH DELETE n\"\n try:\n self.session.run(cypher)\n logger.info(\"All nodes and relationships have been deleted from Neo4j.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to delete all nodes and relationships: {e}\")\n raise StorageError(f\"Failed to delete all nodes and relationships: {e}\")\n\n def get_neighbors(\n self,\n node_id: str,\n relationship: Optional[str] = None,\n direction: str = \"both\"\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Retrieves neighboring nodes connected by edges.\n\n Args:\n node_id (str): Unique identifier of the node.\n relationship (Optional[str]): Filter by relationship type.\n direction (str): Direction of the relationships ('in', 'out', 'both').\n\n Returns:\n List[Dict[str, Any]]: A list of neighboring nodes.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue retrieving neighbors.\n \"\"\"\n if direction not in [\"in\", \"out\", \"both\"]:\n raise ValueError(\"Invalid direction. Must be 'in', 'out', or 'both'.\")\n\n rel_type = f\":{relationship}\" if relationship else ''\n if direction == \"in\":\n pattern = f\"<-[r{rel_type}]-\"\n elif direction == \"out\":\n pattern = f\"-[r{rel_type}]->\"\n else: # both\n pattern = f\"-[r{rel_type}]-\"\n\n cypher = f\"MATCH (n {{id: $node_id}}){pattern}(neighbor) RETURN neighbor\"\n params = {'node_id': node_id}\n try:\n # First, check if the node exists\n node_exists_query = \"MATCH (n {id: $node_id}) RETURN n\"\n node_result = self.session.run(node_exists_query, params)\n if not node_result.single():\n logger.warning(f\"Node with id '{node_id}' not found.\")\n raise NodeNotFoundError(f\"Node with id '{node_id}' not found.\")\n # Get neighbors\n result = self.session.run(cypher, params)\n neighbors = []\n for record in result:\n node = record['neighbor']\n neighbor_data = {\n 'id': node['id'],\n 'properties': {k: v for k, v in node.items() if k != 'id'},\n 'labels': list(node.labels)\n }\n neighbors.append(neighbor_data)\n logger.info(f\"Neighbors of node '{node_id}' retrieved.\")\n return neighbors\n except NodeNotFoundError:\n raise\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to get neighbors: {e}\")\n raise StorageError(f\"Failed to get neighbors: {e}\")\n\n def query_nodes(\n self,\n properties: Dict[str, Any],\n labels: Optional[List[str]] = None\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Queries nodes based on properties and labels.\n\n Args:\n properties (Dict[str, Any]): Properties to filter nodes.\n labels (Optional[List[str]]): Labels to filter nodes.\n\n Returns:\n List[Dict[str, Any]]: A list of nodes matching the query.\n\n Raises:\n StorageError: If there is an issue querying nodes.\n \"\"\"\n labels_str = ':' + ':'.join(labels) if labels else ''\n params = {}\n cypher = f\"MATCH (n{labels_str})\"\n\n if properties:\n props_conditions = ' AND '.join([f\"n.{key} = ${key}\" for key in properties.keys()])\n cypher += f\" WHERE {props_conditions}\"\n params.update(properties)\n\n cypher += \" RETURN n\"\n\n try:\n result = self.session.run(cypher, params)\n nodes = []\n for record in result:\n node = record['n']\n node_data = {\n 'id': node['id'],\n 'properties': {k: v for k, v in node.items() if k != 'id'},\n 'labels': list(node.labels)\n }\n nodes.append(node_data)\n logger.info(f\"Query returned {len(nodes)} nodes.\")\n return nodes\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to query nodes: {e}\")\n raise StorageError(f\"Failed to query nodes: {e}\")\n\n def execute_query(self, query: str, parameters: Optional[Dict[str, Any]] = None) -> Any:\n \"\"\"\n Executes a raw query against the graph database.\n\n Args:\n query (str): The query string.\n parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.\n\n Returns:\n Any: The result of the query.\n\n Raises:\n StorageError: If there is an issue executing the query.\n \"\"\"\n try:\n result = self.session.run(query, parameters)\n records = [record.data() for record in result]\n logger.info(f\"Executed query: {query}\")\n return records\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to execute query: {e}\")\n raise StorageError(f\"Failed to execute query: {e}\")\n\n def close(self) -> None:\n \"\"\"\n Closes the graph database connection and releases resources.\n \"\"\"\n if hasattr(self, 'session') and self.session:\n self.session.close()\n if hasattr(self, 'driver') and self.driver:\n self.driver.close()\n logger.info(\"Closed connection to Neo4j database.\")\n\n def __del__(self):\n \"\"\"Destructor to ensure resources are cleaned up.\"\"\"\n self.close()\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.__del__","title":"__del__()
","text":"Destructor to ensure resources are cleaned up.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def __del__(self):\n \"\"\"Destructor to ensure resources are cleaned up.\"\"\"\n self.close()\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.__init__","title":"__init__(config)
","text":"Initialize the Neo4j graph database connection.
Parameters:
Name Type Description Default config
Dict[str, Any]
Configuration dictionary.
required Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the Neo4j graph database connection.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.uri = config.get('uri')\n self.user = config.get('user')\n self.password = config.get('password')\n self.database = config.get('database', 'neo4j')\n self.encrypted = config.get('encrypted', True)\n\n if not all([self.uri, self.user, self.password]):\n raise ValueError(\"Required configuration parameters 'uri', 'user', and 'password' are missing.\")\n\n self.create_client(\n uri=self.uri,\n user=self.user,\n password=self.password,\n encrypted=self.encrypted\n )\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.add_edge","title":"add_edge(source_id, target_id, relationship, properties=None)
","text":"Adds an edge (relationship) between two nodes.
Parameters:
Name Type Description Default source_id
str
Unique identifier of the source node.
required target_id
str
Unique identifier of the target node.
required relationship
str
Type of the relationship.
required properties
Optional[Dict[str, Any]]
Properties associated with the edge.
None
Raises:
Type Description NodeNotFoundError
If either the source or target node does not exist.
StorageError
If there is an issue adding the edge.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def add_edge(\n self,\n source_id: str,\n target_id: str,\n relationship: str,\n properties: Optional[Dict[str, Any]] = None\n) -> None:\n \"\"\"\n Adds an edge (relationship) between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship.\n properties (Optional[Dict[str, Any]]): Properties associated with the edge.\n\n Raises:\n NodeNotFoundError: If either the source or target node does not exist.\n StorageError: If there is an issue adding the edge.\n \"\"\"\n properties = properties or {}\n # First, check if both nodes exist\n cypher_check = \"MATCH (a {id: $source_id}), (b {id: $target_id}) RETURN a, b\"\n params = {\n 'source_id': source_id,\n 'target_id': target_id\n }\n try:\n result = self.session.run(cypher_check, params)\n record = result.single()\n if not record:\n missing_nodes = []\n # Check if source node exists\n node_a_exists = self.session.run(\"MATCH (a {id: $source_id}) RETURN a\", {'source_id': source_id}).single()\n if not node_a_exists:\n missing_nodes.append(source_id)\n # Check if target node exists\n node_b_exists = self.session.run(\"MATCH (b {id: $target_id}) RETURN b\", {'target_id': target_id}).single()\n if not node_b_exists:\n missing_nodes.append(target_id)\n logger.warning(f\"Node(s) with id(s) {missing_nodes} not found.\")\n raise NodeNotFoundError(f\"Node(s) with id(s) {missing_nodes} not found.\")\n # Proceed to add the edge\n cypher_edge = (\n \"MATCH (a {id: $source_id}), (b {id: $target_id}) \"\n f\"MERGE (a)-[r:{relationship}]->(b) \"\n \"SET r += $properties\"\n )\n params['properties'] = properties\n self.session.run(cypher_edge, params)\n logger.info(f\"Relationship '{relationship}' added between '{source_id}' and '{target_id}'.\")\n except NodeNotFoundError:\n raise\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to add edge: {e}\")\n raise StorageError(f\"Failed to add edge: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.add_node","title":"add_node(node_id, properties=None, labels=None)
","text":"Adds a node to the graph.
Parameters:
Name Type Description Default node_id
str
Unique identifier for the node.
required properties
Optional[Dict[str, Any]]
Properties associated with the node.
None
labels
Optional[List[str]]
Labels or types associated with the node.
None
Raises:
Type Description StorageError
If there is an issue adding the node.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def add_node(\n self,\n node_id: str,\n properties: Optional[Dict[str, Any]] = None,\n labels: Optional[List[str]] = None\n) -> None:\n \"\"\"\n Adds a node to the graph.\n\n Args:\n node_id (str): Unique identifier for the node.\n properties (Optional[Dict[str, Any]]): Properties associated with the node.\n labels (Optional[List[str]]): Labels or types associated with the node.\n\n Raises:\n StorageError: If there is an issue adding the node.\n \"\"\"\n properties = properties or {}\n labels = labels or []\n labels_str = ':' + ':'.join(labels) if labels else ''\n cypher = f\"MERGE (n{labels_str} {{id: $node_id}}) SET n += $properties\"\n params = {\n 'node_id': node_id,\n 'properties': properties\n }\n try:\n self.session.run(cypher, params)\n logger.info(f\"Node with id '{node_id}' added to the graph.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to add node: {e}\")\n raise StorageError(f\"Failed to add node: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.close","title":"close()
","text":"Closes the graph database connection and releases resources.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def close(self) -> None:\n \"\"\"\n Closes the graph database connection and releases resources.\n \"\"\"\n if hasattr(self, 'session') and self.session:\n self.session.close()\n if hasattr(self, 'driver') and self.driver:\n self.driver.close()\n logger.info(\"Closed connection to Neo4j database.\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.create_client","title":"create_client(uri, user, password, encrypted=True, **kwargs)
","text":"Initializes the client connection to the Neo4j graph database.
Parameters:
Name Type Description Default uri
str
The URI of the Neo4j instance.
required user
str
Username for authentication.
required password
str
Password for authentication.
required encrypted
bool
Whether to use encrypted connection.
True
**kwargs
Additional parameters.
{}
Raises:
Type Description ConnectionError
If the client fails to connect to the graph database.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def create_client(\n self,\n uri: str,\n user: str,\n password: str,\n encrypted: bool = True,\n **kwargs\n) -> None:\n \"\"\"\n Initializes the client connection to the Neo4j graph database.\n\n Args:\n uri (str): The URI of the Neo4j instance.\n user (str): Username for authentication.\n password (str): Password for authentication.\n encrypted (bool): Whether to use encrypted connection.\n **kwargs: Additional parameters.\n\n Raises:\n ConnectionError: If the client fails to connect to the graph database.\n \"\"\"\n try:\n auth = basic_auth(user, password)\n self.driver = Neo4jGraphDatabase.driver(uri, auth=auth, encrypted=encrypted, **kwargs)\n self.session = self.driver.session(database=self.database)\n logger.info(f\"Connected to Neo4j at {uri}.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to connect to Neo4j: {e}\")\n raise ConnectionError(f\"Failed to connect to Neo4j: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.delete_all","title":"delete_all()
","text":"Deletes all nodes and relationships from the Neo4j graph database.
Raises:
Type Description StorageError
If there is an issue deleting all nodes and relationships.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def delete_all(self) -> None:\n \"\"\"\n Deletes all nodes and relationships from the Neo4j graph database.\n\n Raises:\n StorageError: If there is an issue deleting all nodes and relationships.\n \"\"\"\n cypher = \"MATCH (n) DETACH DELETE n\"\n try:\n self.session.run(cypher)\n logger.info(\"All nodes and relationships have been deleted from Neo4j.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to delete all nodes and relationships: {e}\")\n raise StorageError(f\"Failed to delete all nodes and relationships: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.delete_all_edges","title":"delete_all_edges()
","text":"Deletes all relationships from the Neo4j graph database without deleting nodes.
Raises:
Type Description StorageError
If there is an issue deleting relationships.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def delete_all_edges(self) -> None:\n \"\"\"\n Deletes all relationships from the Neo4j graph database without deleting nodes.\n\n Raises:\n StorageError: If there is an issue deleting relationships.\n \"\"\"\n cypher = \"MATCH ()-[r]->() DELETE r\"\n try:\n self.session.run(cypher)\n logger.info(\"All relationships have been deleted from Neo4j.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to delete all relationships: {e}\")\n raise StorageError(f\"Failed to delete all relationships: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.delete_edge","title":"delete_edge(source_id, target_id, relationship)
","text":"Deletes a specific relationship between two nodes.
Parameters:
Name Type Description Default source_id
str
Unique identifier of the source node.
required target_id
str
Unique identifier of the target node.
required relationship
str
Type of the relationship to delete.
required Raises:
Type Description StorageError
If there is an issue deleting the relationship.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def delete_edge(\n self,\n source_id: str,\n target_id: str,\n relationship: str\n) -> None:\n \"\"\"\n Deletes a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to delete.\n\n Raises:\n StorageError: If there is an issue deleting the relationship.\n \"\"\"\n cypher = (\n \"MATCH (a {id: $source_id})-[r:%s]->(b {id: $target_id}) \"\n \"DELETE r\"\n ) % relationship\n params = {\n 'source_id': source_id,\n 'target_id': target_id\n }\n try:\n result = self.session.run(cypher, params)\n if result.consume().counters.relationships_deleted == 0:\n logger.warning(f\"No relationship '{relationship}' found between '{source_id}' and '{target_id}'.\")\n raise StorageError(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.\")\n logger.info(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' deleted.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to delete relationship: {e}\")\n raise StorageError(f\"Failed to delete relationship: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.delete_node","title":"delete_node(node_id)
","text":"Deletes a node from the graph.
Parameters:
Name Type Description Default node_id
str
Unique identifier of the node.
required Raises:
Type Description NodeNotFoundError
If the node does not exist.
StorageError
If there is an issue deleting the node.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def delete_node(self, node_id: str) -> None:\n \"\"\"\n Deletes a node from the graph.\n\n Args:\n node_id (str): Unique identifier of the node.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue deleting the node.\n \"\"\"\n cypher = \"MATCH (n {id: $node_id}) DETACH DELETE n RETURN COUNT(n) AS count\"\n params = {'node_id': node_id}\n try:\n result = self.session.run(cypher, params)\n record = result.single()\n if record and record['count'] > 0:\n logger.info(f\"Node with id '{node_id}' deleted.\")\n else:\n logger.warning(f\"Node with id '{node_id}' not found.\")\n raise NodeNotFoundError(f\"Node with id '{node_id}' not found.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to delete node: {e}\")\n raise StorageError(f\"Failed to delete node: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.delete_relationships_by_type","title":"delete_relationships_by_type(relationship)
","text":"Deletes all relationships of a specific type from the Neo4j graph database.
Parameters:
Name Type Description Default relationship
str
The type of relationships to delete.
required Raises:
Type Description StorageError
If there is an issue deleting the relationships.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def delete_relationships_by_type(self, relationship: str) -> None:\n \"\"\"\n Deletes all relationships of a specific type from the Neo4j graph database.\n\n Args:\n relationship (str): The type of relationships to delete.\n\n Raises:\n StorageError: If there is an issue deleting the relationships.\n \"\"\"\n cypher = f\"MATCH ()-[r:{relationship}]->() DELETE r\"\n try:\n self.session.run(cypher)\n logger.info(f\"All relationships of type '{relationship}' have been deleted from Neo4j.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to delete relationships of type '{relationship}': {e}\")\n raise StorageError(f\"Failed to delete relationships of type '{relationship}': {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.execute_query","title":"execute_query(query, parameters=None)
","text":"Executes a raw query against the graph database.
Parameters:
Name Type Description Default query
str
The query string.
required parameters
Optional[Dict[str, Any]]
Parameters for parameterized queries.
None
Returns:
Name Type Description Any
Any
The result of the query.
Raises:
Type Description StorageError
If there is an issue executing the query.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def execute_query(self, query: str, parameters: Optional[Dict[str, Any]] = None) -> Any:\n \"\"\"\n Executes a raw query against the graph database.\n\n Args:\n query (str): The query string.\n parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.\n\n Returns:\n Any: The result of the query.\n\n Raises:\n StorageError: If there is an issue executing the query.\n \"\"\"\n try:\n result = self.session.run(query, parameters)\n records = [record.data() for record in result]\n logger.info(f\"Executed query: {query}\")\n return records\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to execute query: {e}\")\n raise StorageError(f\"Failed to execute query: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.get_neighbors","title":"get_neighbors(node_id, relationship=None, direction='both')
","text":"Retrieves neighboring nodes connected by edges.
Parameters:
Name Type Description Default node_id
str
Unique identifier of the node.
required relationship
Optional[str]
Filter by relationship type.
None
direction
str
Direction of the relationships ('in', 'out', 'both').
'both'
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of neighboring nodes.
Raises:
Type Description NodeNotFoundError
If the node does not exist.
StorageError
If there is an issue retrieving neighbors.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def get_neighbors(\n self,\n node_id: str,\n relationship: Optional[str] = None,\n direction: str = \"both\"\n) -> List[Dict[str, Any]]:\n \"\"\"\n Retrieves neighboring nodes connected by edges.\n\n Args:\n node_id (str): Unique identifier of the node.\n relationship (Optional[str]): Filter by relationship type.\n direction (str): Direction of the relationships ('in', 'out', 'both').\n\n Returns:\n List[Dict[str, Any]]: A list of neighboring nodes.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue retrieving neighbors.\n \"\"\"\n if direction not in [\"in\", \"out\", \"both\"]:\n raise ValueError(\"Invalid direction. Must be 'in', 'out', or 'both'.\")\n\n rel_type = f\":{relationship}\" if relationship else ''\n if direction == \"in\":\n pattern = f\"<-[r{rel_type}]-\"\n elif direction == \"out\":\n pattern = f\"-[r{rel_type}]->\"\n else: # both\n pattern = f\"-[r{rel_type}]-\"\n\n cypher = f\"MATCH (n {{id: $node_id}}){pattern}(neighbor) RETURN neighbor\"\n params = {'node_id': node_id}\n try:\n # First, check if the node exists\n node_exists_query = \"MATCH (n {id: $node_id}) RETURN n\"\n node_result = self.session.run(node_exists_query, params)\n if not node_result.single():\n logger.warning(f\"Node with id '{node_id}' not found.\")\n raise NodeNotFoundError(f\"Node with id '{node_id}' not found.\")\n # Get neighbors\n result = self.session.run(cypher, params)\n neighbors = []\n for record in result:\n node = record['neighbor']\n neighbor_data = {\n 'id': node['id'],\n 'properties': {k: v for k, v in node.items() if k != 'id'},\n 'labels': list(node.labels)\n }\n neighbors.append(neighbor_data)\n logger.info(f\"Neighbors of node '{node_id}' retrieved.\")\n return neighbors\n except NodeNotFoundError:\n raise\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to get neighbors: {e}\")\n raise StorageError(f\"Failed to get neighbors: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.get_node","title":"get_node(node_id)
","text":"Retrieves a node by its identifier.
Parameters:
Name Type Description Default node_id
str
Unique identifier of the node.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the node's properties and labels.
Raises:
Type Description NodeNotFoundError
If the node does not exist.
StorageError
If there is an issue retrieving the node.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def get_node(self, node_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieves a node by its identifier.\n\n Args:\n node_id (str): Unique identifier of the node.\n\n Returns:\n Dict[str, Any]: A dictionary containing the node's properties and labels.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue retrieving the node.\n \"\"\"\n cypher = \"MATCH (n {id: $node_id}) RETURN n\"\n params = {'node_id': node_id}\n try:\n result = self.session.run(cypher, params)\n record = result.single()\n if record:\n node = record['n']\n node_data = {\n 'id': node['id'],\n 'properties': {k: v for k, v in node.items() if k != 'id'},\n 'labels': list(node.labels)\n }\n logger.info(f\"Node with id '{node_id}' retrieved.\")\n return node_data\n else:\n logger.warning(f\"Node with id '{node_id}' not found.\")\n raise NodeNotFoundError(f\"Node with id '{node_id}' not found.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to get node: {e}\")\n raise StorageError(f\"Failed to get node: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.get_relationship","title":"get_relationship(source_id, target_id, relationship)
","text":"Retrieves a specific relationship between two nodes.
Parameters:
Name Type Description Default source_id
str
Unique identifier of the source node.
required target_id
str
Unique identifier of the target node.
required relationship
str
Type of the relationship to retrieve.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the relationship's properties.
Raises:
Type Description StorageError
If there is an issue retrieving the relationship.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def get_relationship(\n self,\n source_id: str,\n target_id: str,\n relationship: str\n) -> Dict[str, Any]:\n \"\"\"\n Retrieves a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to retrieve.\n\n Returns:\n Dict[str, Any]: A dictionary containing the relationship's properties.\n\n Raises:\n StorageError: If there is an issue retrieving the relationship.\n \"\"\"\n cypher = (\n \"MATCH (a {id: $source_id})-[r:%s]->(b {id: $target_id}) \"\n \"RETURN r\"\n ) % relationship\n params = {\n 'source_id': source_id,\n 'target_id': target_id\n }\n try:\n result = self.session.run(cypher, params)\n record = result.single()\n if record:\n relationship_data = record['r']\n properties = dict(relationship_data)\n properties['type'] = relationship.type # Include relationship type\n logger.info(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' retrieved.\")\n return properties\n else:\n logger.warning(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.\")\n raise StorageError(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to retrieve relationship: {e}\")\n raise StorageError(f\"Failed to retrieve relationship: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.query_nodes","title":"query_nodes(properties, labels=None)
","text":"Queries nodes based on properties and labels.
Parameters:
Name Type Description Default properties
Dict[str, Any]
Properties to filter nodes.
required labels
Optional[List[str]]
Labels to filter nodes.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of nodes matching the query.
Raises:
Type Description StorageError
If there is an issue querying nodes.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def query_nodes(\n self,\n properties: Dict[str, Any],\n labels: Optional[List[str]] = None\n) -> List[Dict[str, Any]]:\n \"\"\"\n Queries nodes based on properties and labels.\n\n Args:\n properties (Dict[str, Any]): Properties to filter nodes.\n labels (Optional[List[str]]): Labels to filter nodes.\n\n Returns:\n List[Dict[str, Any]]: A list of nodes matching the query.\n\n Raises:\n StorageError: If there is an issue querying nodes.\n \"\"\"\n labels_str = ':' + ':'.join(labels) if labels else ''\n params = {}\n cypher = f\"MATCH (n{labels_str})\"\n\n if properties:\n props_conditions = ' AND '.join([f\"n.{key} = ${key}\" for key in properties.keys()])\n cypher += f\" WHERE {props_conditions}\"\n params.update(properties)\n\n cypher += \" RETURN n\"\n\n try:\n result = self.session.run(cypher, params)\n nodes = []\n for record in result:\n node = record['n']\n node_data = {\n 'id': node['id'],\n 'properties': {k: v for k, v in node.items() if k != 'id'},\n 'labels': list(node.labels)\n }\n nodes.append(node_data)\n logger.info(f\"Query returned {len(nodes)} nodes.\")\n return nodes\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to query nodes: {e}\")\n raise StorageError(f\"Failed to query nodes: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.update_edge","title":"update_edge(source_id, target_id, relationship, properties)
","text":"Updates properties of a specific relationship between two nodes.
Parameters:
Name Type Description Default source_id
str
Unique identifier of the source node.
required target_id
str
Unique identifier of the target node.
required relationship
str
Type of the relationship to update.
required properties
Dict[str, Any]
Properties to update on the relationship.
required Raises:
Type Description StorageError
If there is an issue updating the relationship.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def update_edge(\n self,\n source_id: str,\n target_id: str,\n relationship: str,\n properties: Dict[str, Any]\n) -> None:\n \"\"\"\n Updates properties of a specific relationship between two nodes.\n\n Args:\n source_id (str): Unique identifier of the source node.\n target_id (str): Unique identifier of the target node.\n relationship (str): Type of the relationship to update.\n properties (Dict[str, Any]): Properties to update on the relationship.\n\n Raises:\n StorageError: If there is an issue updating the relationship.\n \"\"\"\n cypher = (\n \"MATCH (a {id: $source_id})-[r:%s]->(b {id: $target_id}) \"\n \"SET r += $properties RETURN r\"\n ) % relationship\n params = {\n 'source_id': source_id,\n 'target_id': target_id,\n 'properties': properties\n }\n try:\n result = self.session.run(cypher, params)\n record = result.single()\n if record:\n logger.info(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' updated with properties {properties}.\")\n else:\n logger.warning(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.\")\n raise StorageError(f\"Relationship '{relationship}' between '{source_id}' and '{target_id}' not found.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to update relationship: {e}\")\n raise StorageError(f\"Failed to update relationship: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.Neo4jDatabase.update_node","title":"update_node(node_id, properties)
","text":"Updates properties of a node.
Parameters:
Name Type Description Default node_id
str
Unique identifier of the node.
required properties
Dict[str, Any]
Properties to update.
required Raises:
Type Description NodeNotFoundError
If the node does not exist.
StorageError
If there is an issue updating the node.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
def update_node(self, node_id: str, properties: Dict[str, Any]) -> None:\n \"\"\"\n Updates properties of a node.\n\n Args:\n node_id (str): Unique identifier of the node.\n properties (Dict[str, Any]): Properties to update.\n\n Raises:\n NodeNotFoundError: If the node does not exist.\n StorageError: If there is an issue updating the node.\n \"\"\"\n cypher = \"MATCH (n {id: $node_id}) SET n += $properties RETURN n\"\n params = {\n 'node_id': node_id,\n 'properties': properties\n }\n try:\n result = self.session.run(cypher, params)\n record = result.single()\n if record:\n logger.info(f\"Node with id '{node_id}' updated.\")\n else:\n logger.warning(f\"Node with id '{node_id}' not found.\")\n raise NodeNotFoundError(f\"Node with id '{node_id}' not found.\")\n except exceptions.Neo4jError as e:\n logger.error(f\"Failed to update node: {e}\")\n raise StorageError(f\"Failed to update node: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.NodeNotFoundError","title":"NodeNotFoundError
","text":" Bases: Exception
Exception raised when a node is not found in the graph database.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
class NodeNotFoundError(Exception):\n \"\"\"Exception raised when a node is not found in the graph database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.neo4jdb.neo4j_database.StorageError","title":"StorageError
","text":" Bases: Exception
Exception raised when there is a storage-related error in the graph database.
Source code in src/aeiva/storage/neo4jdb/neo4j_database.py
class StorageError(Exception):\n \"\"\"Exception raised when there is a storage-related error in the graph database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.pgvector","title":"pgvector
","text":""},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_config","title":"pgvector_config
","text":""},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_config.PGVectorConfig","title":"PGVectorConfig
dataclass
","text":" Bases: BaseConfig
Configuration for PGVector (PostgreSQL with vector extension).
Source code in src/aeiva/storage/pgvector/pgvector_config.py
@dataclass\nclass PGVectorConfig(BaseConfig):\n \"\"\"\n Configuration for PGVector (PostgreSQL with vector extension).\n \"\"\"\n\n dbname: str = field(\n default=\"postgres\",\n metadata={\"help\": \"Name of the database.\"}\n )\n collection_name: str = field(\n default=\"mem0\",\n metadata={\"help\": \"Name of the collection (table name).\"}\n )\n embedding_model_dims: int = field(\n default=1536,\n metadata={\"help\": \"Dimensions of the embedding model.\"}\n )\n user: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Database user.\"}\n )\n password: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Database password.\"}\n )\n host: str = field(\n default=\"localhost\",\n metadata={\"help\": \"Database host.\"}\n )\n port: int = field(\n default=5432,\n metadata={\"help\": \"Database port.\"}\n )\n use_diskann: bool = field(\n default=True,\n metadata={\"help\": \"Whether to use diskann for approximate nearest neighbors search.\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n # Validate that user and password are provided\n if not self.user or not self.password:\n raise ValueError(\"Both 'user' and 'password' must be provided.\")\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database","title":"pgvector_database
","text":""},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase","title":"PGVectorDatabase
","text":" Bases: VectorDatabase
Concrete implementation of VectorStoreBase using PGVector.
Source code in src/aeiva/storage/pgvector/pgvector_database.py
class PGVectorDatabase(VectorDatabase):\n \"\"\"\n Concrete implementation of VectorStoreBase using PGVector.\n \"\"\"\n\n def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the PGVector vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.collection_name = config.get('collection_name')\n self.dbname = config.get('dbname')\n self.user = config.get('user')\n self.password = config.get('password')\n self.host = config.get('host', 'localhost')\n self.port = config.get('port', 5432)\n self.embedding_model_dims = config.get('embedding_model_dims')\n self.use_diskann = config.get('use_diskann', False)\n\n if not all([self.collection_name, self.dbname, self.user, self.password, self.embedding_model_dims]):\n raise ValueError(\"Required configuration parameters are missing.\")\n\n self.create_client()\n self.create_collection(\n collection_name=self.collection_name,\n vector_size=self.embedding_model_dims,\n distance_metric='cosine' # PGVector uses cosine by default\n )\n\n def create_client(self, **kwargs) -> None:\n \"\"\"\n Initializes the client connection to the PGVector database.\n\n Args:\n **kwargs: Additional parameters.\n \"\"\"\n try:\n self.conn = psycopg2.connect(\n dbname=self.dbname,\n user=self.user,\n password=self.password,\n host=self.host,\n port=self.port,\n **kwargs\n )\n self.cur = self.conn.cursor()\n logger.info(\"Connected to PGVector database.\")\n except psycopg2.Error as e:\n logger.error(f\"Failed to connect to PGVector database: {e}\")\n raise ConnectionError(f\"Failed to connect to PGVector database: {e}\")\n\n def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection (table) in PGVector.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use (e.g., 'cosine').\n \"\"\"\n # Check if table exists\n self.cur.execute(\n \"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name=%s);\",\n (collection_name,)\n )\n exists = self.cur.fetchone()[0]\n if exists:\n logger.info(f\"Table {collection_name} already exists. Skipping creation.\")\n return\n\n # Create table\n create_table_query = f\"\"\"\n CREATE TABLE {collection_name} (\n id VARCHAR(64) PRIMARY KEY,\n vector vector({vector_size}),\n payload JSONB\n );\n \"\"\"\n self.cur.execute(create_table_query)\n self.conn.commit()\n logger.info(f\"Table {collection_name} created successfully.\")\n\n # Create index if use_diskann is True\n if self.use_diskann:\n create_index_query = f\"\"\"\n CREATE INDEX {collection_name}_vector_idx\n ON {collection_name}\n USING ivfflat (vector vector_cosine_ops)\n WITH (lists = 100);\n \"\"\"\n self.cur.execute(create_index_query)\n self.conn.commit()\n logger.info(f\"Index created on table {collection_name}.\")\n\n def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n ) -> None:\n \"\"\"\n Insert vectors into a collection.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n if ids is None:\n raise ValueError(\"PGVector requires IDs to be provided for each vector.\")\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n records = [\n (id_, vector, Json(payload))\n for id_, vector, payload in zip(ids, vectors, payloads)\n ]\n insert_query = f\"INSERT INTO {collection_name} (id, vector, payload) VALUES %s;\"\n execute_values(self.cur, insert_query, records)\n self.conn.commit()\n logger.info(f\"Inserted {len(vectors)} vectors into table {collection_name}.\")\n\n def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n filter_clause = \"\"\n params = [query_vector]\n\n if filters:\n filter_conditions = []\n for key, value in filters.items():\n filter_conditions.append(f\"payload ->> %s = %s\")\n params.extend([key, str(value)])\n filter_clause = \"WHERE \" + \" AND \".join(filter_conditions)\n\n search_query = f\"\"\"\n SELECT id, vector, payload, 1 - (vector <#> %s::vector) AS score\n FROM {collection_name}\n {filter_clause}\n ORDER BY vector <#> %s::vector\n LIMIT %s;\n \"\"\"\n params.extend([query_vector, top_k])\n self.cur.execute(search_query, params)\n results = self.cur.fetchall()\n\n output = []\n for row in results:\n result = {\n 'id': row[0],\n 'score': row[3],\n 'payload': row[2]\n }\n output.append(result)\n return output\n\n def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n delete_query = f\"DELETE FROM {collection_name} WHERE id = %s;\"\n self.cur.execute(delete_query, (vector_id,))\n self.conn.commit()\n logger.info(f\"Deleted vector with ID {vector_id} from table {collection_name}.\")\n\n def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n ) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n if vector is not None:\n update_query = f\"UPDATE {collection_name} SET vector = %s WHERE id = %s;\"\n self.cur.execute(update_query, (vector, vector_id))\n if payload is not None:\n update_query = f\"UPDATE {collection_name} SET payload = %s WHERE id = %s;\"\n self.cur.execute(update_query, (Json(payload), vector_id))\n self.conn.commit()\n logger.info(f\"Updated vector with ID {vector_id} in table {collection_name}.\")\n\n def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n select_query = f\"SELECT id, vector, payload FROM {collection_name} WHERE id = %s;\"\n self.cur.execute(select_query, (vector_id,))\n result = self.cur.fetchone()\n\n if not result:\n raise KeyError(f\"Vector with ID {vector_id} not found in table {collection_name}.\")\n\n vector_data = {\n 'id': result[0],\n 'vector': result[1],\n 'payload': result[2]\n }\n return vector_data\n\n def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections (tables).\n\n Returns:\n List[str]: A list of collection names.\n \"\"\"\n self.cur.execute(\n \"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';\"\n )\n tables = self.cur.fetchall()\n return [table[0] for table in tables]\n\n def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n \"\"\"\n drop_query = f\"DROP TABLE IF EXISTS {collection_name};\"\n self.cur.execute(drop_query)\n self.conn.commit()\n logger.info(f\"Deleted table {collection_name}.\")\n\n def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection.\n \"\"\"\n self.cur.execute(\n \"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s;\",\n (collection_name,)\n )\n columns = self.cur.fetchall()\n info = {\n 'name': collection_name,\n 'columns': {column[0]: column[1] for column in columns}\n }\n return info\n\n def __del__(self):\n \"\"\"Clean up resources.\"\"\"\n if hasattr(self, 'cur') and self.cur:\n self.cur.close()\n if hasattr(self, 'conn') and self.conn:\n self.conn.close()\n logger.info(\"Closed connection to PGVector database.\")\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.__del__","title":"__del__()
","text":"Clean up resources.
Source code in src/aeiva/storage/pgvector/pgvector_database.py
def __del__(self):\n \"\"\"Clean up resources.\"\"\"\n if hasattr(self, 'cur') and self.cur:\n self.cur.close()\n if hasattr(self, 'conn') and self.conn:\n self.conn.close()\n logger.info(\"Closed connection to PGVector database.\")\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.__init__","title":"__init__(config)
","text":"Initialize the PGVector vector store.
Parameters:
Name Type Description Default config
Dict[str, Any]
Configuration dictionary.
required Source code in src/aeiva/storage/pgvector/pgvector_database.py
def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the PGVector vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.collection_name = config.get('collection_name')\n self.dbname = config.get('dbname')\n self.user = config.get('user')\n self.password = config.get('password')\n self.host = config.get('host', 'localhost')\n self.port = config.get('port', 5432)\n self.embedding_model_dims = config.get('embedding_model_dims')\n self.use_diskann = config.get('use_diskann', False)\n\n if not all([self.collection_name, self.dbname, self.user, self.password, self.embedding_model_dims]):\n raise ValueError(\"Required configuration parameters are missing.\")\n\n self.create_client()\n self.create_collection(\n collection_name=self.collection_name,\n vector_size=self.embedding_model_dims,\n distance_metric='cosine' # PGVector uses cosine by default\n )\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.create_client","title":"create_client(**kwargs)
","text":"Initializes the client connection to the PGVector database.
Parameters:
Name Type Description Default **kwargs
Additional parameters.
{}
Source code in src/aeiva/storage/pgvector/pgvector_database.py
def create_client(self, **kwargs) -> None:\n \"\"\"\n Initializes the client connection to the PGVector database.\n\n Args:\n **kwargs: Additional parameters.\n \"\"\"\n try:\n self.conn = psycopg2.connect(\n dbname=self.dbname,\n user=self.user,\n password=self.password,\n host=self.host,\n port=self.port,\n **kwargs\n )\n self.cur = self.conn.cursor()\n logger.info(\"Connected to PGVector database.\")\n except psycopg2.Error as e:\n logger.error(f\"Failed to connect to PGVector database: {e}\")\n raise ConnectionError(f\"Failed to connect to PGVector database: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.create_collection","title":"create_collection(collection_name, vector_size, distance_metric)
","text":"Create a new vector collection (table) in PGVector.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_size
int
The dimensionality of the vectors.
required distance_metric
str
The distance metric to use (e.g., 'cosine').
required Source code in src/aeiva/storage/pgvector/pgvector_database.py
def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection (table) in PGVector.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use (e.g., 'cosine').\n \"\"\"\n # Check if table exists\n self.cur.execute(\n \"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name=%s);\",\n (collection_name,)\n )\n exists = self.cur.fetchone()[0]\n if exists:\n logger.info(f\"Table {collection_name} already exists. Skipping creation.\")\n return\n\n # Create table\n create_table_query = f\"\"\"\n CREATE TABLE {collection_name} (\n id VARCHAR(64) PRIMARY KEY,\n vector vector({vector_size}),\n payload JSONB\n );\n \"\"\"\n self.cur.execute(create_table_query)\n self.conn.commit()\n logger.info(f\"Table {collection_name} created successfully.\")\n\n # Create index if use_diskann is True\n if self.use_diskann:\n create_index_query = f\"\"\"\n CREATE INDEX {collection_name}_vector_idx\n ON {collection_name}\n USING ivfflat (vector vector_cosine_ops)\n WITH (lists = 100);\n \"\"\"\n self.cur.execute(create_index_query)\n self.conn.commit()\n logger.info(f\"Index created on table {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.delete_collection","title":"delete_collection(collection_name)
","text":"Delete an entire vector collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection to delete.
required Source code in src/aeiva/storage/pgvector/pgvector_database.py
def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n \"\"\"\n drop_query = f\"DROP TABLE IF EXISTS {collection_name};\"\n self.cur.execute(drop_query)\n self.conn.commit()\n logger.info(f\"Deleted table {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.delete_vector","title":"delete_vector(collection_name, vector_id)
","text":"Delete a vector from a collection by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to delete.
required Source code in src/aeiva/storage/pgvector/pgvector_database.py
def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n delete_query = f\"DELETE FROM {collection_name} WHERE id = %s;\"\n self.cur.execute(delete_query, (vector_id,))\n self.conn.commit()\n logger.info(f\"Deleted vector with ID {vector_id} from table {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.get_collection_info","title":"get_collection_info(collection_name)
","text":"Get information about a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: Information about the collection.
Source code in src/aeiva/storage/pgvector/pgvector_database.py
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection.\n \"\"\"\n self.cur.execute(\n \"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s;\",\n (collection_name,)\n )\n columns = self.cur.fetchall()\n info = {\n 'name': collection_name,\n 'columns': {column[0]: column[1] for column in columns}\n }\n return info\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.get_vector","title":"get_vector(collection_name, vector_id)
","text":"Retrieve a vector by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the vector data and payload.
Source code in src/aeiva/storage/pgvector/pgvector_database.py
def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n select_query = f\"SELECT id, vector, payload FROM {collection_name} WHERE id = %s;\"\n self.cur.execute(select_query, (vector_id,))\n result = self.cur.fetchone()\n\n if not result:\n raise KeyError(f\"Vector with ID {vector_id} not found in table {collection_name}.\")\n\n vector_data = {\n 'id': result[0],\n 'vector': result[1],\n 'payload': result[2]\n }\n return vector_data\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.insert_vectors","title":"insert_vectors(collection_name, vectors, payloads=None, ids=None)
","text":"Insert vectors into a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vectors
List[List[float]]
A list of vectors to insert.
required payloads
Optional[List[Dict[str, Any]]]
Optional metadata associated with each vector.
None
ids
Optional[List[str]]
Optional unique identifiers for each vector.
None
Source code in src/aeiva/storage/pgvector/pgvector_database.py
def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n) -> None:\n \"\"\"\n Insert vectors into a collection.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n if ids is None:\n raise ValueError(\"PGVector requires IDs to be provided for each vector.\")\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n records = [\n (id_, vector, Json(payload))\n for id_, vector, payload in zip(ids, vectors, payloads)\n ]\n insert_query = f\"INSERT INTO {collection_name} (id, vector, payload) VALUES %s;\"\n execute_values(self.cur, insert_query, records)\n self.conn.commit()\n logger.info(f\"Inserted {len(vectors)} vectors into table {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.list_collections","title":"list_collections()
","text":"List all available vector collections (tables).
Returns:
Type Description List[str]
List[str]: A list of collection names.
Source code in src/aeiva/storage/pgvector/pgvector_database.py
def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections (tables).\n\n Returns:\n List[str]: A list of collection names.\n \"\"\"\n self.cur.execute(\n \"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';\"\n )\n tables = self.cur.fetchall()\n return [table[0] for table in tables]\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.search_vectors","title":"search_vectors(collection_name, query_vector, top_k=5, filters=None)
","text":"Search for similar vectors in a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required query_vector
List[float]
The vector to search with.
required top_k
int
The number of top results to return.
5
filters
Optional[Dict[str, Any]]
Optional filters to apply to the search.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of search results.
Source code in src/aeiva/storage/pgvector/pgvector_database.py
def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n filter_clause = \"\"\n params = [query_vector]\n\n if filters:\n filter_conditions = []\n for key, value in filters.items():\n filter_conditions.append(f\"payload ->> %s = %s\")\n params.extend([key, str(value)])\n filter_clause = \"WHERE \" + \" AND \".join(filter_conditions)\n\n search_query = f\"\"\"\n SELECT id, vector, payload, 1 - (vector <#> %s::vector) AS score\n FROM {collection_name}\n {filter_clause}\n ORDER BY vector <#> %s::vector\n LIMIT %s;\n \"\"\"\n params.extend([query_vector, top_k])\n self.cur.execute(search_query, params)\n results = self.cur.fetchall()\n\n output = []\n for row in results:\n result = {\n 'id': row[0],\n 'score': row[3],\n 'payload': row[2]\n }\n output.append(result)\n return output\n
"},{"location":"reference/#src.aeiva.storage.pgvector.pgvector_database.PGVectorDatabase.update_vector","title":"update_vector(collection_name, vector_id, vector=None, payload=None)
","text":"Update a vector's data or payload.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to update.
required vector
Optional[List[float]]
The new vector data.
None
payload
Optional[Dict[str, Any]]
The new payload data.
None
Source code in src/aeiva/storage/pgvector/pgvector_database.py
def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n if vector is not None:\n update_query = f\"UPDATE {collection_name} SET vector = %s WHERE id = %s;\"\n self.cur.execute(update_query, (vector, vector_id))\n if payload is not None:\n update_query = f\"UPDATE {collection_name} SET payload = %s WHERE id = %s;\"\n self.cur.execute(update_query, (Json(payload), vector_id))\n self.conn.commit()\n logger.info(f\"Updated vector with ID {vector_id} in table {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.postgresql","title":"postgresql
","text":""},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_config","title":"postgresql_config
","text":""},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_config.PostgreSQLConfig","title":"PostgreSQLConfig
dataclass
","text":" Bases: BaseConfig
Configuration for PostgreSQL database.
Source code in src/aeiva/storage/postgresql/postgresql_config.py
@dataclass\nclass PostgreSQLConfig(BaseConfig):\n \"\"\"\n Configuration for PostgreSQL database.\n \"\"\"\n dbname: str = field(\n default='postgres',\n metadata={\"help\": \"Name of the PostgreSQL database.\"}\n )\n user: str = field(\n default='postgres',\n metadata={\"help\": \"Username for PostgreSQL authentication.\"}\n )\n password: str = field(\n default='',\n metadata={\"help\": \"Password for PostgreSQL authentication.\"}\n )\n host: str = field(\n default='localhost',\n metadata={\"help\": \"Host address for PostgreSQL server.\"}\n )\n port: int = field(\n default=5432,\n metadata={\"help\": \"Port number for PostgreSQL server.\"}\n )\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database","title":"postgresql_database
","text":""},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase","title":"PostgreSQLDatabase
","text":" Bases: RelationalDatabase
Concrete implementation of RelationalStoreBase using PostgreSQL.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
class PostgreSQLDatabase(RelationalDatabase):\n \"\"\"\n Concrete implementation of RelationalStoreBase using PostgreSQL.\n \"\"\"\n\n def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the PostgreSQL database connection.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.connection = None\n self.cursor = None\n self.connect()\n\n def connect(self) -> None:\n \"\"\"\n Establishes a connection to the PostgreSQL database.\n \"\"\"\n try:\n self.connection = psycopg2.connect(\n dbname=self.config.get('dbname'),\n user=self.config.get('user'),\n password=self.config.get('password'),\n host=self.config.get('host'),\n port=self.config.get('port')\n )\n self.connection.autocommit = True # Enable autocommit for DDL statements\n self.cursor = self.connection.cursor(cursor_factory=RealDictCursor)\n except psycopg2.Error as e:\n raise ConnectionError(f\"Failed to connect to PostgreSQL database: {e}\")\n\n def close(self) -> None:\n \"\"\"\n Closes the database connection and releases resources.\n \"\"\"\n if self.cursor:\n self.cursor.close()\n if self.connection:\n self.connection.close()\n\n def insert_record(self, table: str, record: Dict[str, Any]) -> Any:\n \"\"\"\n Inserts a record into a table.\n\n Args:\n table (str): The name of the table.\n record (Dict[str, Any]): A dictionary representing the record to insert.\n\n Returns:\n Any: The primary key of the inserted record.\n\n Raises:\n StorageError: If there is an issue inserting the record.\n \"\"\"\n try:\n columns = ', '.join(record.keys())\n placeholders = ', '.join(f\"%({key})s\" for key in record.keys())\n sql = f\"INSERT INTO {table} ({columns}) VALUES ({placeholders}) RETURNING id\"\n self.cursor.execute(sql, record)\n result = self.cursor.fetchone()\n return result['id']\n except psycopg2.IntegrityError as e:\n self.connection.rollback()\n raise StorageError(f\"Integrity error: {e}\")\n except psycopg2.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to insert record: {e}\")\n\n def get_record(self, table: str, primary_key: Any) -> Dict[str, Any]:\n \"\"\"\n Retrieves a record by its primary key.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Returns:\n Dict[str, Any]: The retrieved record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue retrieving the record.\n \"\"\"\n try:\n sql = f\"SELECT * FROM {table} WHERE id = %s\"\n self.cursor.execute(sql, (primary_key,))\n row = self.cursor.fetchone()\n if row is None:\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n return dict(row)\n except psycopg2.Error as e:\n raise StorageError(f\"Failed to get record: {e}\")\n\n def update_record(self, table: str, primary_key: Any, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a record in a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n updates (Dict[str, Any]): A dictionary of fields to update.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue updating the record.\n \"\"\"\n try:\n set_clause = ', '.join(f\"{key} = %({key})s\" for key in updates.keys())\n sql = f\"UPDATE {table} SET {set_clause} WHERE id = %(id)s\"\n updates['id'] = primary_key\n self.cursor.execute(sql, updates)\n if self.cursor.rowcount == 0:\n self.connection.rollback()\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n except psycopg2.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to update record: {e}\")\n\n def delete_record(self, table: str, primary_key: Any) -> None:\n \"\"\"\n Deletes a record from a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue deleting the record.\n \"\"\"\n try:\n sql = f\"DELETE FROM {table} WHERE id = %s\"\n self.cursor.execute(sql, (primary_key,))\n if self.cursor.rowcount == 0:\n self.connection.rollback()\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n except psycopg2.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to delete record: {e}\")\n\n def query_records(\n self,\n table: str,\n conditions: Optional[Dict[str, Any]] = None,\n limit: Optional[int] = None,\n offset: Optional[int] = None\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Queries records from a table based on conditions.\n\n Args:\n table (str): The name of the table.\n conditions (Optional[Dict[str, Any]]): Conditions to filter records.\n limit (Optional[int]): Maximum number of records to return.\n offset (Optional[int]): Number of records to skip.\n\n Returns:\n List[Dict[str, Any]]: A list of records matching the query.\n\n Raises:\n StorageError: If there is an issue querying records.\n \"\"\"\n try:\n sql = f\"SELECT * FROM {table}\"\n params = {}\n if conditions:\n where_clause = ' AND '.join(f\"{key} = %({key})s\" for key in conditions.keys())\n sql += f\" WHERE {where_clause}\"\n params.update(conditions)\n if limit is not None:\n sql += f\" LIMIT {limit}\"\n if offset is not None:\n sql += f\" OFFSET {offset}\"\n self.cursor.execute(sql, params)\n rows = self.cursor.fetchall()\n return [dict(row) for row in rows]\n except psycopg2.Error as e:\n raise StorageError(f\"Failed to query records: {e}\")\n\n def execute_sql(self, query: str, parameters: Optional[Dict[str, Any]] = None) -> Any:\n \"\"\"\n Executes a raw SQL query.\n\n Args:\n query (str): The SQL query string.\n parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.\n\n Returns:\n Any: The result of the query.\n\n Raises:\n StorageError: If there is an issue executing the query.\n \"\"\"\n cursor = self.connection.cursor(cursor_factory=psycopg2.extras.DictCursor)\n try:\n if parameters:\n cursor.execute(query, parameters)\n else:\n cursor.execute(query)\n if query.strip().upper().startswith(\"SELECT\"):\n return cursor\n else:\n self.connection.commit()\n return cursor\n except psycopg2.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to execute SQL query: {e}\")\n\n def begin_transaction(self) -> None:\n \"\"\"\n Begins a transaction.\n \"\"\"\n self.connection.autocommit = False\n\n def commit_transaction(self) -> None:\n \"\"\"\n Commits the current transaction.\n \"\"\"\n self.connection.commit()\n self.connection.autocommit = True\n\n def rollback_transaction(self) -> None:\n \"\"\"\n Rolls back the current transaction.\n \"\"\"\n self.connection.rollback()\n self.connection.autocommit = True\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.__init__","title":"__init__(config)
","text":"Initialize the PostgreSQL database connection.
Parameters:
Name Type Description Default config
Dict[str, Any]
Configuration dictionary.
required Source code in src/aeiva/storage/postgresql/postgresql_database.py
def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the PostgreSQL database connection.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.connection = None\n self.cursor = None\n self.connect()\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.begin_transaction","title":"begin_transaction()
","text":"Begins a transaction.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def begin_transaction(self) -> None:\n \"\"\"\n Begins a transaction.\n \"\"\"\n self.connection.autocommit = False\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.close","title":"close()
","text":"Closes the database connection and releases resources.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def close(self) -> None:\n \"\"\"\n Closes the database connection and releases resources.\n \"\"\"\n if self.cursor:\n self.cursor.close()\n if self.connection:\n self.connection.close()\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.commit_transaction","title":"commit_transaction()
","text":"Commits the current transaction.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def commit_transaction(self) -> None:\n \"\"\"\n Commits the current transaction.\n \"\"\"\n self.connection.commit()\n self.connection.autocommit = True\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.connect","title":"connect()
","text":"Establishes a connection to the PostgreSQL database.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def connect(self) -> None:\n \"\"\"\n Establishes a connection to the PostgreSQL database.\n \"\"\"\n try:\n self.connection = psycopg2.connect(\n dbname=self.config.get('dbname'),\n user=self.config.get('user'),\n password=self.config.get('password'),\n host=self.config.get('host'),\n port=self.config.get('port')\n )\n self.connection.autocommit = True # Enable autocommit for DDL statements\n self.cursor = self.connection.cursor(cursor_factory=RealDictCursor)\n except psycopg2.Error as e:\n raise ConnectionError(f\"Failed to connect to PostgreSQL database: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.delete_record","title":"delete_record(table, primary_key)
","text":"Deletes a record from a table.
Parameters:
Name Type Description Default table
str
The name of the table.
required primary_key
Any
The primary key of the record.
required Raises:
Type Description RecordNotFoundError
If the record does not exist.
StorageError
If there is an issue deleting the record.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def delete_record(self, table: str, primary_key: Any) -> None:\n \"\"\"\n Deletes a record from a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue deleting the record.\n \"\"\"\n try:\n sql = f\"DELETE FROM {table} WHERE id = %s\"\n self.cursor.execute(sql, (primary_key,))\n if self.cursor.rowcount == 0:\n self.connection.rollback()\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n except psycopg2.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to delete record: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.execute_sql","title":"execute_sql(query, parameters=None)
","text":"Executes a raw SQL query.
Parameters:
Name Type Description Default query
str
The SQL query string.
required parameters
Optional[Dict[str, Any]]
Parameters for parameterized queries.
None
Returns:
Name Type Description Any
Any
The result of the query.
Raises:
Type Description StorageError
If there is an issue executing the query.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def execute_sql(self, query: str, parameters: Optional[Dict[str, Any]] = None) -> Any:\n \"\"\"\n Executes a raw SQL query.\n\n Args:\n query (str): The SQL query string.\n parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.\n\n Returns:\n Any: The result of the query.\n\n Raises:\n StorageError: If there is an issue executing the query.\n \"\"\"\n cursor = self.connection.cursor(cursor_factory=psycopg2.extras.DictCursor)\n try:\n if parameters:\n cursor.execute(query, parameters)\n else:\n cursor.execute(query)\n if query.strip().upper().startswith(\"SELECT\"):\n return cursor\n else:\n self.connection.commit()\n return cursor\n except psycopg2.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to execute SQL query: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.get_record","title":"get_record(table, primary_key)
","text":"Retrieves a record by its primary key.
Parameters:
Name Type Description Default table
str
The name of the table.
required primary_key
Any
The primary key of the record.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: The retrieved record.
Raises:
Type Description RecordNotFoundError
If the record does not exist.
StorageError
If there is an issue retrieving the record.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def get_record(self, table: str, primary_key: Any) -> Dict[str, Any]:\n \"\"\"\n Retrieves a record by its primary key.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Returns:\n Dict[str, Any]: The retrieved record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue retrieving the record.\n \"\"\"\n try:\n sql = f\"SELECT * FROM {table} WHERE id = %s\"\n self.cursor.execute(sql, (primary_key,))\n row = self.cursor.fetchone()\n if row is None:\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n return dict(row)\n except psycopg2.Error as e:\n raise StorageError(f\"Failed to get record: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.insert_record","title":"insert_record(table, record)
","text":"Inserts a record into a table.
Parameters:
Name Type Description Default table
str
The name of the table.
required record
Dict[str, Any]
A dictionary representing the record to insert.
required Returns:
Name Type Description Any
Any
The primary key of the inserted record.
Raises:
Type Description StorageError
If there is an issue inserting the record.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def insert_record(self, table: str, record: Dict[str, Any]) -> Any:\n \"\"\"\n Inserts a record into a table.\n\n Args:\n table (str): The name of the table.\n record (Dict[str, Any]): A dictionary representing the record to insert.\n\n Returns:\n Any: The primary key of the inserted record.\n\n Raises:\n StorageError: If there is an issue inserting the record.\n \"\"\"\n try:\n columns = ', '.join(record.keys())\n placeholders = ', '.join(f\"%({key})s\" for key in record.keys())\n sql = f\"INSERT INTO {table} ({columns}) VALUES ({placeholders}) RETURNING id\"\n self.cursor.execute(sql, record)\n result = self.cursor.fetchone()\n return result['id']\n except psycopg2.IntegrityError as e:\n self.connection.rollback()\n raise StorageError(f\"Integrity error: {e}\")\n except psycopg2.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to insert record: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.query_records","title":"query_records(table, conditions=None, limit=None, offset=None)
","text":"Queries records from a table based on conditions.
Parameters:
Name Type Description Default table
str
The name of the table.
required conditions
Optional[Dict[str, Any]]
Conditions to filter records.
None
limit
Optional[int]
Maximum number of records to return.
None
offset
Optional[int]
Number of records to skip.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of records matching the query.
Raises:
Type Description StorageError
If there is an issue querying records.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def query_records(\n self,\n table: str,\n conditions: Optional[Dict[str, Any]] = None,\n limit: Optional[int] = None,\n offset: Optional[int] = None\n) -> List[Dict[str, Any]]:\n \"\"\"\n Queries records from a table based on conditions.\n\n Args:\n table (str): The name of the table.\n conditions (Optional[Dict[str, Any]]): Conditions to filter records.\n limit (Optional[int]): Maximum number of records to return.\n offset (Optional[int]): Number of records to skip.\n\n Returns:\n List[Dict[str, Any]]: A list of records matching the query.\n\n Raises:\n StorageError: If there is an issue querying records.\n \"\"\"\n try:\n sql = f\"SELECT * FROM {table}\"\n params = {}\n if conditions:\n where_clause = ' AND '.join(f\"{key} = %({key})s\" for key in conditions.keys())\n sql += f\" WHERE {where_clause}\"\n params.update(conditions)\n if limit is not None:\n sql += f\" LIMIT {limit}\"\n if offset is not None:\n sql += f\" OFFSET {offset}\"\n self.cursor.execute(sql, params)\n rows = self.cursor.fetchall()\n return [dict(row) for row in rows]\n except psycopg2.Error as e:\n raise StorageError(f\"Failed to query records: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.rollback_transaction","title":"rollback_transaction()
","text":"Rolls back the current transaction.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def rollback_transaction(self) -> None:\n \"\"\"\n Rolls back the current transaction.\n \"\"\"\n self.connection.rollback()\n self.connection.autocommit = True\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.PostgreSQLDatabase.update_record","title":"update_record(table, primary_key, updates)
","text":"Updates a record in a table.
Parameters:
Name Type Description Default table
str
The name of the table.
required primary_key
Any
The primary key of the record.
required updates
Dict[str, Any]
A dictionary of fields to update.
required Raises:
Type Description RecordNotFoundError
If the record does not exist.
StorageError
If there is an issue updating the record.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
def update_record(self, table: str, primary_key: Any, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a record in a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n updates (Dict[str, Any]): A dictionary of fields to update.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue updating the record.\n \"\"\"\n try:\n set_clause = ', '.join(f\"{key} = %({key})s\" for key in updates.keys())\n sql = f\"UPDATE {table} SET {set_clause} WHERE id = %(id)s\"\n updates['id'] = primary_key\n self.cursor.execute(sql, updates)\n if self.cursor.rowcount == 0:\n self.connection.rollback()\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n except psycopg2.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to update record: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.RecordNotFoundError","title":"RecordNotFoundError
","text":" Bases: Exception
Exception raised when a record is not found in the database.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
class RecordNotFoundError(Exception):\n \"\"\"Exception raised when a record is not found in the database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.postgresql.postgresql_database.StorageError","title":"StorageError
","text":" Bases: Exception
Exception raised when there is a storage-related error in the database.
Source code in src/aeiva/storage/postgresql/postgresql_database.py
class StorageError(Exception):\n \"\"\"Exception raised when there is a storage-related error in the database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.postgresql.test","title":"test
","text":""},{"location":"reference/#src.aeiva.storage.postgresql.test.RecordNotFoundError","title":"RecordNotFoundError
","text":" Bases: Exception
Exception raised when a record is not found in the database.
Source code in src/aeiva/storage/postgresql/test.py
class RecordNotFoundError(Exception):\n \"\"\"Exception raised when a record is not found in the database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.qdrant","title":"qdrant
","text":""},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_config","title":"qdrant_config
","text":""},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_config.QdrantConfig","title":"QdrantConfig
dataclass
","text":" Bases: BaseConfig
Configuration for Qdrant vector database.
Source code in src/aeiva/storage/qdrant/qdrant_config.py
@dataclass\nclass QdrantConfig(BaseConfig):\n \"\"\"\n Configuration for Qdrant vector database.\n \"\"\"\n\n collection_name: str = field(\n default=\"mem0\",\n metadata={\"help\": \"Name of the collection.\"}\n )\n embedding_model_dims: int = field(\n default=1536,\n metadata={\"help\": \"Dimensions of the embedding model.\"}\n )\n client: Optional[Any] = field(\n default=None,\n metadata={\"help\": \"Existing Qdrant client instance (if any).\"}\n )\n host: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Host address for Qdrant server.\"}\n )\n port: Optional[int] = field(\n default=None,\n metadata={\"help\": \"Port for Qdrant server.\"}\n )\n path: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Path for local Qdrant database storage.\"}\n )\n url: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Full URL for Qdrant server.\"}\n )\n api_key: Optional[str] = field(\n default=None,\n metadata={\"help\": \"API key for Qdrant server authentication.\"}\n )\n on_disk: bool = field(\n default=False,\n metadata={\"help\": \"Whether to enable persistent storage on disk.\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n # Validate that connection parameters are provided\n if not self.path and not ((self.host and self.port) or (self.url and self.api_key)):\n raise ValueError(\"Provide 'path' for local storage, or 'host' and 'port', or 'url' and 'api_key' for remote connection.\")\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database","title":"qdrant_database
","text":""},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase","title":"QdrantDatabase
","text":" Bases: VectorDatabase
Concrete implementation of VectorStoreBase using Qdrant.
Source code in src/aeiva/storage/qdrant/qdrant_database.py
class QdrantDatabase(VectorDatabase):\n \"\"\"\n Concrete implementation of VectorStoreBase using Qdrant.\n \"\"\"\n\n def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the Qdrant vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.collection_name = config.get('collection_name')\n self.embedding_model_dims = config.get('embedding_model_dims')\n self.client = config.get('client')\n self.host = config.get('host')\n self.port = config.get('port')\n self.path = config.get('path')\n self.url = config.get('url')\n self.api_key = config.get('api_key')\n self.on_disk = config.get('on_disk', False)\n\n if not all([self.collection_name, self.embedding_model_dims]):\n raise ValueError(\"Required configuration parameters are missing.\")\n\n self.create_client()\n self.create_collection(\n collection_name=self.collection_name,\n vector_size=self.embedding_model_dims,\n distance_metric='COSINE'\n )\n\n def create_client(self, **kwargs) -> None:\n \"\"\"\n Initializes the client connection to the Qdrant vector store.\n\n Args:\n **kwargs: Additional parameters.\n \"\"\"\n if self.client:\n return # Client already provided\n\n client_params = {}\n if self.api_key:\n client_params['api_key'] = self.api_key\n if self.url:\n client_params['url'] = self.url\n elif self.host and self.port:\n client_params['host'] = self.host\n client_params['port'] = self.port\n else:\n client_params['path'] = self.path\n\n self.client = QdrantClient(**client_params)\n logger.info(\"Qdrant client initialized.\")\n\n def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection in Qdrant.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use (e.g., 'COSINE').\n \"\"\"\n # Check if collection exists\n collections = self.list_collections()\n if collection_name in collections:\n logger.info(f\"Collection {collection_name} already exists. Skipping creation.\")\n return\n\n vector_params = VectorParams(\n size=vector_size,\n distance=getattr(Distance, distance_metric.upper()),\n on_disk=self.on_disk\n )\n self.client.create_collection(\n collection_name=collection_name,\n vectors_config=vector_params\n )\n logger.info(f\"Collection {collection_name} created successfully.\")\n\n def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n ) -> None:\n \"\"\"\n Insert vectors into a collection.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n if ids is None:\n ids = [i for i in range(len(vectors))]\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n points = [\n PointStruct(\n id=id_,\n vector=vector,\n payload=payload\n )\n for id_, vector, payload in zip(ids, vectors, payloads)\n ]\n self.client.upsert(\n collection_name=collection_name,\n points=points\n )\n logger.info(f\"Inserted {len(vectors)} vectors into collection {collection_name}.\")\n\n def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n query_filter = self._build_filter(filters)\n results = self.client.search(\n collection_name=collection_name,\n query_vector=query_vector,\n limit=top_k,\n query_filter=query_filter\n )\n\n output = []\n for hit in results:\n result = {\n 'id': hit.id,\n 'score': hit.score,\n 'payload': hit.payload\n }\n output.append(result)\n return output\n\n def _build_filter(self, filters: Optional[Dict[str, Any]]) -> Optional[Filter]:\n \"\"\"\n Build a Qdrant filter object from a dictionary.\n\n Args:\n filters (Optional[Dict[str, Any]]): Filters to apply.\n\n Returns:\n Optional[Filter]: A Qdrant Filter object.\n \"\"\"\n if not filters:\n return None\n\n conditions = []\n for key, value in filters.items():\n conditions.append(\n FieldCondition(\n key=key,\n match=MatchValue(value=value)\n )\n )\n return Filter(must=conditions)\n\n def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n self.client.delete(\n collection_name=collection_name,\n points_selector=[vector_id]\n )\n logger.info(f\"Deleted vector with ID {vector_id} from collection {collection_name}.\")\n\n def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n ) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n point = PointStruct(\n id=vector_id,\n vector=vector,\n payload=payload\n )\n self.client.upsert(\n collection_name=collection_name,\n points=[point]\n )\n logger.info(f\"Updated vector with ID {vector_id} in collection {collection_name}.\")\n\n def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n result = self.client.retrieve(\n collection_name=collection_name,\n ids=[vector_id]\n )\n if not result:\n raise KeyError(f\"Vector with ID {vector_id} not found in collection {collection_name}.\")\n\n point = result[0]\n vector_data = {\n 'id': point.id,\n 'vector': point.vector,\n 'payload': point.payload\n }\n return vector_data\n\n def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections.\n\n Returns:\n List[str]: A list of collection names.\n \"\"\"\n collections = self.client.get_collections().collections\n return [collection.name for collection in collections]\n\n def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n \"\"\"\n self.client.delete_collection(collection_name=collection_name)\n logger.info(f\"Deleted collection {collection_name}.\")\n\n def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection.\n \"\"\"\n info = self.client.get_collection(collection_name=collection_name)\n return info.dict()\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.__init__","title":"__init__(config)
","text":"Initialize the Qdrant vector store.
Parameters:
Name Type Description Default config
Dict[str, Any]
Configuration dictionary.
required Source code in src/aeiva/storage/qdrant/qdrant_database.py
def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the Qdrant vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.collection_name = config.get('collection_name')\n self.embedding_model_dims = config.get('embedding_model_dims')\n self.client = config.get('client')\n self.host = config.get('host')\n self.port = config.get('port')\n self.path = config.get('path')\n self.url = config.get('url')\n self.api_key = config.get('api_key')\n self.on_disk = config.get('on_disk', False)\n\n if not all([self.collection_name, self.embedding_model_dims]):\n raise ValueError(\"Required configuration parameters are missing.\")\n\n self.create_client()\n self.create_collection(\n collection_name=self.collection_name,\n vector_size=self.embedding_model_dims,\n distance_metric='COSINE'\n )\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.create_client","title":"create_client(**kwargs)
","text":"Initializes the client connection to the Qdrant vector store.
Parameters:
Name Type Description Default **kwargs
Additional parameters.
{}
Source code in src/aeiva/storage/qdrant/qdrant_database.py
def create_client(self, **kwargs) -> None:\n \"\"\"\n Initializes the client connection to the Qdrant vector store.\n\n Args:\n **kwargs: Additional parameters.\n \"\"\"\n if self.client:\n return # Client already provided\n\n client_params = {}\n if self.api_key:\n client_params['api_key'] = self.api_key\n if self.url:\n client_params['url'] = self.url\n elif self.host and self.port:\n client_params['host'] = self.host\n client_params['port'] = self.port\n else:\n client_params['path'] = self.path\n\n self.client = QdrantClient(**client_params)\n logger.info(\"Qdrant client initialized.\")\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.create_collection","title":"create_collection(collection_name, vector_size, distance_metric)
","text":"Create a new vector collection in Qdrant.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_size
int
The dimensionality of the vectors.
required distance_metric
str
The distance metric to use (e.g., 'COSINE').
required Source code in src/aeiva/storage/qdrant/qdrant_database.py
def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection in Qdrant.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use (e.g., 'COSINE').\n \"\"\"\n # Check if collection exists\n collections = self.list_collections()\n if collection_name in collections:\n logger.info(f\"Collection {collection_name} already exists. Skipping creation.\")\n return\n\n vector_params = VectorParams(\n size=vector_size,\n distance=getattr(Distance, distance_metric.upper()),\n on_disk=self.on_disk\n )\n self.client.create_collection(\n collection_name=collection_name,\n vectors_config=vector_params\n )\n logger.info(f\"Collection {collection_name} created successfully.\")\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.delete_collection","title":"delete_collection(collection_name)
","text":"Delete an entire vector collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection to delete.
required Source code in src/aeiva/storage/qdrant/qdrant_database.py
def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n \"\"\"\n self.client.delete_collection(collection_name=collection_name)\n logger.info(f\"Deleted collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.delete_vector","title":"delete_vector(collection_name, vector_id)
","text":"Delete a vector from a collection by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to delete.
required Source code in src/aeiva/storage/qdrant/qdrant_database.py
def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n self.client.delete(\n collection_name=collection_name,\n points_selector=[vector_id]\n )\n logger.info(f\"Deleted vector with ID {vector_id} from collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.get_collection_info","title":"get_collection_info(collection_name)
","text":"Get information about a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: Information about the collection.
Source code in src/aeiva/storage/qdrant/qdrant_database.py
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection.\n \"\"\"\n info = self.client.get_collection(collection_name=collection_name)\n return info.dict()\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.get_vector","title":"get_vector(collection_name, vector_id)
","text":"Retrieve a vector by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the vector data and payload.
Source code in src/aeiva/storage/qdrant/qdrant_database.py
def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n result = self.client.retrieve(\n collection_name=collection_name,\n ids=[vector_id]\n )\n if not result:\n raise KeyError(f\"Vector with ID {vector_id} not found in collection {collection_name}.\")\n\n point = result[0]\n vector_data = {\n 'id': point.id,\n 'vector': point.vector,\n 'payload': point.payload\n }\n return vector_data\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.insert_vectors","title":"insert_vectors(collection_name, vectors, payloads=None, ids=None)
","text":"Insert vectors into a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vectors
List[List[float]]
A list of vectors to insert.
required payloads
Optional[List[Dict[str, Any]]]
Optional metadata associated with each vector.
None
ids
Optional[List[str]]
Optional unique identifiers for each vector.
None
Source code in src/aeiva/storage/qdrant/qdrant_database.py
def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n) -> None:\n \"\"\"\n Insert vectors into a collection.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n if ids is None:\n ids = [i for i in range(len(vectors))]\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n points = [\n PointStruct(\n id=id_,\n vector=vector,\n payload=payload\n )\n for id_, vector, payload in zip(ids, vectors, payloads)\n ]\n self.client.upsert(\n collection_name=collection_name,\n points=points\n )\n logger.info(f\"Inserted {len(vectors)} vectors into collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.list_collections","title":"list_collections()
","text":"List all available vector collections.
Returns:
Type Description List[str]
List[str]: A list of collection names.
Source code in src/aeiva/storage/qdrant/qdrant_database.py
def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections.\n\n Returns:\n List[str]: A list of collection names.\n \"\"\"\n collections = self.client.get_collections().collections\n return [collection.name for collection in collections]\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.search_vectors","title":"search_vectors(collection_name, query_vector, top_k=5, filters=None)
","text":"Search for similar vectors in a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required query_vector
List[float]
The vector to search with.
required top_k
int
The number of top results to return.
5
filters
Optional[Dict[str, Any]]
Optional filters to apply to the search.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of search results.
Source code in src/aeiva/storage/qdrant/qdrant_database.py
def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n query_filter = self._build_filter(filters)\n results = self.client.search(\n collection_name=collection_name,\n query_vector=query_vector,\n limit=top_k,\n query_filter=query_filter\n )\n\n output = []\n for hit in results:\n result = {\n 'id': hit.id,\n 'score': hit.score,\n 'payload': hit.payload\n }\n output.append(result)\n return output\n
"},{"location":"reference/#src.aeiva.storage.qdrant.qdrant_database.QdrantDatabase.update_vector","title":"update_vector(collection_name, vector_id, vector=None, payload=None)
","text":"Update a vector's data or payload.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to update.
required vector
Optional[List[float]]
The new vector data.
None
payload
Optional[Dict[str, Any]]
The new payload data.
None
Source code in src/aeiva/storage/qdrant/qdrant_database.py
def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n \"\"\"\n if collection_name != self.collection_name:\n raise ValueError(\"Collection name does not match initialized collection name.\")\n\n point = PointStruct(\n id=vector_id,\n vector=vector,\n payload=payload\n )\n self.client.upsert(\n collection_name=collection_name,\n points=[point]\n )\n logger.info(f\"Updated vector with ID {vector_id} in collection {collection_name}.\")\n
"},{"location":"reference/#src.aeiva.storage.relational_database","title":"relational_database
","text":""},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase","title":"RelationalDatabase
","text":" Bases: ABC
Abstract base class for relational database operations.
Source code in src/aeiva/storage/relational_database.py
class RelationalDatabase(ABC):\n \"\"\"\n Abstract base class for relational database operations.\n \"\"\"\n\n @abstractmethod\n def insert_record(self, table: str, record: Dict[str, Any]) -> Any:\n \"\"\"\n Inserts a record into a table.\n\n Args:\n table (str): The name of the table.\n record (Dict[str, Any]): A dictionary representing the record to insert.\n\n Returns:\n Any: The primary key of the inserted record.\n\n Raises:\n StorageError: If there is an issue inserting the record.\n \"\"\"\n pass\n\n @abstractmethod\n def get_record(self, table: str, primary_key: Any) -> Dict[str, Any]:\n \"\"\"\n Retrieves a record by its primary key.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Returns:\n Dict[str, Any]: The retrieved record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue retrieving the record.\n \"\"\"\n pass\n\n @abstractmethod\n def update_record(self, table: str, primary_key: Any, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a record in a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n updates (Dict[str, Any]): A dictionary of fields to update.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue updating the record.\n \"\"\"\n pass\n\n @abstractmethod\n def delete_record(self, table: str, primary_key: Any) -> None:\n \"\"\"\n Deletes a record from a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue deleting the record.\n \"\"\"\n pass\n\n @abstractmethod\n def query_records(self, table: str, conditions: Optional[Dict[str, Any]] = None, limit: Optional[int] = None, offset: Optional[int] = None) -> List[Dict[str, Any]]:\n \"\"\"\n Queries records from a table based on conditions.\n\n Args:\n table (str): The name of the table.\n conditions (Optional[Dict[str, Any]]): Conditions to filter records.\n limit (Optional[int]): Maximum number of records to return.\n offset (Optional[int]): Number of records to skip.\n\n Returns:\n List[Dict[str, Any]]: A list of records matching the query.\n\n Raises:\n StorageError: If there is an issue querying records.\n \"\"\"\n pass\n\n @abstractmethod\n def execute_sql(self, query: str, parameters: Optional[Dict[str, Any]] = None) -> Any:\n \"\"\"\n Executes a raw SQL query.\n\n Args:\n query (str): The SQL query string.\n parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.\n\n Returns:\n Any: The result of the query.\n\n Raises:\n StorageError: If there is an issue executing the query.\n \"\"\"\n pass\n\n @abstractmethod\n def begin_transaction(self) -> None:\n \"\"\"\n Begins a transaction.\n \"\"\"\n pass\n\n @abstractmethod\n def commit_transaction(self) -> None:\n \"\"\"\n Commits the current transaction.\n \"\"\"\n pass\n\n @abstractmethod\n def rollback_transaction(self) -> None:\n \"\"\"\n Rolls back the current transaction.\n \"\"\"\n pass\n\n @abstractmethod\n def close(self) -> None:\n \"\"\"\n Closes the database connection and releases resources.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase.begin_transaction","title":"begin_transaction()
abstractmethod
","text":"Begins a transaction.
Source code in src/aeiva/storage/relational_database.py
@abstractmethod\ndef begin_transaction(self) -> None:\n \"\"\"\n Begins a transaction.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase.close","title":"close()
abstractmethod
","text":"Closes the database connection and releases resources.
Source code in src/aeiva/storage/relational_database.py
@abstractmethod\ndef close(self) -> None:\n \"\"\"\n Closes the database connection and releases resources.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase.commit_transaction","title":"commit_transaction()
abstractmethod
","text":"Commits the current transaction.
Source code in src/aeiva/storage/relational_database.py
@abstractmethod\ndef commit_transaction(self) -> None:\n \"\"\"\n Commits the current transaction.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase.delete_record","title":"delete_record(table, primary_key)
abstractmethod
","text":"Deletes a record from a table.
Parameters:
Name Type Description Default table
str
The name of the table.
required primary_key
Any
The primary key of the record.
required Raises:
Type Description RecordNotFoundError
If the record does not exist.
StorageError
If there is an issue deleting the record.
Source code in src/aeiva/storage/relational_database.py
@abstractmethod\ndef delete_record(self, table: str, primary_key: Any) -> None:\n \"\"\"\n Deletes a record from a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue deleting the record.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase.execute_sql","title":"execute_sql(query, parameters=None)
abstractmethod
","text":"Executes a raw SQL query.
Parameters:
Name Type Description Default query
str
The SQL query string.
required parameters
Optional[Dict[str, Any]]
Parameters for parameterized queries.
None
Returns:
Name Type Description Any
Any
The result of the query.
Raises:
Type Description StorageError
If there is an issue executing the query.
Source code in src/aeiva/storage/relational_database.py
@abstractmethod\ndef execute_sql(self, query: str, parameters: Optional[Dict[str, Any]] = None) -> Any:\n \"\"\"\n Executes a raw SQL query.\n\n Args:\n query (str): The SQL query string.\n parameters (Optional[Dict[str, Any]]): Parameters for parameterized queries.\n\n Returns:\n Any: The result of the query.\n\n Raises:\n StorageError: If there is an issue executing the query.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase.get_record","title":"get_record(table, primary_key)
abstractmethod
","text":"Retrieves a record by its primary key.
Parameters:
Name Type Description Default table
str
The name of the table.
required primary_key
Any
The primary key of the record.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: The retrieved record.
Raises:
Type Description RecordNotFoundError
If the record does not exist.
StorageError
If there is an issue retrieving the record.
Source code in src/aeiva/storage/relational_database.py
@abstractmethod\ndef get_record(self, table: str, primary_key: Any) -> Dict[str, Any]:\n \"\"\"\n Retrieves a record by its primary key.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Returns:\n Dict[str, Any]: The retrieved record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue retrieving the record.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase.insert_record","title":"insert_record(table, record)
abstractmethod
","text":"Inserts a record into a table.
Parameters:
Name Type Description Default table
str
The name of the table.
required record
Dict[str, Any]
A dictionary representing the record to insert.
required Returns:
Name Type Description Any
Any
The primary key of the inserted record.
Raises:
Type Description StorageError
If there is an issue inserting the record.
Source code in src/aeiva/storage/relational_database.py
@abstractmethod\ndef insert_record(self, table: str, record: Dict[str, Any]) -> Any:\n \"\"\"\n Inserts a record into a table.\n\n Args:\n table (str): The name of the table.\n record (Dict[str, Any]): A dictionary representing the record to insert.\n\n Returns:\n Any: The primary key of the inserted record.\n\n Raises:\n StorageError: If there is an issue inserting the record.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase.query_records","title":"query_records(table, conditions=None, limit=None, offset=None)
abstractmethod
","text":"Queries records from a table based on conditions.
Parameters:
Name Type Description Default table
str
The name of the table.
required conditions
Optional[Dict[str, Any]]
Conditions to filter records.
None
limit
Optional[int]
Maximum number of records to return.
None
offset
Optional[int]
Number of records to skip.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of records matching the query.
Raises:
Type Description StorageError
If there is an issue querying records.
Source code in src/aeiva/storage/relational_database.py
@abstractmethod\ndef query_records(self, table: str, conditions: Optional[Dict[str, Any]] = None, limit: Optional[int] = None, offset: Optional[int] = None) -> List[Dict[str, Any]]:\n \"\"\"\n Queries records from a table based on conditions.\n\n Args:\n table (str): The name of the table.\n conditions (Optional[Dict[str, Any]]): Conditions to filter records.\n limit (Optional[int]): Maximum number of records to return.\n offset (Optional[int]): Number of records to skip.\n\n Returns:\n List[Dict[str, Any]]: A list of records matching the query.\n\n Raises:\n StorageError: If there is an issue querying records.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase.rollback_transaction","title":"rollback_transaction()
abstractmethod
","text":"Rolls back the current transaction.
Source code in src/aeiva/storage/relational_database.py
@abstractmethod\ndef rollback_transaction(self) -> None:\n \"\"\"\n Rolls back the current transaction.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.relational_database.RelationalDatabase.update_record","title":"update_record(table, primary_key, updates)
abstractmethod
","text":"Updates a record in a table.
Parameters:
Name Type Description Default table
str
The name of the table.
required primary_key
Any
The primary key of the record.
required updates
Dict[str, Any]
A dictionary of fields to update.
required Raises:
Type Description RecordNotFoundError
If the record does not exist.
StorageError
If there is an issue updating the record.
Source code in src/aeiva/storage/relational_database.py
@abstractmethod\ndef update_record(self, table: str, primary_key: Any, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a record in a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n updates (Dict[str, Any]): A dictionary of fields to update.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue updating the record.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.sqlite","title":"sqlite
","text":""},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_config","title":"sqlite_config
","text":""},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_config.SQLiteConfig","title":"SQLiteConfig
dataclass
","text":" Bases: BaseConfig
Configuration for SQLite database.
Source code in src/aeiva/storage/sqlite/sqlite_config.py
@dataclass\nclass SQLiteConfig(BaseConfig):\n \"\"\"\n Configuration for SQLite database.\n \"\"\"\n database: str = field(\n default=':memory:',\n metadata={\"help\": \"Path to the SQLite database file. Use ':memory:' for an in-memory database.\"}\n )\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database","title":"sqlite_database
","text":""},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.RecordNotFoundError","title":"RecordNotFoundError
","text":" Bases: Exception
Exception raised when a record is not found in the database.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
class RecordNotFoundError(Exception):\n \"\"\"Exception raised when a record is not found in the database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase","title":"SQLiteDatabase
","text":" Bases: RelationalDatabase
Concrete implementation of RelationalStoreBase using SQLite.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
class SQLiteDatabase(RelationalDatabase):\n \"\"\"\n Concrete implementation of RelationalStoreBase using SQLite.\n \"\"\"\n\n def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the SQLite database connection.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.database = config.get('database', ':memory:')\n self.connection = None\n self.cursor = None\n self.connect()\n\n def connect(self) -> None:\n \"\"\"\n Establishes a connection to the SQLite database.\n \"\"\"\n try:\n self.connection = sqlite3.connect(self.database)\n self.connection.row_factory = sqlite3.Row # To get dict-like rows\n self.cursor = self.connection.cursor()\n # self.connection.execute('PRAGMA foreign_keys = ON') # Enable foreign key support\n except sqlite3.Error as e:\n raise ConnectionError(f\"Failed to connect to SQLite database: {e}\")\n\n def close(self) -> None:\n \"\"\"\n Closes the database connection and releases resources.\n \"\"\"\n if self.cursor:\n self.cursor.close()\n if self.connection:\n self.connection.close()\n\n def insert_record(self, table: str, record: Dict[str, Any]) -> Any:\n \"\"\"\n Inserts a record into a table.\n\n Args:\n table (str): The name of the table.\n record (Dict[str, Any]): A dictionary representing the record to insert.\n\n Returns:\n Any: The primary key of the inserted record.\n\n Raises:\n StorageError: If there is an issue inserting the record.\n \"\"\"\n try:\n columns = ', '.join(record.keys())\n placeholders = ', '.join('?' for _ in record)\n sql = f\"INSERT INTO {table} ({columns}) VALUES ({placeholders})\"\n values = list(record.values())\n self.cursor.execute(sql, values)\n self.connection.commit()\n return self.cursor.lastrowid\n except sqlite3.IntegrityError as e:\n self.connection.rollback()\n raise StorageError(f\"Integrity error: {e}\")\n except sqlite3.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to insert record: {e}\")\n\n def get_record(self, table: str, primary_key: Any) -> Dict[str, Any]:\n \"\"\"\n Retrieves a record by its primary key.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Returns:\n Dict[str, Any]: The retrieved record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue retrieving the record.\n \"\"\"\n try:\n sql = f\"SELECT * FROM {table} WHERE id = ?\"\n self.cursor.execute(sql, (primary_key,))\n row = self.cursor.fetchone()\n if row is None:\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n return dict(row)\n except sqlite3.Error as e:\n raise StorageError(f\"Failed to get record: {e}\")\n\n def update_record(self, table: str, primary_key: Any, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a record in a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n updates (Dict[str, Any]): A dictionary of fields to update.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue updating the record.\n \"\"\"\n try:\n set_clause = ', '.join(f\"{key} = ?\" for key in updates.keys())\n sql = f\"UPDATE {table} SET {set_clause} WHERE id = ?\"\n values = list(updates.values()) + [primary_key]\n self.cursor.execute(sql, values)\n if self.cursor.rowcount == 0:\n self.connection.rollback()\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n self.connection.commit()\n except sqlite3.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to update record: {e}\")\n\n def delete_record(self, table: str, primary_key: Any) -> None:\n \"\"\"\n Deletes a record from a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue deleting the record.\n \"\"\"\n try:\n sql = f\"DELETE FROM {table} WHERE id = ?\"\n self.cursor.execute(sql, (primary_key,))\n if self.cursor.rowcount == 0:\n self.connection.rollback()\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n self.connection.commit()\n except sqlite3.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to delete record: {e}\")\n\n def query_records(\n self,\n table: str,\n conditions: Optional[Dict[str, Any]] = None,\n limit: Optional[int] = None,\n offset: Optional[int] = None\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Queries records from a table based on conditions.\n\n Args:\n table (str): The name of the table.\n conditions (Optional[Dict[str, Any]]): Conditions to filter records.\n limit (Optional[int]): Maximum number of records to return.\n offset (Optional[int]): Number of records to skip.\n\n Returns:\n List[Dict[str, Any]]: A list of records matching the query.\n\n Raises:\n StorageError: If there is an issue querying records.\n \"\"\"\n try:\n sql = f\"SELECT * FROM {table}\"\n params = []\n if conditions:\n where_clause = ' AND '.join(f\"{key} = ?\" for key in conditions.keys())\n sql += f\" WHERE {where_clause}\"\n params.extend(conditions.values())\n if limit is not None:\n sql += f\" LIMIT {limit}\"\n if offset is not None:\n sql += f\" OFFSET {offset}\"\n self.cursor.execute(sql, params)\n rows = self.cursor.fetchall()\n return [dict(row) for row in rows]\n except sqlite3.Error as e:\n raise StorageError(f\"Failed to query records: {e}\")\n\n def execute_sql(self, query: str, params: Optional[Tuple] = None):\n \"\"\"\n Executes a SQL query and returns the cursor.\n\n Args:\n query (str): The SQL query to execute.\n params (Optional[Tuple]): Parameters to substitute into the query.\n\n Returns:\n sqlite3.Cursor: The cursor after executing the query.\n \"\"\"\n cursor = self.connection.cursor()\n try:\n if params:\n cursor.execute(query, params)\n else:\n cursor.execute(query)\n # For SELECT queries, do not commit. For INSERT/UPDATE/DELETE, you may need to commit.\n if query.strip().upper().startswith(\"SELECT\"):\n return cursor\n else:\n self.connection.commit()\n return cursor\n except sqlite3.Error as e:\n print(f\"SQLite query failed: {e}\")\n raise e\n\n def begin_transaction(self) -> None:\n \"\"\"\n Begins a transaction.\n \"\"\"\n self.connection.isolation_level = None\n self.cursor.execute('BEGIN')\n\n def commit_transaction(self) -> None:\n \"\"\"\n Commits the current transaction.\n \"\"\"\n self.connection.commit()\n self.connection.isolation_level = None\n\n def rollback_transaction(self) -> None:\n \"\"\"\n Rolls back the current transaction.\n \"\"\"\n self.connection.rollback()\n self.connection.isolation_level = None\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.__init__","title":"__init__(config)
","text":"Initialize the SQLite database connection.
Parameters:
Name Type Description Default config
Dict[str, Any]
Configuration dictionary.
required Source code in src/aeiva/storage/sqlite/sqlite_database.py
def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the SQLite database connection.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.database = config.get('database', ':memory:')\n self.connection = None\n self.cursor = None\n self.connect()\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.begin_transaction","title":"begin_transaction()
","text":"Begins a transaction.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def begin_transaction(self) -> None:\n \"\"\"\n Begins a transaction.\n \"\"\"\n self.connection.isolation_level = None\n self.cursor.execute('BEGIN')\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.close","title":"close()
","text":"Closes the database connection and releases resources.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def close(self) -> None:\n \"\"\"\n Closes the database connection and releases resources.\n \"\"\"\n if self.cursor:\n self.cursor.close()\n if self.connection:\n self.connection.close()\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.commit_transaction","title":"commit_transaction()
","text":"Commits the current transaction.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def commit_transaction(self) -> None:\n \"\"\"\n Commits the current transaction.\n \"\"\"\n self.connection.commit()\n self.connection.isolation_level = None\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.connect","title":"connect()
","text":"Establishes a connection to the SQLite database.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def connect(self) -> None:\n \"\"\"\n Establishes a connection to the SQLite database.\n \"\"\"\n try:\n self.connection = sqlite3.connect(self.database)\n self.connection.row_factory = sqlite3.Row # To get dict-like rows\n self.cursor = self.connection.cursor()\n # self.connection.execute('PRAGMA foreign_keys = ON') # Enable foreign key support\n except sqlite3.Error as e:\n raise ConnectionError(f\"Failed to connect to SQLite database: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.delete_record","title":"delete_record(table, primary_key)
","text":"Deletes a record from a table.
Parameters:
Name Type Description Default table
str
The name of the table.
required primary_key
Any
The primary key of the record.
required Raises:
Type Description RecordNotFoundError
If the record does not exist.
StorageError
If there is an issue deleting the record.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def delete_record(self, table: str, primary_key: Any) -> None:\n \"\"\"\n Deletes a record from a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue deleting the record.\n \"\"\"\n try:\n sql = f\"DELETE FROM {table} WHERE id = ?\"\n self.cursor.execute(sql, (primary_key,))\n if self.cursor.rowcount == 0:\n self.connection.rollback()\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n self.connection.commit()\n except sqlite3.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to delete record: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.execute_sql","title":"execute_sql(query, params=None)
","text":"Executes a SQL query and returns the cursor.
Parameters:
Name Type Description Default query
str
The SQL query to execute.
required params
Optional[Tuple]
Parameters to substitute into the query.
None
Returns:
Type Description sqlite3.Cursor: The cursor after executing the query.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def execute_sql(self, query: str, params: Optional[Tuple] = None):\n \"\"\"\n Executes a SQL query and returns the cursor.\n\n Args:\n query (str): The SQL query to execute.\n params (Optional[Tuple]): Parameters to substitute into the query.\n\n Returns:\n sqlite3.Cursor: The cursor after executing the query.\n \"\"\"\n cursor = self.connection.cursor()\n try:\n if params:\n cursor.execute(query, params)\n else:\n cursor.execute(query)\n # For SELECT queries, do not commit. For INSERT/UPDATE/DELETE, you may need to commit.\n if query.strip().upper().startswith(\"SELECT\"):\n return cursor\n else:\n self.connection.commit()\n return cursor\n except sqlite3.Error as e:\n print(f\"SQLite query failed: {e}\")\n raise e\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.get_record","title":"get_record(table, primary_key)
","text":"Retrieves a record by its primary key.
Parameters:
Name Type Description Default table
str
The name of the table.
required primary_key
Any
The primary key of the record.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: The retrieved record.
Raises:
Type Description RecordNotFoundError
If the record does not exist.
StorageError
If there is an issue retrieving the record.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def get_record(self, table: str, primary_key: Any) -> Dict[str, Any]:\n \"\"\"\n Retrieves a record by its primary key.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n\n Returns:\n Dict[str, Any]: The retrieved record.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue retrieving the record.\n \"\"\"\n try:\n sql = f\"SELECT * FROM {table} WHERE id = ?\"\n self.cursor.execute(sql, (primary_key,))\n row = self.cursor.fetchone()\n if row is None:\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n return dict(row)\n except sqlite3.Error as e:\n raise StorageError(f\"Failed to get record: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.insert_record","title":"insert_record(table, record)
","text":"Inserts a record into a table.
Parameters:
Name Type Description Default table
str
The name of the table.
required record
Dict[str, Any]
A dictionary representing the record to insert.
required Returns:
Name Type Description Any
Any
The primary key of the inserted record.
Raises:
Type Description StorageError
If there is an issue inserting the record.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def insert_record(self, table: str, record: Dict[str, Any]) -> Any:\n \"\"\"\n Inserts a record into a table.\n\n Args:\n table (str): The name of the table.\n record (Dict[str, Any]): A dictionary representing the record to insert.\n\n Returns:\n Any: The primary key of the inserted record.\n\n Raises:\n StorageError: If there is an issue inserting the record.\n \"\"\"\n try:\n columns = ', '.join(record.keys())\n placeholders = ', '.join('?' for _ in record)\n sql = f\"INSERT INTO {table} ({columns}) VALUES ({placeholders})\"\n values = list(record.values())\n self.cursor.execute(sql, values)\n self.connection.commit()\n return self.cursor.lastrowid\n except sqlite3.IntegrityError as e:\n self.connection.rollback()\n raise StorageError(f\"Integrity error: {e}\")\n except sqlite3.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to insert record: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.query_records","title":"query_records(table, conditions=None, limit=None, offset=None)
","text":"Queries records from a table based on conditions.
Parameters:
Name Type Description Default table
str
The name of the table.
required conditions
Optional[Dict[str, Any]]
Conditions to filter records.
None
limit
Optional[int]
Maximum number of records to return.
None
offset
Optional[int]
Number of records to skip.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of records matching the query.
Raises:
Type Description StorageError
If there is an issue querying records.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def query_records(\n self,\n table: str,\n conditions: Optional[Dict[str, Any]] = None,\n limit: Optional[int] = None,\n offset: Optional[int] = None\n) -> List[Dict[str, Any]]:\n \"\"\"\n Queries records from a table based on conditions.\n\n Args:\n table (str): The name of the table.\n conditions (Optional[Dict[str, Any]]): Conditions to filter records.\n limit (Optional[int]): Maximum number of records to return.\n offset (Optional[int]): Number of records to skip.\n\n Returns:\n List[Dict[str, Any]]: A list of records matching the query.\n\n Raises:\n StorageError: If there is an issue querying records.\n \"\"\"\n try:\n sql = f\"SELECT * FROM {table}\"\n params = []\n if conditions:\n where_clause = ' AND '.join(f\"{key} = ?\" for key in conditions.keys())\n sql += f\" WHERE {where_clause}\"\n params.extend(conditions.values())\n if limit is not None:\n sql += f\" LIMIT {limit}\"\n if offset is not None:\n sql += f\" OFFSET {offset}\"\n self.cursor.execute(sql, params)\n rows = self.cursor.fetchall()\n return [dict(row) for row in rows]\n except sqlite3.Error as e:\n raise StorageError(f\"Failed to query records: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.rollback_transaction","title":"rollback_transaction()
","text":"Rolls back the current transaction.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def rollback_transaction(self) -> None:\n \"\"\"\n Rolls back the current transaction.\n \"\"\"\n self.connection.rollback()\n self.connection.isolation_level = None\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.SQLiteDatabase.update_record","title":"update_record(table, primary_key, updates)
","text":"Updates a record in a table.
Parameters:
Name Type Description Default table
str
The name of the table.
required primary_key
Any
The primary key of the record.
required updates
Dict[str, Any]
A dictionary of fields to update.
required Raises:
Type Description RecordNotFoundError
If the record does not exist.
StorageError
If there is an issue updating the record.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
def update_record(self, table: str, primary_key: Any, updates: Dict[str, Any]) -> None:\n \"\"\"\n Updates a record in a table.\n\n Args:\n table (str): The name of the table.\n primary_key (Any): The primary key of the record.\n updates (Dict[str, Any]): A dictionary of fields to update.\n\n Raises:\n RecordNotFoundError: If the record does not exist.\n StorageError: If there is an issue updating the record.\n \"\"\"\n try:\n set_clause = ', '.join(f\"{key} = ?\" for key in updates.keys())\n sql = f\"UPDATE {table} SET {set_clause} WHERE id = ?\"\n values = list(updates.values()) + [primary_key]\n self.cursor.execute(sql, values)\n if self.cursor.rowcount == 0:\n self.connection.rollback()\n raise RecordNotFoundError(f\"Record with primary key {primary_key} not found in table '{table}'.\")\n self.connection.commit()\n except sqlite3.Error as e:\n self.connection.rollback()\n raise StorageError(f\"Failed to update record: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.sqlite.sqlite_database.StorageError","title":"StorageError
","text":" Bases: Exception
Exception raised when there is a storage-related error in the database.
Source code in src/aeiva/storage/sqlite/sqlite_database.py
class StorageError(Exception):\n \"\"\"Exception raised when there is a storage-related error in the database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.sqlite.test","title":"test
","text":""},{"location":"reference/#src.aeiva.storage.sqlite.test.RecordNotFoundError","title":"RecordNotFoundError
","text":" Bases: Exception
Exception raised when a record is not found in the database.
Source code in src/aeiva/storage/sqlite/test.py
class RecordNotFoundError(Exception):\n \"\"\"Exception raised when a record is not found in the database.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.test","title":"test
","text":""},{"location":"reference/#src.aeiva.storage.test.main","title":"main()
","text":"Main function to run tests for Milvus, Neo4j, and SQLite databases.
Source code in src/aeiva/storage/test.py
def main():\n \"\"\"\n Main function to run tests for Milvus, Neo4j, and SQLite databases.\n \"\"\"\n test_milvus()\n test_neo4j()\n test_sqlite()\n
"},{"location":"reference/#src.aeiva.storage.test.test_milvus","title":"test_milvus()
","text":"Test the DatabaseFactory and DatabaseConfigFactory with Milvus database.
Source code in src/aeiva/storage/test.py
def test_milvus():\n \"\"\"\n Test the DatabaseFactory and DatabaseConfigFactory with Milvus database.\n \"\"\"\n print(\"\\n--- Testing Milvus Database ---\")\n # Create configuration for Milvus\n milvus_config = DatabaseConfigFactory.create(\n 'milvus',\n # uri='tcp://localhost:19530',\n uri='storage/milvus_demo.db',\n collection_name='test_collection',\n embedding_model_dims=128,\n metric_type='COSINE',\n )\n\n # Create Milvus database instance\n milvus_db = DatabaseFactory.create('milvus', milvus_config)\n\n try:\n # Prepare sample data\n vector_dimension = milvus_config.embedding_model_dims\n vectors = [\n [float(i) for i in range(vector_dimension)], # Sample vector 1\n [float(i + 1) for i in range(vector_dimension)], # Sample vector 2\n ]\n payloads = [\n {'name': 'Vector 1', 'description': 'First test vector.'},\n {'name': 'Vector 2', 'description': 'Second test vector.'},\n ]\n ids = [str(uuid.uuid4()), str(uuid.uuid4())] # Generate unique IDs\n\n # Insert vectors into the collection\n milvus_db.insert_vectors(\n collection_name=milvus_config.collection_name,\n vectors=vectors,\n payloads=payloads,\n ids=ids\n )\n logging.info(f\"Inserted vectors with IDs: {ids}\")\n\n # Search for similar vectors\n query_vector = [float(i + 0.5) for i in range(vector_dimension)] # Query vector\n search_results = milvus_db.search_vectors(\n collection_name=milvus_config.collection_name,\n query_vector=query_vector,\n top_k=2\n )\n print(f\"Milvus Search results:\\n{search_results}\")\n\n except Exception as e:\n logging.error(f\"An error occurred while testing Milvus: {e}\")\n finally:\n # Close the connection\n del milvus_db\n
"},{"location":"reference/#src.aeiva.storage.test.test_neo4j","title":"test_neo4j()
","text":"Test the DatabaseFactory and DatabaseConfigFactory with Neo4j database.
Source code in src/aeiva/storage/test.py
def test_neo4j():\n \"\"\"\n Test the DatabaseFactory and DatabaseConfigFactory with Neo4j database.\n \"\"\"\n print(\"\\n--- Testing Neo4j Database ---\")\n # Create configuration for Neo4j\n neo4j_config = DatabaseConfigFactory.create(\n 'neo4j',\n uri='bolt://localhost:7687',\n user='neo4j',\n password='cf57bwP9pcdcEK3', # Replace with your actual password\n database='neo4j',\n encrypted=False,\n )\n\n # Create Neo4j database instance\n neo4j_db = DatabaseFactory.create('neo4j', neo4j_config)\n\n try:\n # Add a node\n node_id = 'node1'\n neo4j_db.add_node(\n node_id=node_id,\n properties={'name': 'Alice', 'age': 30},\n labels=['Person']\n )\n logging.info(f\"Added node with ID: {node_id}\")\n\n # Retrieve the node\n node_data = neo4j_db.get_node(node_id)\n print(f\"Neo4j Node data: {node_data}\")\n\n # Add another node and create a relationship\n node_id2 = 'node2'\n neo4j_db.add_node(\n node_id=node_id2,\n properties={'name': 'Bob', 'age': 25},\n labels=['Person']\n )\n neo4j_db.add_edge(\n source_id=node_id,\n target_id=node_id2,\n relationship='KNOWS',\n properties={'since': 2020}\n )\n logging.info(f\"Added edge between {node_id} and {node_id2}\")\n\n # Get neighbors\n neighbors = neo4j_db.get_neighbors(node_id, relationship='KNOWS', direction='out')\n print(f\"Neo4j Neighbors of {node_id}: {neighbors}\")\n\n except Exception as e:\n logging.error(f\"An error occurred while testing Neo4j: {e}\")\n finally:\n # Close the connection\n neo4j_db.close()\n
"},{"location":"reference/#src.aeiva.storage.test.test_sqlite","title":"test_sqlite()
","text":"Test the DatabaseFactory and DatabaseConfigFactory with SQLite database.
Source code in src/aeiva/storage/test.py
def test_sqlite():\n \"\"\"\n Test the DatabaseFactory and DatabaseConfigFactory with SQLite database.\n \"\"\"\n print(\"\\n--- Testing SQLite Database ---\")\n # Create configuration for SQLite\n sqlite_config = DatabaseConfigFactory.create(\n 'sqlite',\n database='storage/test_database.db' # Use a file-based database for persistence\n )\n\n # Create SQLite database instance\n sqlite_db = DatabaseFactory.create('sqlite', sqlite_config)\n\n try:\n # Create a sample table\n create_table_sql = \"\"\"\n CREATE TABLE IF NOT EXISTS users (\n id INTEGER PRIMARY KEY AUTOINCREMENT,\n name TEXT NOT NULL,\n age INTEGER,\n email TEXT UNIQUE\n );\n \"\"\"\n sqlite_db.execute_sql(create_table_sql)\n logging.info(\"Created table 'users' in SQLite database.\")\n\n # Insert a record\n record = {'name': 'Alice', 'age': 30, 'email': 'alice@example.com'}\n user_id = sqlite_db.insert_record('users', record)\n logging.info(f\"Inserted user with ID: {user_id}\")\n\n # Retrieve the record\n retrieved_record = sqlite_db.get_record('users', user_id)\n print(f\"SQLite Retrieved record: {retrieved_record}\")\n\n # Update the record\n updates = {'age': 31}\n sqlite_db.update_record('users', user_id, updates)\n logging.info(f\"Updated user with ID: {user_id}\")\n\n # Query records\n conditions = {'age': 31}\n users = sqlite_db.query_records('users', conditions)\n print(f\"SQLite Users with age 31: {users}\")\n\n except Exception as e:\n logging.error(f\"An error occurred while testing SQLite: {e}\")\n finally:\n # Close the database connection\n sqlite_db.close()\n
"},{"location":"reference/#src.aeiva.storage.vector_database","title":"vector_database
","text":""},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase","title":"VectorDatabase
","text":" Bases: ABC
Abstract base class for vector storage operations.
Source code in src/aeiva/storage/vector_database.py
class VectorDatabase(ABC):\n \"\"\"\n Abstract base class for vector storage operations.\n \"\"\"\n\n @abstractmethod\n def create_client(\n self,\n uri: str,\n user: Optional[str] = None,\n password: Optional[str] = None,\n db_name: Optional[str] = None,\n token: Optional[str] = None,\n timeout: Optional[float] = None,\n **kwargs\n ) -> None:\n \"\"\"\n Initializes the client connection to the vector store.\n\n Args:\n uri (str): The URI of the vector store instance.\n user (Optional[str]): Username for authentication.\n password (Optional[str]): Password for authentication.\n db_name (Optional[str]): Name of the database.\n token (Optional[str]): Access token for authentication.\n timeout (Optional[float]): Timeout duration for operations.\n **kwargs: Additional implementation-specific parameters.\n\n Raises:\n ConnectionError: If the client fails to connect to the vector store.\n \"\"\"\n pass\n\n @abstractmethod\n def create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use (e.g., 'euclidean', 'cosine').\n\n Raises:\n CollectionAlreadyExistsError: If a collection with the given name already exists.\n StorageError: If there is an issue creating the collection.\n \"\"\"\n pass\n\n @abstractmethod\n def insert_vectors(self, collection_name: str, vectors: List[List[float]], payloads: Optional[List[Dict[str, Any]]] = None, ids: Optional[List[str]] = None) -> None:\n \"\"\"\n Insert vectors into a collection.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n StorageError: If there is an issue inserting the vectors.\n \"\"\"\n pass\n\n @abstractmethod\n def search_vectors(self, collection_name: str, query_vector: List[float], top_k: int = 5, filters: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results, each containing the vector ID, score, and payload.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n StorageError: If there is an issue performing the search.\n \"\"\"\n pass\n\n @abstractmethod\n def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n VectorNotFoundError: If the vector with the specified ID does not exist.\n StorageError: If there is an issue deleting the vector.\n \"\"\"\n pass\n\n @abstractmethod\n def update_vector(self, collection_name: str, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict[str, Any]] = None) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n VectorNotFoundError: If the vector with the specified ID does not exist.\n StorageError: If there is an issue updating the vector.\n \"\"\"\n pass\n\n @abstractmethod\n def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n VectorNotFoundError: If the vector with the specified ID does not exist.\n StorageError: If there is an issue retrieving the vector.\n \"\"\"\n pass\n\n @abstractmethod\n def list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections.\n\n Returns:\n List[str]: A list of collection names.\n\n Raises:\n StorageError: If there is an issue retrieving the collection list.\n \"\"\"\n pass\n\n @abstractmethod\n def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n StorageError: If there is an issue deleting the collection.\n \"\"\"\n pass\n\n @abstractmethod\n def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection, such as vector size and distance metric.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n StorageError: If there is an issue retrieving the collection information.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase.create_client","title":"create_client(uri, user=None, password=None, db_name=None, token=None, timeout=None, **kwargs)
abstractmethod
","text":"Initializes the client connection to the vector store.
Parameters:
Name Type Description Default uri
str
The URI of the vector store instance.
required user
Optional[str]
Username for authentication.
None
password
Optional[str]
Password for authentication.
None
db_name
Optional[str]
Name of the database.
None
token
Optional[str]
Access token for authentication.
None
timeout
Optional[float]
Timeout duration for operations.
None
**kwargs
Additional implementation-specific parameters.
{}
Raises:
Type Description ConnectionError
If the client fails to connect to the vector store.
Source code in src/aeiva/storage/vector_database.py
@abstractmethod\ndef create_client(\n self,\n uri: str,\n user: Optional[str] = None,\n password: Optional[str] = None,\n db_name: Optional[str] = None,\n token: Optional[str] = None,\n timeout: Optional[float] = None,\n **kwargs\n) -> None:\n \"\"\"\n Initializes the client connection to the vector store.\n\n Args:\n uri (str): The URI of the vector store instance.\n user (Optional[str]): Username for authentication.\n password (Optional[str]): Password for authentication.\n db_name (Optional[str]): Name of the database.\n token (Optional[str]): Access token for authentication.\n timeout (Optional[float]): Timeout duration for operations.\n **kwargs: Additional implementation-specific parameters.\n\n Raises:\n ConnectionError: If the client fails to connect to the vector store.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase.create_collection","title":"create_collection(collection_name, vector_size, distance_metric)
abstractmethod
","text":"Create a new vector collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_size
int
The dimensionality of the vectors.
required distance_metric
str
The distance metric to use (e.g., 'euclidean', 'cosine').
required Raises:
Type Description CollectionAlreadyExistsError
If a collection with the given name already exists.
StorageError
If there is an issue creating the collection.
Source code in src/aeiva/storage/vector_database.py
@abstractmethod\ndef create_collection(self, collection_name: str, vector_size: int, distance_metric: str) -> None:\n \"\"\"\n Create a new vector collection.\n\n Args:\n collection_name (str): The name of the collection.\n vector_size (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use (e.g., 'euclidean', 'cosine').\n\n Raises:\n CollectionAlreadyExistsError: If a collection with the given name already exists.\n StorageError: If there is an issue creating the collection.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase.delete_collection","title":"delete_collection(collection_name)
abstractmethod
","text":"Delete an entire vector collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection to delete.
required Raises:
Type Description CollectionNotFoundError
If the specified collection does not exist.
StorageError
If there is an issue deleting the collection.
Source code in src/aeiva/storage/vector_database.py
@abstractmethod\ndef delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire vector collection.\n\n Args:\n collection_name (str): The name of the collection to delete.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n StorageError: If there is an issue deleting the collection.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase.delete_vector","title":"delete_vector(collection_name, vector_id)
abstractmethod
","text":"Delete a vector from a collection by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to delete.
required Raises:
Type Description CollectionNotFoundError
If the specified collection does not exist.
VectorNotFoundError
If the vector with the specified ID does not exist.
StorageError
If there is an issue deleting the vector.
Source code in src/aeiva/storage/vector_database.py
@abstractmethod\ndef delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from a collection by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to delete.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n VectorNotFoundError: If the vector with the specified ID does not exist.\n StorageError: If there is an issue deleting the vector.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase.get_collection_info","title":"get_collection_info(collection_name)
abstractmethod
","text":"Get information about a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: Information about the collection, such as vector size and distance metric.
Raises:
Type Description CollectionNotFoundError
If the specified collection does not exist.
StorageError
If there is an issue retrieving the collection information.
Source code in src/aeiva/storage/vector_database.py
@abstractmethod\ndef get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection.\n\n Args:\n collection_name (str): The name of the collection.\n\n Returns:\n Dict[str, Any]: Information about the collection, such as vector size and distance metric.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n StorageError: If there is an issue retrieving the collection information.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase.get_vector","title":"get_vector(collection_name, vector_id)
abstractmethod
","text":"Retrieve a vector by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the vector data and payload.
Raises:
Type Description CollectionNotFoundError
If the specified collection does not exist.
VectorNotFoundError
If the vector with the specified ID does not exist.
StorageError
If there is an issue retrieving the vector.
Source code in src/aeiva/storage/vector_database.py
@abstractmethod\ndef get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n VectorNotFoundError: If the vector with the specified ID does not exist.\n StorageError: If there is an issue retrieving the vector.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase.insert_vectors","title":"insert_vectors(collection_name, vectors, payloads=None, ids=None)
abstractmethod
","text":"Insert vectors into a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vectors
List[List[float]]
A list of vectors to insert.
required payloads
Optional[List[Dict[str, Any]]]
Optional metadata associated with each vector.
None
ids
Optional[List[str]]
Optional unique identifiers for each vector.
None
Raises:
Type Description CollectionNotFoundError
If the specified collection does not exist.
StorageError
If there is an issue inserting the vectors.
Source code in src/aeiva/storage/vector_database.py
@abstractmethod\ndef insert_vectors(self, collection_name: str, vectors: List[List[float]], payloads: Optional[List[Dict[str, Any]]] = None, ids: Optional[List[str]] = None) -> None:\n \"\"\"\n Insert vectors into a collection.\n\n Args:\n collection_name (str): The name of the collection.\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n StorageError: If there is an issue inserting the vectors.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase.list_collections","title":"list_collections()
abstractmethod
","text":"List all available vector collections.
Returns:
Type Description List[str]
List[str]: A list of collection names.
Raises:
Type Description StorageError
If there is an issue retrieving the collection list.
Source code in src/aeiva/storage/vector_database.py
@abstractmethod\ndef list_collections(self) -> List[str]:\n \"\"\"\n List all available vector collections.\n\n Returns:\n List[str]: A list of collection names.\n\n Raises:\n StorageError: If there is an issue retrieving the collection list.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase.search_vectors","title":"search_vectors(collection_name, query_vector, top_k=5, filters=None)
abstractmethod
","text":"Search for similar vectors in a collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required query_vector
List[float]
The vector to search with.
required top_k
int
The number of top results to return.
5
filters
Optional[Dict[str, Any]]
Optional filters to apply to the search.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of search results, each containing the vector ID, score, and payload.
Raises:
Type Description CollectionNotFoundError
If the specified collection does not exist.
StorageError
If there is an issue performing the search.
Source code in src/aeiva/storage/vector_database.py
@abstractmethod\ndef search_vectors(self, collection_name: str, query_vector: List[float], top_k: int = 5, filters: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in a collection.\n\n Args:\n collection_name (str): The name of the collection.\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results, each containing the vector ID, score, and payload.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n StorageError: If there is an issue performing the search.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.vector_database.VectorDatabase.update_vector","title":"update_vector(collection_name, vector_id, vector=None, payload=None)
abstractmethod
","text":"Update a vector's data or payload.
Parameters:
Name Type Description Default collection_name
str
The name of the collection.
required vector_id
str
The unique identifier of the vector to update.
required vector
Optional[List[float]]
The new vector data.
None
payload
Optional[Dict[str, Any]]
The new payload data.
None
Raises:
Type Description CollectionNotFoundError
If the specified collection does not exist.
VectorNotFoundError
If the vector with the specified ID does not exist.
StorageError
If there is an issue updating the vector.
Source code in src/aeiva/storage/vector_database.py
@abstractmethod\ndef update_vector(self, collection_name: str, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict[str, Any]] = None) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection.\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n\n Raises:\n CollectionNotFoundError: If the specified collection does not exist.\n VectorNotFoundError: If the vector with the specified ID does not exist.\n StorageError: If there is an issue updating the vector.\n \"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.storage.weaviate","title":"weaviate
","text":""},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_config","title":"weaviate_config
","text":""},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_config.WeaviateConfig","title":"WeaviateConfig
dataclass
","text":" Bases: BaseConfig
Configuration for Weaviate vector database.
Source code in src/aeiva/storage/weaviate/weaviate_config.py
@dataclass\nclass WeaviateConfig(BaseConfig):\n \"\"\"\n Configuration for Weaviate vector database.\n \"\"\"\n\n url: str = field(\n default='http://localhost:8080',\n metadata={\"help\": \"URL of the Weaviate instance (e.g., 'http://localhost:8080').\"}\n )\n api_key: Optional[str] = field(\n default=None,\n metadata={\"help\": \"API key for Weaviate authentication (if required).\"}\n )\n auth_client_secret: Optional[Dict[str, Any]] = field(\n default=None,\n metadata={\"help\": \"Authentication client secret for Weaviate (if using OIDC).\"}\n )\n timeout_config: Optional[Tuple[float, float]] = field(\n default=(2, 20),\n metadata={\"help\": \"Timeout configuration for requests (connect timeout, read timeout).\"}\n )\n additional_headers: Optional[Dict[str, str]] = field(\n default=None,\n metadata={\"help\": \"Additional headers to include in requests to Weaviate.\"}\n )\n embedding_model: Optional[str] = field(\n default=None,\n metadata={\"help\": \"Name of the embedding model used (if required).\"}\n )\n index_name: str = field(\n default='MyIndex',\n metadata={\"help\": \"Name of the Weaviate index (class).\"}\n )\n vector_dim: int = field(\n default=512,\n metadata={\"help\": \"Dimensionality of the vectors stored in Weaviate.\"}\n )\n distance_metric: str = field(\n default='cosine',\n metadata={\"help\": \"Distance metric to use (e.g., 'cosine', 'l2-squared', 'dot').\"}\n )\n\n def __post_init__(self):\n super().__post_init__()\n if not self.url:\n raise ValueError(\"The 'url' parameter is required for Weaviate configuration.\")\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database","title":"weaviate_database
","text":""},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase","title":"WeaviateDatabase
","text":" Bases: VectorDatabase
Concrete implementation of VectorStoreBase using Weaviate.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
class WeaviateDatabase(VectorDatabase):\n \"\"\"\n Concrete implementation of VectorStoreBase using Weaviate.\n \"\"\"\n\n def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the Weaviate vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.url = config.get('url', 'http://localhost:8080')\n self.api_key = config.get('api_key')\n self.auth_client_secret = config.get('auth_client_secret')\n self.timeout_config = config.get('timeout_config', (2, 20))\n self.additional_headers = config.get('additional_headers')\n self.embedding_model = config.get('embedding_model')\n self.index_name = config.get('index_name', 'MyIndex')\n self.vector_dim = config.get('vector_dim', 512)\n self.distance_metric = config.get('distance_metric', 'cosine')\n\n self.client = self.create_client()\n self.create_index(\n index_name=self.index_name,\n vector_dim=self.vector_dim,\n distance_metric=self.distance_metric\n )\n\n def create_client(self) -> Client:\n \"\"\"\n Initializes the client connection to the Weaviate vector store.\n\n Returns:\n Client: The Weaviate client instance.\n\n Raises:\n ConnectionError: If the client fails to connect to the Weaviate instance.\n \"\"\"\n try:\n if self.api_key:\n auth_config = AuthApiKey(api_key=self.api_key)\n elif self.auth_client_secret:\n auth_config = AuthClientPassword(**self.auth_client_secret)\n else:\n auth_config = None\n\n client = weaviate.Client(\n url=self.url,\n auth_client_secret=auth_config,\n timeout_config=self.timeout_config,\n additional_headers=self.additional_headers\n )\n\n if not client.is_ready():\n raise ConnectionError(f\"Weaviate at {self.url} is not ready.\")\n\n logger.info(f\"Connected to Weaviate at {self.url}.\")\n return client\n except Exception as e:\n logger.error(f\"Failed to connect to Weaviate: {e}\")\n raise ConnectionError(f\"Failed to connect to Weaviate: {e}\")\n\n def create_index(self, index_name: str, vector_dim: int, distance_metric: str) -> None:\n \"\"\"\n Create a new index (class) in Weaviate.\n\n Args:\n index_name (str): The name of the index.\n vector_dim (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use.\n\n Raises:\n WeaviateException: If there is an issue creating the index.\n \"\"\"\n try:\n if self.client.schema.contains(index_name):\n logger.info(f\"Index {index_name} already exists. Skipping creation.\")\n return\n\n class_obj = {\n \"class\": index_name,\n \"vectorizer\": \"none\",\n \"vectorIndexType\": \"hnsw\",\n \"vectorIndexConfig\": {\n \"distance\": distance_metric\n },\n \"properties\": [\n {\n \"name\": \"id\",\n \"dataType\": [\"string\"],\n \"description\": \"Unique identifier\",\n },\n {\n \"name\": \"payload\",\n \"dataType\": [\"blob\"],\n \"description\": \"Payload data\",\n },\n ]\n }\n\n self.client.schema.create_class(class_obj)\n logger.info(f\"Index {index_name} created successfully.\")\n except WeaviateException as e:\n logger.error(f\"Failed to create index: {e}\")\n raise\n\n def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n ) -> None:\n \"\"\"\n Insert vectors into the collection.\n\n Args:\n collection_name (str): The name of the collection (index).\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n\n Raises:\n ValueError: If input data is invalid.\n WeaviateException: If there is an issue inserting vectors.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n if ids is None:\n raise ValueError(\"Weaviate requires IDs to be provided for each vector.\")\n\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n try:\n with self.client.batch(batch_size=100) as batch:\n for id_, vector, payload in zip(ids, vectors, payloads):\n data_object = {\n \"id\": id_,\n \"payload\": payload\n }\n batch.add_data_object(\n data_object=data_object,\n class_name=collection_name,\n vector=vector\n )\n logger.info(f\"Inserted {len(vectors)} vectors into index {collection_name}.\")\n except WeaviateException as e:\n logger.error(f\"Failed to insert vectors: {e}\")\n raise\n\n def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n ) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in the collection.\n\n Args:\n collection_name (str): The name of the collection (index).\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n\n Raises:\n ValueError: If collection name does not match.\n WeaviateException: If there is an issue performing the search.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n try:\n near_vector = {\n \"vector\": query_vector,\n }\n\n where_filter = self._build_filters(filters)\n\n result = self.client.query.get(\n class_name=collection_name,\n properties=[\"id\", \"payload\"]\n ).with_near_vector(near_vector).with_where(where_filter).with_limit(top_k).do()\n\n output = []\n for item in result[\"data\"][\"Get\"][collection_name]:\n result_item = {\n \"id\": item[\"id\"],\n \"score\": item[\"_additional\"][\"certainty\"], # or distance\n \"payload\": item[\"payload\"]\n }\n output.append(result_item)\n return output\n except WeaviateException as e:\n logger.error(f\"Failed to search vectors: {e}\")\n raise\n\n def _build_filters(self, filters: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:\n \"\"\"\n Build a Weaviate where filter from a dictionary.\n\n Args:\n filters (Optional[Dict[str, Any]]): Filters to apply.\n\n Returns:\n Optional[Dict[str, Any]]: A Weaviate where filter.\n \"\"\"\n if not filters:\n return None\n\n conditions = []\n for key, value in filters.items():\n condition = {\n \"path\": [key],\n \"operator\": \"Equal\",\n \"valueString\": value if isinstance(value, str) else None,\n \"valueInt\": value if isinstance(value, int) else None,\n \"valueBoolean\": value if isinstance(value, bool) else None,\n \"valueNumber\": value if isinstance(value, float) else None,\n }\n conditions.append(condition)\n\n where_filter = {\n \"operator\": \"And\",\n \"operands\": conditions\n }\n\n return where_filter\n\n def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from the collection by its ID.\n\n Args:\n collection_name (str): The name of the collection (index).\n vector_id (str): The unique identifier of the vector to delete.\n\n Raises:\n ValueError: If collection name does not match.\n WeaviateException: If there is an issue deleting the vector.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n try:\n self.client.data_object.delete(\n uuid=vector_id,\n class_name=collection_name\n )\n logger.info(f\"Deleted vector with ID {vector_id} from index {collection_name}.\")\n except WeaviateException as e:\n logger.error(f\"Failed to delete vector: {e}\")\n raise\n\n def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n ) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection (index).\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n\n Raises:\n ValueError: If collection name does not match.\n WeaviateException: If there is an issue updating the vector.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n try:\n data_object = {}\n if payload is not None:\n data_object[\"payload\"] = payload\n\n self.client.data_object.update(\n data_object=data_object,\n class_name=collection_name,\n uuid=vector_id,\n vector=vector\n )\n logger.info(f\"Updated vector with ID {vector_id} in index {collection_name}.\")\n except WeaviateException as e:\n logger.error(f\"Failed to update vector: {e}\")\n raise\n\n def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection (index).\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n\n Raises:\n ValueError: If collection name does not match.\n KeyError: If the vector is not found.\n WeaviateException: If there is an issue retrieving the vector.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n try:\n result = self.client.data_object.get_by_id(\n uuid=vector_id,\n class_name=collection_name,\n additional_properties=[\"vector\"]\n )\n if result is None:\n raise KeyError(f\"Vector with ID {vector_id} not found in index {collection_name}.\")\n\n vector_data = {\n \"id\": result[\"id\"],\n \"vector\": result[\"vector\"],\n \"payload\": result[\"payload\"]\n }\n return vector_data\n except WeaviateException as e:\n logger.error(f\"Failed to retrieve vector: {e}\")\n raise\n\n def list_collections(self) -> List[str]:\n \"\"\"\n List all available indexes (classes).\n\n Returns:\n List[str]: A list of index names.\n \"\"\"\n try:\n schema = self.client.schema.get()\n return [clazz[\"class\"] for clazz in schema[\"classes\"]]\n except WeaviateException as e:\n logger.error(f\"Failed to list collections: {e}\")\n raise\n\n def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire index (class).\n\n Args:\n collection_name (str): The name of the collection (index) to delete.\n\n Raises:\n WeaviateException: If there is an issue deleting the collection.\n \"\"\"\n try:\n self.client.schema.delete_class(collection_name)\n logger.info(f\"Deleted index {collection_name}.\")\n except WeaviateException as e:\n logger.error(f\"Failed to delete collection: {e}\")\n raise\n\n def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection (index).\n\n Args:\n collection_name (str): The name of the collection (index).\n\n Returns:\n Dict[str, Any]: Information about the collection.\n\n Raises:\n WeaviateException: If there is an issue retrieving the collection info.\n \"\"\"\n try:\n class_schema = self.client.schema.get(class_name=collection_name)\n return class_schema\n except WeaviateException as e:\n logger.error(f\"Failed to get collection info: {e}\")\n raise\n\n def __del__(self):\n \"\"\"Clean up resources.\"\"\"\n if hasattr(self, 'client'):\n self.client.close()\n logger.info(\"Closed connection to Weaviate.\")\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.__del__","title":"__del__()
","text":"Clean up resources.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def __del__(self):\n \"\"\"Clean up resources.\"\"\"\n if hasattr(self, 'client'):\n self.client.close()\n logger.info(\"Closed connection to Weaviate.\")\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.__init__","title":"__init__(config)
","text":"Initialize the Weaviate vector store.
Parameters:
Name Type Description Default config
Dict[str, Any]
Configuration dictionary.
required Source code in src/aeiva/storage/weaviate/weaviate_database.py
def __init__(self, config: Dict[str, Any]) -> None:\n \"\"\"\n Initialize the Weaviate vector store.\n\n Args:\n config (Dict[str, Any]): Configuration dictionary.\n \"\"\"\n self.config = config\n self.url = config.get('url', 'http://localhost:8080')\n self.api_key = config.get('api_key')\n self.auth_client_secret = config.get('auth_client_secret')\n self.timeout_config = config.get('timeout_config', (2, 20))\n self.additional_headers = config.get('additional_headers')\n self.embedding_model = config.get('embedding_model')\n self.index_name = config.get('index_name', 'MyIndex')\n self.vector_dim = config.get('vector_dim', 512)\n self.distance_metric = config.get('distance_metric', 'cosine')\n\n self.client = self.create_client()\n self.create_index(\n index_name=self.index_name,\n vector_dim=self.vector_dim,\n distance_metric=self.distance_metric\n )\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.create_client","title":"create_client()
","text":"Initializes the client connection to the Weaviate vector store.
Returns:
Name Type Description Client
Client
The Weaviate client instance.
Raises:
Type Description ConnectionError
If the client fails to connect to the Weaviate instance.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def create_client(self) -> Client:\n \"\"\"\n Initializes the client connection to the Weaviate vector store.\n\n Returns:\n Client: The Weaviate client instance.\n\n Raises:\n ConnectionError: If the client fails to connect to the Weaviate instance.\n \"\"\"\n try:\n if self.api_key:\n auth_config = AuthApiKey(api_key=self.api_key)\n elif self.auth_client_secret:\n auth_config = AuthClientPassword(**self.auth_client_secret)\n else:\n auth_config = None\n\n client = weaviate.Client(\n url=self.url,\n auth_client_secret=auth_config,\n timeout_config=self.timeout_config,\n additional_headers=self.additional_headers\n )\n\n if not client.is_ready():\n raise ConnectionError(f\"Weaviate at {self.url} is not ready.\")\n\n logger.info(f\"Connected to Weaviate at {self.url}.\")\n return client\n except Exception as e:\n logger.error(f\"Failed to connect to Weaviate: {e}\")\n raise ConnectionError(f\"Failed to connect to Weaviate: {e}\")\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.create_index","title":"create_index(index_name, vector_dim, distance_metric)
","text":"Create a new index (class) in Weaviate.
Parameters:
Name Type Description Default index_name
str
The name of the index.
required vector_dim
int
The dimensionality of the vectors.
required distance_metric
str
The distance metric to use.
required Raises:
Type Description WeaviateException
If there is an issue creating the index.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def create_index(self, index_name: str, vector_dim: int, distance_metric: str) -> None:\n \"\"\"\n Create a new index (class) in Weaviate.\n\n Args:\n index_name (str): The name of the index.\n vector_dim (int): The dimensionality of the vectors.\n distance_metric (str): The distance metric to use.\n\n Raises:\n WeaviateException: If there is an issue creating the index.\n \"\"\"\n try:\n if self.client.schema.contains(index_name):\n logger.info(f\"Index {index_name} already exists. Skipping creation.\")\n return\n\n class_obj = {\n \"class\": index_name,\n \"vectorizer\": \"none\",\n \"vectorIndexType\": \"hnsw\",\n \"vectorIndexConfig\": {\n \"distance\": distance_metric\n },\n \"properties\": [\n {\n \"name\": \"id\",\n \"dataType\": [\"string\"],\n \"description\": \"Unique identifier\",\n },\n {\n \"name\": \"payload\",\n \"dataType\": [\"blob\"],\n \"description\": \"Payload data\",\n },\n ]\n }\n\n self.client.schema.create_class(class_obj)\n logger.info(f\"Index {index_name} created successfully.\")\n except WeaviateException as e:\n logger.error(f\"Failed to create index: {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.delete_collection","title":"delete_collection(collection_name)
","text":"Delete an entire index (class).
Parameters:
Name Type Description Default collection_name
str
The name of the collection (index) to delete.
required Raises:
Type Description WeaviateException
If there is an issue deleting the collection.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def delete_collection(self, collection_name: str) -> None:\n \"\"\"\n Delete an entire index (class).\n\n Args:\n collection_name (str): The name of the collection (index) to delete.\n\n Raises:\n WeaviateException: If there is an issue deleting the collection.\n \"\"\"\n try:\n self.client.schema.delete_class(collection_name)\n logger.info(f\"Deleted index {collection_name}.\")\n except WeaviateException as e:\n logger.error(f\"Failed to delete collection: {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.delete_vector","title":"delete_vector(collection_name, vector_id)
","text":"Delete a vector from the collection by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection (index).
required vector_id
str
The unique identifier of the vector to delete.
required Raises:
Type Description ValueError
If collection name does not match.
WeaviateException
If there is an issue deleting the vector.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def delete_vector(self, collection_name: str, vector_id: str) -> None:\n \"\"\"\n Delete a vector from the collection by its ID.\n\n Args:\n collection_name (str): The name of the collection (index).\n vector_id (str): The unique identifier of the vector to delete.\n\n Raises:\n ValueError: If collection name does not match.\n WeaviateException: If there is an issue deleting the vector.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n try:\n self.client.data_object.delete(\n uuid=vector_id,\n class_name=collection_name\n )\n logger.info(f\"Deleted vector with ID {vector_id} from index {collection_name}.\")\n except WeaviateException as e:\n logger.error(f\"Failed to delete vector: {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.get_collection_info","title":"get_collection_info(collection_name)
","text":"Get information about a collection (index).
Parameters:
Name Type Description Default collection_name
str
The name of the collection (index).
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: Information about the collection.
Raises:
Type Description WeaviateException
If there is an issue retrieving the collection info.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:\n \"\"\"\n Get information about a collection (index).\n\n Args:\n collection_name (str): The name of the collection (index).\n\n Returns:\n Dict[str, Any]: Information about the collection.\n\n Raises:\n WeaviateException: If there is an issue retrieving the collection info.\n \"\"\"\n try:\n class_schema = self.client.schema.get(class_name=collection_name)\n return class_schema\n except WeaviateException as e:\n logger.error(f\"Failed to get collection info: {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.get_vector","title":"get_vector(collection_name, vector_id)
","text":"Retrieve a vector by its ID.
Parameters:
Name Type Description Default collection_name
str
The name of the collection (index).
required vector_id
str
The unique identifier of the vector.
required Returns:
Type Description Dict[str, Any]
Dict[str, Any]: A dictionary containing the vector data and payload.
Raises:
Type Description ValueError
If collection name does not match.
KeyError
If the vector is not found.
WeaviateException
If there is an issue retrieving the vector.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def get_vector(self, collection_name: str, vector_id: str) -> Dict[str, Any]:\n \"\"\"\n Retrieve a vector by its ID.\n\n Args:\n collection_name (str): The name of the collection (index).\n vector_id (str): The unique identifier of the vector.\n\n Returns:\n Dict[str, Any]: A dictionary containing the vector data and payload.\n\n Raises:\n ValueError: If collection name does not match.\n KeyError: If the vector is not found.\n WeaviateException: If there is an issue retrieving the vector.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n try:\n result = self.client.data_object.get_by_id(\n uuid=vector_id,\n class_name=collection_name,\n additional_properties=[\"vector\"]\n )\n if result is None:\n raise KeyError(f\"Vector with ID {vector_id} not found in index {collection_name}.\")\n\n vector_data = {\n \"id\": result[\"id\"],\n \"vector\": result[\"vector\"],\n \"payload\": result[\"payload\"]\n }\n return vector_data\n except WeaviateException as e:\n logger.error(f\"Failed to retrieve vector: {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.insert_vectors","title":"insert_vectors(collection_name, vectors, payloads=None, ids=None)
","text":"Insert vectors into the collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection (index).
required vectors
List[List[float]]
A list of vectors to insert.
required payloads
Optional[List[Dict[str, Any]]]
Optional metadata associated with each vector.
None
ids
Optional[List[str]]
Optional unique identifiers for each vector.
None
Raises:
Type Description ValueError
If input data is invalid.
WeaviateException
If there is an issue inserting vectors.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def insert_vectors(\n self,\n collection_name: str,\n vectors: List[List[float]],\n payloads: Optional[List[Dict[str, Any]]] = None,\n ids: Optional[List[str]] = None\n) -> None:\n \"\"\"\n Insert vectors into the collection.\n\n Args:\n collection_name (str): The name of the collection (index).\n vectors (List[List[float]]): A list of vectors to insert.\n payloads (Optional[List[Dict[str, Any]]]): Optional metadata associated with each vector.\n ids (Optional[List[str]]): Optional unique identifiers for each vector.\n\n Raises:\n ValueError: If input data is invalid.\n WeaviateException: If there is an issue inserting vectors.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n if ids is None:\n raise ValueError(\"Weaviate requires IDs to be provided for each vector.\")\n\n if payloads is None:\n payloads = [{} for _ in range(len(vectors))]\n\n if not (len(ids) == len(vectors) == len(payloads)):\n raise ValueError(\"Lengths of ids, vectors, and payloads must be equal.\")\n\n try:\n with self.client.batch(batch_size=100) as batch:\n for id_, vector, payload in zip(ids, vectors, payloads):\n data_object = {\n \"id\": id_,\n \"payload\": payload\n }\n batch.add_data_object(\n data_object=data_object,\n class_name=collection_name,\n vector=vector\n )\n logger.info(f\"Inserted {len(vectors)} vectors into index {collection_name}.\")\n except WeaviateException as e:\n logger.error(f\"Failed to insert vectors: {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.list_collections","title":"list_collections()
","text":"List all available indexes (classes).
Returns:
Type Description List[str]
List[str]: A list of index names.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def list_collections(self) -> List[str]:\n \"\"\"\n List all available indexes (classes).\n\n Returns:\n List[str]: A list of index names.\n \"\"\"\n try:\n schema = self.client.schema.get()\n return [clazz[\"class\"] for clazz in schema[\"classes\"]]\n except WeaviateException as e:\n logger.error(f\"Failed to list collections: {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.search_vectors","title":"search_vectors(collection_name, query_vector, top_k=5, filters=None)
","text":"Search for similar vectors in the collection.
Parameters:
Name Type Description Default collection_name
str
The name of the collection (index).
required query_vector
List[float]
The vector to search with.
required top_k
int
The number of top results to return.
5
filters
Optional[Dict[str, Any]]
Optional filters to apply to the search.
None
Returns:
Type Description List[Dict[str, Any]]
List[Dict[str, Any]]: A list of search results.
Raises:
Type Description ValueError
If collection name does not match.
WeaviateException
If there is an issue performing the search.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def search_vectors(\n self,\n collection_name: str,\n query_vector: List[float],\n top_k: int = 5,\n filters: Optional[Dict[str, Any]] = None\n) -> List[Dict[str, Any]]:\n \"\"\"\n Search for similar vectors in the collection.\n\n Args:\n collection_name (str): The name of the collection (index).\n query_vector (List[float]): The vector to search with.\n top_k (int): The number of top results to return.\n filters (Optional[Dict[str, Any]]): Optional filters to apply to the search.\n\n Returns:\n List[Dict[str, Any]]: A list of search results.\n\n Raises:\n ValueError: If collection name does not match.\n WeaviateException: If there is an issue performing the search.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n try:\n near_vector = {\n \"vector\": query_vector,\n }\n\n where_filter = self._build_filters(filters)\n\n result = self.client.query.get(\n class_name=collection_name,\n properties=[\"id\", \"payload\"]\n ).with_near_vector(near_vector).with_where(where_filter).with_limit(top_k).do()\n\n output = []\n for item in result[\"data\"][\"Get\"][collection_name]:\n result_item = {\n \"id\": item[\"id\"],\n \"score\": item[\"_additional\"][\"certainty\"], # or distance\n \"payload\": item[\"payload\"]\n }\n output.append(result_item)\n return output\n except WeaviateException as e:\n logger.error(f\"Failed to search vectors: {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.storage.weaviate.weaviate_database.WeaviateDatabase.update_vector","title":"update_vector(collection_name, vector_id, vector=None, payload=None)
","text":"Update a vector's data or payload.
Parameters:
Name Type Description Default collection_name
str
The name of the collection (index).
required vector_id
str
The unique identifier of the vector to update.
required vector
Optional[List[float]]
The new vector data.
None
payload
Optional[Dict[str, Any]]
The new payload data.
None
Raises:
Type Description ValueError
If collection name does not match.
WeaviateException
If there is an issue updating the vector.
Source code in src/aeiva/storage/weaviate/weaviate_database.py
def update_vector(\n self,\n collection_name: str,\n vector_id: str,\n vector: Optional[List[float]] = None,\n payload: Optional[Dict[str, Any]] = None\n) -> None:\n \"\"\"\n Update a vector's data or payload.\n\n Args:\n collection_name (str): The name of the collection (index).\n vector_id (str): The unique identifier of the vector to update.\n vector (Optional[List[float]]): The new vector data.\n payload (Optional[Dict[str, Any]]): The new payload data.\n\n Raises:\n ValueError: If collection name does not match.\n WeaviateException: If there is an issue updating the vector.\n \"\"\"\n if collection_name != self.index_name:\n raise ValueError(\"Collection name does not match initialized index name.\")\n\n try:\n data_object = {}\n if payload is not None:\n data_object[\"payload\"] = payload\n\n self.client.data_object.update(\n data_object=data_object,\n class_name=collection_name,\n uuid=vector_id,\n vector=vector\n )\n logger.info(f\"Updated vector with ID {vector_id} in index {collection_name}.\")\n except WeaviateException as e:\n logger.error(f\"Failed to update vector: {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.tool","title":"tool
","text":""},{"location":"reference/#src.aeiva.tool.api_server","title":"api_server
","text":""},{"location":"reference/#src.aeiva.tool.api_server.call_api_action","title":"call_api_action(api_name, action_name, request)
async
","text":"Endpoint to dynamically call an action within a specified API.
Parameters:
Name Type Description Default api_name
str
The name of the API.
required action_name
str
The name of the action/function to execute.
required request
Request
The incoming HTTP request.
required Returns:
Name Type Description dict
The result of the action or an error message.
Source code in src/aeiva/tool/api_server.py
@app.get(\"/api/{api_name}/{action_name}\")\nasync def call_api_action(api_name: str, action_name: str, request: Request):\n \"\"\"\n Endpoint to dynamically call an action within a specified API.\n\n Args:\n api_name (str): The name of the API.\n action_name (str): The name of the action/function to execute.\n request (Request): The incoming HTTP request.\n\n Returns:\n dict: The result of the action or an error message.\n \"\"\"\n try:\n logger.info(f\"Starting call_api_action for API '{api_name}', Action '{action_name}'\")\n\n # Load the API module\n module = load_api_module(api_name)\n\n # Retrieve the action function\n try:\n action = getattr(module, action_name)\n logger.info(f\"Retrieved action '{action_name}' from API '{api_name}'\")\n except AttributeError:\n logger.error(f\"Action '{action_name}' not found in API '{api_name}'\")\n raise HTTPException(status_code=404, detail=f\"Action '{action_name}' not found in API '{api_name}'\")\n\n # Extract parameters based on request method\n params = {}\n if request.method in [\"POST\", \"PUT\", \"PATCH\"]:\n try:\n params = await request.json()\n logger.info(f\"Received JSON payload: {params}\")\n except json.JSONDecodeError:\n logger.error(\"Invalid JSON payload\")\n raise HTTPException(status_code=400, detail=\"Invalid JSON payload\")\n else:\n # For GET requests, extract query parameters\n params = dict(request.query_params)\n logger.info(f\"Received query parameters: {params}\")\n\n # Get the function signature\n sig = signature(action)\n logger.info(f\"Function signature for '{action_name}': {sig}\")\n\n # Prepare to collect converted parameters\n converted_params = {}\n\n for param_name, param in sig.parameters.items():\n if param_name in params:\n value = params[param_name]\n param_type = param.annotation if param.annotation != Parameter.empty else str\n try:\n if param_type == bool:\n # Convert to boolean\n if isinstance(value, bool):\n converted_value = value\n elif isinstance(value, str):\n converted_value = value.lower() in (\"true\", \"1\", \"yes\")\n else:\n converted_value = bool(value)\n elif param_type in [int, float, str]:\n converted_value = param_type(value)\n elif param_type == list or param_type == dict:\n converted_value = json.loads(value)\n else:\n # For more complex types, assume Pydantic models or custom parsing\n converted_value = param_type(value)\n converted_params[param_name] = converted_value\n logger.debug(f\"Converted parameter '{param_name}': {converted_value} (Type: {param_type})\")\n except (ValueError, json.JSONDecodeError, TypeError) as e:\n logger.error(f\"Invalid value for parameter '{param_name}': {value} ({e})\")\n raise HTTPException(\n status_code=400,\n detail=f\"Invalid value for parameter '{param_name}': {value}. Expected type {param_type.__name__}.\"\n )\n else:\n if param.default == Parameter.empty:\n logger.error(f\"Missing required parameter: {param_name}\")\n raise HTTPException(status_code=400, detail=f\"Missing required parameter: {param_name}\")\n else:\n # Use default value\n converted_params[param_name] = param.default\n logger.debug(f\"Using default value for parameter '{param_name}': {param.default}\")\n\n # Determine if the action is asynchronous\n if asyncio.iscoroutinefunction(action):\n logger.info(f\"Action '{action_name}' is asynchronous. Awaiting execution.\")\n result = await action(**converted_params)\n else:\n logger.info(f\"Action '{action_name}' is synchronous. Executing directly.\")\n result = action(**converted_params)\n\n logger.info(f\"Action '{action_name}' executed successfully with result: {result}\")\n return {\"result\": result}\n\n except FileNotFoundError as e:\n logger.error(f\"API module not found: {e}\")\n raise HTTPException(status_code=404, detail=str(e))\n except HTTPException as he:\n # Re-raise HTTP exceptions to be handled by FastAPI\n raise he\n except Exception as e:\n logger.error(f\"Unhandled exception in call_api_action: {e}\", exc_info=True)\n raise HTTPException(status_code=500, detail=\"Internal Server Error\")\n
"},{"location":"reference/#src.aeiva.tool.api_server.load_api_module","title":"load_api_module(api_name)
","text":"Dynamically load the API module for the given api_name.
Parameters:
Name Type Description Default api_name
str
The name of the API.
required Returns:
Name Type Description module
The loaded API module.
Raises:
Type Description FileNotFoundError
If the API module does not exist.
ImportError
If the module cannot be imported.
Source code in src/aeiva/tool/api_server.py
def load_api_module(api_name: str):\n \"\"\"\n Dynamically load the API module for the given api_name.\n\n Args:\n api_name (str): The name of the API.\n\n Returns:\n module: The loaded API module.\n\n Raises:\n FileNotFoundError: If the API module does not exist.\n ImportError: If the module cannot be imported.\n \"\"\"\n # Construct the path to the API module\n api_path = BASE_DIR / \"api\" / api_name / \"api.py\"\n\n if not api_path.exists():\n logger.error(f\"API module not found at path: {api_path}\")\n raise FileNotFoundError(f\"API module not found at path: {api_path}\")\n\n module_name = f\"aeiva.tool.api.{api_name}.api\"\n spec = importlib.util.spec_from_file_location(module_name, str(api_path))\n module = importlib.util.module_from_spec(spec)\n try:\n spec.loader.exec_module(module)\n logger.info(f\"Successfully loaded module '{module_name}'\")\n except Exception as e:\n logger.error(f\"Failed to load module '{module_name}': {e}\")\n raise ImportError(f\"Failed to load module '{module_name}': {e}\")\n return module\n
"},{"location":"reference/#src.aeiva.tool.api_server.root","title":"root()
async
","text":"Root endpoint to confirm the API server is running.
Source code in src/aeiva/tool/api_server.py
@app.get(\"/\")\nasync def root():\n \"\"\"\n Root endpoint to confirm the API server is running.\n \"\"\"\n return {\"message\": \"Welcome to the AI Agent API system!\"}\n
"},{"location":"reference/#src.aeiva.tool.tool","title":"tool
","text":""},{"location":"reference/#src.aeiva.tool.tool.Tool","title":"Tool
","text":"Source code in src/aeiva/tool/tool.py
class Tool:\n def __init__(self, api_name: str):\n \"\"\"\n Initialize the tool, determining whether it should run locally or via an external service.\n Args:\n api_name (str): The name of the tool API (matches the function name).\n \"\"\"\n self.api_name = api_name\n self.schema = self.load_tool_schema(api_name)\n\n @classmethod\n def load_tool_schema(cls, api_name: str) -> dict:\n \"\"\"\n Load the tool's schema from the JSON file.\n Args:\n api_name (str): The name of the API or function.\n Returns:\n dict: The loaded schema from the JSON file.\n \"\"\"\n current_path = os.path.dirname(os.path.abspath(__file__))\n project_root = os.path.abspath(os.path.join(current_path, \"../../..\"))\n path = os.path.join(\n project_root,\n f\"src/aeiva/tool/api/{api_name}/{api_name}.json\",\n )\n with open(path, \"r\") as file:\n return json.load(file)\n\n async def aexecute(self, params: dict) -> Any:\n \"\"\"\n Execute the tool by calling the corresponding function (whether it's for a local function or encapsulated API call).\n Args:\n params (dict): Parameters to pass to the tool.\n Returns:\n Any: The result of the tool execution.\n \"\"\"\n function_module = f\"aeiva.tool.api.{self.api_name}.api\"\n func_module = import_module(function_module)\n\n # Check if the function is async\n function: Callable = getattr(func_module, self.api_name)\n if asyncio.iscoroutinefunction(function):\n return await function(**params)\n else:\n return function(**params)\n\n def execute(self, params: dict) -> Any:\n \"\"\"\n Execute the tool synchronously by calling the corresponding function.\n\n Args:\n params (dict): Parameters to pass to the tool.\n\n Returns:\n Any: The result of the tool execution.\n \"\"\"\n function_module = f\"aeiva.tool.api.{self.api_name}.api\"\n func_module = import_module(function_module)\n\n function: Callable = getattr(func_module, self.api_name)\n if asyncio.iscoroutinefunction(function):\n # If the function is async, attempt to run it in an event loop\n try:\n loop = asyncio.get_running_loop()\n # If an event loop is running, create a task and wait for it\n task = loop.create_task(function(**params))\n return loop.run_until_complete(task)\n except RuntimeError:\n # No event loop running, use asyncio.run\n return asyncio.run(function(**params))\n else:\n # If the function is synchronous, call it directly\n return function(**params)\n
"},{"location":"reference/#src.aeiva.tool.tool.Tool.__init__","title":"__init__(api_name)
","text":"Initialize the tool, determining whether it should run locally or via an external service. Args: api_name (str): The name of the tool API (matches the function name).
Source code in src/aeiva/tool/tool.py
def __init__(self, api_name: str):\n \"\"\"\n Initialize the tool, determining whether it should run locally or via an external service.\n Args:\n api_name (str): The name of the tool API (matches the function name).\n \"\"\"\n self.api_name = api_name\n self.schema = self.load_tool_schema(api_name)\n
"},{"location":"reference/#src.aeiva.tool.tool.Tool.aexecute","title":"aexecute(params)
async
","text":"Execute the tool by calling the corresponding function (whether it's for a local function or encapsulated API call). Args: params (dict): Parameters to pass to the tool. Returns: Any: The result of the tool execution.
Source code in src/aeiva/tool/tool.py
async def aexecute(self, params: dict) -> Any:\n \"\"\"\n Execute the tool by calling the corresponding function (whether it's for a local function or encapsulated API call).\n Args:\n params (dict): Parameters to pass to the tool.\n Returns:\n Any: The result of the tool execution.\n \"\"\"\n function_module = f\"aeiva.tool.api.{self.api_name}.api\"\n func_module = import_module(function_module)\n\n # Check if the function is async\n function: Callable = getattr(func_module, self.api_name)\n if asyncio.iscoroutinefunction(function):\n return await function(**params)\n else:\n return function(**params)\n
"},{"location":"reference/#src.aeiva.tool.tool.Tool.execute","title":"execute(params)
","text":"Execute the tool synchronously by calling the corresponding function.
Parameters:
Name Type Description Default params
dict
Parameters to pass to the tool.
required Returns:
Name Type Description Any
Any
The result of the tool execution.
Source code in src/aeiva/tool/tool.py
def execute(self, params: dict) -> Any:\n \"\"\"\n Execute the tool synchronously by calling the corresponding function.\n\n Args:\n params (dict): Parameters to pass to the tool.\n\n Returns:\n Any: The result of the tool execution.\n \"\"\"\n function_module = f\"aeiva.tool.api.{self.api_name}.api\"\n func_module = import_module(function_module)\n\n function: Callable = getattr(func_module, self.api_name)\n if asyncio.iscoroutinefunction(function):\n # If the function is async, attempt to run it in an event loop\n try:\n loop = asyncio.get_running_loop()\n # If an event loop is running, create a task and wait for it\n task = loop.create_task(function(**params))\n return loop.run_until_complete(task)\n except RuntimeError:\n # No event loop running, use asyncio.run\n return asyncio.run(function(**params))\n else:\n # If the function is synchronous, call it directly\n return function(**params)\n
"},{"location":"reference/#src.aeiva.tool.tool.Tool.load_tool_schema","title":"load_tool_schema(api_name)
classmethod
","text":"Load the tool's schema from the JSON file. Args: api_name (str): The name of the API or function. Returns: dict: The loaded schema from the JSON file.
Source code in src/aeiva/tool/tool.py
@classmethod\ndef load_tool_schema(cls, api_name: str) -> dict:\n \"\"\"\n Load the tool's schema from the JSON file.\n Args:\n api_name (str): The name of the API or function.\n Returns:\n dict: The loaded schema from the JSON file.\n \"\"\"\n current_path = os.path.dirname(os.path.abspath(__file__))\n project_root = os.path.abspath(os.path.join(current_path, \"../../..\"))\n path = os.path.join(\n project_root,\n f\"src/aeiva/tool/api/{api_name}/{api_name}.json\",\n )\n with open(path, \"r\") as file:\n return json.load(file)\n
"},{"location":"reference/#src.aeiva.tool.toolkit","title":"toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.arxiv_toolkit","title":"arxiv_toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.arxiv_toolkit.ArxivToolkit","title":"ArxivToolkit
","text":" Bases: Toolkit
A toolkit for interacting with the arXiv API.
Source code in src/aeiva/tool/toolkit/arxiv_toolkit.py
class ArxivToolkit(Toolkit):\n \"\"\"\n A toolkit for interacting with the arXiv API.\n \"\"\"\n\n def __init__(self, config=None):\n super().__init__(\n name=\"ArxivToolkit\",\n tool_names=[\n \"download_arxiv_papers\",\n \"search_arxiv_papers\"\n ],\n config=config\n )\n
"},{"location":"reference/#src.aeiva.tool.toolkit.auto_ui_toolkit","title":"auto_ui_toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.auto_ui_toolkit.AutoUIToolkit","title":"AutoUIToolkit
","text":" Bases: Toolkit
A toolkit for automating GUI interactions.
Source code in src/aeiva/tool/toolkit/auto_ui_toolkit.py
class AutoUIToolkit(Toolkit):\n \"\"\"\n A toolkit for automating GUI interactions.\n \"\"\"\n\n def __init__(self, config=None):\n super().__init__(\n name=\"AutoUIToolkit\",\n tool_names=[\n \"analyze_gui\",\n \"analyze_gui_by_ocr\",\n \"click_mouse\",\n \"click_on_element\",\n \"move_mouse\",\n \"operate_computer\",\n \"scroll\",\n \"type_into_element\",\n \"type_keyboard\"\n ],\n config=config\n )\n
"},{"location":"reference/#src.aeiva.tool.toolkit.docx_toolkit","title":"docx_toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.docx_toolkit.DocxToolkit","title":"DocxToolkit
","text":" Bases: Toolkit
A toolkit for interacting with Docx files.
Source code in src/aeiva/tool/toolkit/docx_toolkit.py
class DocxToolkit(Toolkit):\n \"\"\"\n A toolkit for interacting with Docx files.\n \"\"\"\n\n def __init__(self, config=None):\n super().__init__(\n name=\"DocxToolkit\",\n tool_names=[\n \"create_docx\",\n \"docx2html\",\n \"docx2images\",\n \"docx2markdown\",\n \"docx2metadata\",\n \"docx2pdf\",\n \"docx2text\",\n \"modify_docx\"\n ],\n config=config\n )\n
"},{"location":"reference/#src.aeiva.tool.toolkit.file_toolkit","title":"file_toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.file_toolkit.FileToolkit","title":"FileToolkit
","text":" Bases: Toolkit
A toolkit for file-related operations.
Source code in src/aeiva/tool/toolkit/file_toolkit.py
class FileToolkit(Toolkit):\n \"\"\"\n A toolkit for file-related operations.\n \"\"\"\n\n def __init__(self, config=None):\n super().__init__(\n name=\"FileToolkit\",\n tool_names=[\n \"create_file_or_folder\",\n \"open_file_or_folder\",\n \"search_file_or_folder\",\n \"copy_file_or_folder\",\n \"move_file_or_folder\",\n \"change_permissions\",\n \"get_file_metadata\",\n \"delete_file\",\n \"edit_file\",\n \"find_file\",\n \"list_files\",\n \"read_file\",\n \"rename_file\",\n \"write_file\"\n ],\n config=config\n )\n
"},{"location":"reference/#src.aeiva.tool.toolkit.git_toolkit","title":"git_toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.git_toolkit.GitToolkit","title":"GitToolkit
","text":" Bases: Toolkit
A toolkit for interacting with Git repositories.
Source code in src/aeiva/tool/toolkit/git_toolkit.py
class GitToolkit(Toolkit):\n \"\"\"\n A toolkit for interacting with Git repositories.\n \"\"\"\n\n def __init__(self, config=None):\n super().__init__(\n name=\"GitToolkit\",\n tool_names=[\n \"git_apply_patch\",\n \"git_clone\",\n \"git_custom\",\n \"git_patch\",\n \"git_repo_tree\"\n ],\n config=config\n )\n
"},{"location":"reference/#src.aeiva.tool.toolkit.math_toolkit","title":"math_toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.math_toolkit.MathToolkit","title":"MathToolkit
","text":" Bases: Toolkit
A toolkit for mathematical operations.
Source code in src/aeiva/tool/toolkit/math_toolkit.py
class MathToolkit(Toolkit):\n \"\"\"\n A toolkit for mathematical operations.\n \"\"\"\n\n def __init__(self, config=None):\n super().__init__(\n name=\"MathToolkit\",\n tool_names=[\"calculator\"],\n config=config\n )\n
"},{"location":"reference/#src.aeiva.tool.toolkit.pdf_toolkit","title":"pdf_toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.pdf_toolkit.PdfToolkit","title":"PdfToolkit
","text":" Bases: Toolkit
A toolkit for interacting with PDF files.
Source code in src/aeiva/tool/toolkit/pdf_toolkit.py
class PdfToolkit(Toolkit):\n \"\"\"\n A toolkit for interacting with PDF files.\n \"\"\"\n\n def __init__(self, config=None):\n super().__init__(\n name=\"PdfToolkit\",\n tool_names=[\n \"pdf2markdown\",\n \"pdf2text\",\n \"pdf2tables\",\n \"pdf2images\",\n \"pdf2metadata\",\n \"pdf2ocr\"\n ],\n config=config\n )\n
"},{"location":"reference/#src.aeiva.tool.toolkit.rbac","title":"rbac
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.rbac.PermissionError","title":"PermissionError
","text":" Bases: Exception
Custom exception for permission-related errors.
Source code in src/aeiva/tool/toolkit/rbac.py
class PermissionError(Exception):\n \"\"\"Custom exception for permission-related errors.\"\"\"\n pass\n
"},{"location":"reference/#src.aeiva.tool.toolkit.rbac.check_permission","title":"check_permission(user_role, api_name, config)
","text":"Check if the user_role has permission to execute the given api_name.
Parameters:
Name Type Description Default user_role
str
The role of the user.
required api_name
str
The name of the API function.
required config
ToolkitConfig
The toolkit configuration containing role permissions.
required Returns:
Name Type Description bool
bool
True if permitted, False otherwise.
Raises:
Type Description PermissionError
If the user does not have the required permission.
Source code in src/aeiva/tool/toolkit/rbac.py
def check_permission(user_role: str, api_name: str, config: ToolkitConfig) -> bool:\n \"\"\"\n Check if the user_role has permission to execute the given api_name.\n\n Args:\n user_role (str): The role of the user.\n api_name (str): The name of the API function.\n config (ToolkitConfig): The toolkit configuration containing role permissions.\n\n Returns:\n bool: True if permitted, False otherwise.\n\n Raises:\n PermissionError: If the user does not have the required permission.\n \"\"\"\n allowed_apis: List[str] = config.role_permissions.get(user_role, [])\n if api_name in allowed_apis:\n return True\n else:\n return False\n
"},{"location":"reference/#src.aeiva.tool.toolkit.security","title":"security
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.security.sanitize_file_path","title":"sanitize_file_path(file_path, config)
","text":"Sanitize the file path to prevent directory traversal attacks.
Parameters:
Name Type Description Default file_path
str
The input file path.
required config
ToolkitConfig
The configuration instance.
required Returns:
Name Type Description str
str
The sanitized absolute file path.
Raises:
Type Description ValueError
If the file path is not within allowed directories.
Source code in src/aeiva/tool/toolkit/security.py
def sanitize_file_path(file_path: str, config: ToolkitConfig) -> str:\n \"\"\"\n Sanitize the file path to prevent directory traversal attacks.\n\n Args:\n file_path (str): The input file path.\n config (ToolkitConfig): The configuration instance.\n\n Returns:\n str: The sanitized absolute file path.\n\n Raises:\n ValueError: If the file path is not within allowed directories.\n \"\"\"\n # Resolve the absolute path\n try:\n absolute_path = Path(file_path).resolve(strict=False)\n except Exception as e:\n logger.error(f\"Error resolving file path '{file_path}': {e}\")\n raise ValueError(f\"Invalid file path: {e}\")\n\n # Check if the path is within allowed directories\n allowed = False\n for dir_path in config.allowed_directories:\n try:\n allowed_dir = Path(dir_path).resolve(strict=False)\n if allowed_dir in absolute_path.parents or allowed_dir == absolute_path.parent:\n allowed = True\n break\n except Exception as e:\n logger.error(f\"Error resolving allowed directory '{dir_path}': {e}\")\n continue\n\n if not allowed:\n logger.error(f\"Unauthorized file path access attempt: {absolute_path}\")\n raise ValueError(\"Unauthorized file path.\")\n\n return str(absolute_path)\n
"},{"location":"reference/#src.aeiva.tool.toolkit.shell_toolkit","title":"shell_toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.shell_toolkit.ShellToolkit","title":"ShellToolkit
","text":" Bases: Toolkit
A toolkit for shell and terminal operations.
Source code in src/aeiva/tool/toolkit/shell_toolkit.py
class ShellToolkit(Toolkit):\n \"\"\"\n A toolkit for shell and terminal operations.\n \"\"\"\n\n def __init__(self, config=None):\n super().__init__(\n name=\"ShellToolkit\",\n tool_names=[\n \"chwdir\",\n \"execute_bash_command\",\n \"execute_script\",\n \"grep\",\n \"create_new_shell_session\"\n ],\n config=config\n )\n
"},{"location":"reference/#src.aeiva.tool.toolkit.system_toolkit","title":"system_toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.system_toolkit.SystemToolkit","title":"SystemToolkit
","text":" Bases: Toolkit
A toolkit for interacting with system-level operations.
Source code in src/aeiva/tool/toolkit/system_toolkit.py
class SystemToolkit(Toolkit):\n \"\"\"\n A toolkit for interacting with system-level operations.\n \"\"\"\n\n def __init__(self, config=None):\n super().__init__(\n name=\"SystemToolkit\",\n tool_names=[\n \"get_system_info\",\n \"get_package_root\",\n \"get_user_home_path\",\n \"open_application\",\n \"close_application\",\n \"percept_terminal_input\",\n \"play_music\",\n \"stop_music\",\n \"take_screenshot\"\n \"list_processes\",\n \"kill_process\",\n \"monitor_process\",\n \"get_network_info\",\n \"check_internet_connection\",\n \"get_disk_usage\",\n \"clean_temp_files\",\n \"list_drives\",\n \"get_env_var\",\n \"set_env_var\",\n \"update_system_packages\",\n \"install_package\",\n \"create_user\",\n \"delete_user\",\n \"change_user_password\",\n \"view_system_logs\",\n \"monitor_system_resources\",\n ],\n config=config\n )\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit","title":"toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit","title":"Toolkit
","text":"Toolkit class that manages multiple Tool instances, handles validation, enforces RBAC, and manages shared resources.
Source code in src/aeiva/tool/toolkit/toolkit.py
class Toolkit:\n \"\"\"\n Toolkit class that manages multiple Tool instances, handles validation,\n enforces RBAC, and manages shared resources.\n \"\"\"\n\n subclasses: Dict[str, Type['Toolkit']] = {}\n\n def __init_subclass__(cls, **kwargs):\n \"\"\"\n Automatically register subclasses in the Toolkit's subclasses dictionary.\n \"\"\"\n super().__init_subclass__(**kwargs)\n Toolkit.subclasses[cls.__name__] = cls\n\n def __init__(self, name: str, tool_names: List[str], config: Optional[ToolkitConfig] = None):\n \"\"\"\n Initialize the Toolkit with a name, list of tool names, and optional configuration.\n\n Args:\n name (str): The name of the toolkit.\n tool_names (List[str]): The names of tools to be managed by the toolkit.\n config (Optional[ToolkitConfig]): Configuration for security and roles.\n \"\"\"\n self.toolkit_name = name\n self.tool_names = tool_names\n self.config = config\n self.tools: Dict[str, Tool] = {}\n self.tool_schemas: Dict[str, Dict] = {}\n self.tool_models: Dict[str, Tuple[Type[BaseModel], Type[BaseModel]]] = {}\n self.shared_resources = None # Placeholder for shared resources\n\n # Setup the toolkit\n self.setup()\n\n def setup(self):\n \"\"\"\n Setup the toolkit by loading tools, their schemas, and initializing shared resources.\n \"\"\"\n logger.info(f\"Setting up toolkit '{self.toolkit_name}'.\")\n\n # Load tools and their schemas\n for tool_name in self.tool_names:\n tool = Tool(api_name=tool_name)\n self.tools[tool_name] = tool\n schema = tool.load_tool_schema(tool_name)\n self.tool_schemas[tool_name] = schema\n logger.debug(f\"Loaded schema for tool '{tool_name}': {schema}\")\n\n # Load Pydantic models for validation\n self.load_pydantic_models_for_all_tools()\n\n # Initialize shared resources\n self.init_shared_resources()\n\n def load_pydantic_models_for_all_tools(self):\n \"\"\"\n Load Pydantic models (Params and Result) for all tools for validation.\n \"\"\"\n logger.info(\"Loading Pydantic models for all tools.\")\n for tool_name in self.tool_names:\n try:\n param_model, result_model = self.load_pydantic_models_for_tool(tool_name)\n self.tool_models[tool_name] = (param_model, result_model)\n logger.debug(f\"Loaded models for tool '{tool_name}': Params={param_model}, Result={result_model}\")\n except Exception as e:\n logger.error(f\"Failed to load models for tool '{tool_name}': {e}\")\n raise\n\n def load_pydantic_models_for_tool(self, api_name: str) -> Tuple[Type[BaseModel], Type[BaseModel]]:\n \"\"\"\n Load the parameter and result Pydantic models for the given API.\n\n Args:\n api_name (str): The name of the API function.\n\n Returns:\n Tuple[Type[BaseModel], Type[BaseModel]]: The parameter and result model classes.\n\n Raises:\n ValueError: If models cannot be loaded.\n \"\"\"\n module_path = f\"aeiva.tool.api.{api_name}.model\" # Adjusted as per user's path\n try:\n models_module = importlib.import_module(module_path)\n param_model_class = getattr(models_module, f\"{snake_to_camel(api_name)}Params\", None)\n result_model_class = getattr(models_module, f\"{snake_to_camel(api_name)}Result\", None)\n if not (param_model_class and issubclass(param_model_class, BaseModel)):\n logger.error(f\"Param model class '{snake_to_camel(api_name)}Params' not found in '{module_path}'.\")\n raise ValueError(f\"Param model class '{snake_to_camel(api_name)}Params' not found in '{module_path}'.\")\n if not (result_model_class and issubclass(result_model_class, BaseModel)):\n logger.error(f\"Result model class '{snake_to_camel(api_name)}Result' not found in '{module_path}'.\")\n raise ValueError(f\"Result model class '{snake_to_camel(api_name)}Result' not found in '{module_path}'.\")\n return param_model_class, result_model_class\n except ImportError as e:\n logger.error(f\"Error importing models from '{module_path}': {e}\")\n raise ImportError(f\"Error importing models from '{module_path}': {e}\")\n except AttributeError as e:\n logger.error(f\"Error accessing model classes in '{module_path}': {e}\")\n raise ValueError(f\"Error accessing model classes in '{module_path}': {e}\")\n\n def init_shared_resources(self):\n \"\"\"\n Initialize shared resources required by the toolkit.\n Override this method in subclasses if needed.\n \"\"\"\n logger.info(\"Initializing shared resources.\")\n # Placeholder for initializing shared resources like databases, servers, etc.\n # Example:\n # self.shared_resources = initialize_database_connection()\n pass\n\n def teardown(self):\n \"\"\"\n Teardown the toolkit by unloading tools, their schemas, and cleaning up shared resources.\n \"\"\"\n logger.info(f\"Tearing down toolkit '{self.toolkit_name}'.\")\n\n # Clean up shared resources\n self.teardown_shared_resources()\n\n # Clear loaded data\n self.tools.clear()\n self.tool_schemas.clear()\n self.tool_models.clear()\n\n def teardown_shared_resources(self):\n \"\"\"\n Teardown shared resources.\n Override this method in subclasses if needed.\n \"\"\"\n logger.info(\"Tearing down shared resources.\")\n # Placeholder for tearing down shared resources\n # Example:\n # if self.shared_resources:\n # self.shared_resources.close()\n pass\n\n @asynccontextmanager\n async def acontext(self):\n \"\"\"\n Asynchronous context manager to handle setup and teardown of shared resources.\n\n Usage:\n async with toolkit.acontent():\n # Execute tools\n \"\"\"\n try:\n await self.asetup()\n yield self\n finally:\n await self.ateardown()\n\n @contextmanager\n def context(self):\n \"\"\"\n Synchronous context manager to handle setup and teardown of shared resources.\n\n Usage:\n with toolkit.context():\n # Execute tools\n \"\"\"\n try:\n self.setup()\n yield self\n finally:\n self.teardown()\n\n async def asetup(self):\n \"\"\"\n Asynchronously setup shared resources.\n \"\"\"\n logger.info(f\"Asynchronously setting up toolkit '{self.toolkit_name}'.\")\n # Override in subclasses if asynchronous setup is required\n pass\n\n async def ateardown(self):\n \"\"\"\n Asynchronously teardown shared resources.\n \"\"\"\n logger.info(f\"Asynchronously tearing down toolkit '{self.toolkit_name}'.\")\n # Override in subclasses if asynchronous teardown is required\n self.teardown()\n\n def execute(self, api_name: str, params: Dict[str, Any]) -> Any:\n \"\"\"\n Synchronously execute a tool's API function with validation and RBAC checks.\n\n Args:\n api_name (str): The name of the API function to execute.\n params (Dict[str, Any]): The parameters for the API function.\n\n Returns:\n Any: The result of the tool execution.\n\n Raises:\n ValueError: If tool not found or parameter validation fails.\n PermissionError: If user does not have permission.\n RuntimeError: If tool execution fails.\n \"\"\"\n tool = self.tools.get(api_name)\n if not tool:\n logger.error(f\"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.\")\n raise ValueError(f\"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.\")\n\n # Perform RBAC check if config is provided\n if self.config:\n # Automatically retrieve user role from OS\n os_user = get_os_user()\n user_role = self.config.user_role_mapping.get(os_user)\n if not user_role:\n logger.error(f\"OS user '{os_user}' does not have an assigned role.\")\n raise ValueError(f\"OS user '{os_user}' does not have an assigned role.\")\n if not check_permission(user_role, api_name, self.config):\n logger.error(f\"User role '{user_role}' does not have permission to execute '{api_name}'.\")\n raise PermissionError(f\"User role '{user_role}' does not have permission to execute '{api_name}'.\")\n\n # Load the Pydantic models for validation\n param_model, result_model = self.tool_models.get(api_name, (None, None))\n if not param_model or not result_model:\n logger.error(f\"Pydantic models for tool '{api_name}' are not loaded.\")\n raise ValueError(f\"Pydantic models for tool '{api_name}' are not loaded.\")\n\n # Instantiate and validate the parameter model\n try:\n param_instance = param_model(**params)\n logger.debug(f\"Validated input parameters for '{api_name}': {param_instance}\")\n except Exception as e:\n logger.error(f\"Error parsing parameters for '{api_name}': {e}\")\n raise ValueError(f\"Invalid parameters for '{api_name}': {e}\")\n\n # Perform security checks on parameters if needed\n param_instance = self.perform_security_checks(param_instance)\n\n # Execute the API function via the Tool\n try:\n raw_result = tool.execute(param_instance.dict())\n logger.debug(f\"Raw result from '{api_name}': {raw_result}\")\n except Exception as e:\n logger.error(f\"Error executing tool '{api_name}': {e}\")\n raise RuntimeError(f\"Error executing tool '{api_name}': {e}\")\n\n # Validate the result using the result model\n try:\n result_instance = result_model(**raw_result)\n logger.info(f\"Execution of '{api_name}' successful with result: {result_instance}\")\n return result_instance\n except Exception as e:\n logger.error(f\"Error parsing result for '{api_name}': {e}\")\n raise ValueError(f\"Invalid result from '{api_name}': {e}\")\n\n async def aexecute(self, api_name: str, params: Dict[str, Any]) -> Any:\n \"\"\"\n Asynchronously execute a tool's API function with validation and RBAC checks.\n\n Args:\n api_name (str): The name of the API function to execute.\n params (Dict[str, Any]): The parameters for the API function.\n\n Returns:\n Any: The result of the tool execution.\n\n Raises:\n ValueError: If tool not found or parameter validation fails.\n PermissionError: If user does not have permission.\n RuntimeError: If tool execution fails.\n \"\"\"\n tool = self.tools.get(api_name)\n if not tool:\n logger.error(f\"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.\")\n raise ValueError(f\"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.\")\n\n # Perform RBAC check if config is provided\n if self.config:\n # Automatically retrieve user role from OS\n os_user = get_os_user()\n user_role = self.config.user_role_mapping.get(os_user)\n if not user_role:\n logger.error(f\"OS user '{os_user}' does not have an assigned role.\")\n raise ValueError(f\"OS user '{os_user}' does not have an assigned role.\")\n if not check_permission(user_role, api_name, self.config):\n logger.error(f\"User role '{user_role}' does not have permission to execute '{api_name}'.\")\n raise PermissionError(f\"User role '{user_role}' does not have permission to execute '{api_name}'.\")\n\n # Load the Pydantic models for validation\n param_model, result_model = self.tool_models.get(api_name, (None, None))\n if not param_model or not result_model:\n logger.error(f\"Pydantic models for tool '{api_name}' are not loaded.\")\n raise ValueError(f\"Pydantic models for tool '{api_name}' are not loaded.\")\n\n # Instantiate and validate the parameter model\n try:\n param_instance = param_model(**params)\n logger.debug(f\"Validated input parameters for '{api_name}': {param_instance}\")\n except Exception as e:\n logger.error(f\"Error parsing parameters for '{api_name}': {e}\")\n raise ValueError(f\"Invalid parameters for '{api_name}': {e}\")\n\n # Perform security checks on parameters if needed\n param_instance = self.perform_security_checks(param_instance)\n\n # Execute the API function via the Tool\n try:\n raw_result = await tool.aexecute(param_instance.dict())\n logger.debug(f\"Raw result from '{api_name}': {raw_result}\")\n except Exception as e:\n logger.error(f\"Error executing tool '{api_name}': {e}\")\n raise RuntimeError(f\"Error executing tool '{api_name}': {e}\")\n\n # Validate the result using the result model\n try:\n result_instance = result_model(**raw_result)\n logger.info(f\"Execution of '{api_name}' successful with result: {result_instance}\")\n return result_instance\n except Exception as e:\n logger.error(f\"Error parsing result for '{api_name}': {e}\")\n raise ValueError(f\"Invalid result from '{api_name}': {e}\")\n\n def perform_security_checks(self, param_instance: BaseModel) -> BaseModel:\n \"\"\"\n Perform security checks on parameters that require sanitization.\n\n Args:\n param_instance (BaseModel): The validated parameter instance.\n\n Returns:\n BaseModel: The sanitized parameter instance.\n\n Raises:\n ValueError: If sanitization fails for any field or if config is required but not provided.\n \"\"\"\n sanitized_params = param_instance.dict()\n\n for field_name, field in param_instance.__fields__.items():\n sanitize = field.field_info.extra.get('sanitize', False)\n if not sanitize:\n continue # Skip fields that do not require sanitization\n\n field_type = field.type_\n origin = get_origin(field_type)\n args = get_args(field_type)\n\n # Determine if the field is a string type or contains string types\n is_string_field = False\n\n if field_type == str:\n is_string_field = True\n elif origin is Union and str in args:\n is_string_field = True\n elif origin is list and len(args) == 1 and args[0] == str:\n is_string_field = True\n elif origin is Optional and str in args:\n is_string_field = True\n # Add more conditions here if there are other complex types containing strings\n\n if is_string_field:\n original_value = sanitized_params.get(field_name)\n if original_value is None:\n continue # Skip if the field value is None\n\n if self.config is None:\n logger.error(\n f\"Configuration is required to sanitize field '{field_name}', \"\n f\"but config is not provided.\"\n )\n raise ValueError(\n f\"Configuration is required to sanitize field '{field_name}', \"\n f\"but config is not provided.\"\n )\n\n try:\n # If the field is a list of strings, sanitize each path individually\n if origin is list and len(args) == 1 and args[0] == str:\n if not isinstance(original_value, list):\n logger.error(\n f\"Expected a list for field '{field_name}', \"\n f\"got {type(original_value)}.\"\n )\n raise ValueError(\n f\"Expected a list for field '{field_name}'.\"\n )\n sanitized_list = []\n for idx, item in enumerate(original_value):\n sanitized_item = sanitize_file_path(item, self.config)\n sanitized_list.append(sanitized_item)\n logger.debug(\n f\"Sanitized '{field_name}[{idx}]': '{item}' -> '{sanitized_item}'\"\n )\n sanitized_params[field_name] = sanitized_list\n else:\n # Sanitize single string path\n sanitized_path = sanitize_file_path(original_value, self.config)\n sanitized_params[field_name] = sanitized_path\n logger.debug(\n f\"Sanitized '{field_name}': '{original_value}' -> '{sanitized_path}'\"\n )\n except ValueError as ve:\n logger.error(\n f\"Sanitization failed for field '{field_name}': {ve}\"\n )\n raise\n\n # Create a new instance of the parameter model with sanitized parameters\n sanitized_instance = param_instance.copy(update=sanitized_params)\n\n return sanitized_instance\n\n def generate_documentation(self) -> str:\n \"\"\"\n Generate documentation for all functions managed by this toolkit based on their schemas.\n\n Returns:\n str: Generated documentation as a markdown string.\n \"\"\"\n doc = f\"# Toolkit: {self.toolkit_name}\\n\\n\"\n for api_name, tool in self.tools.items():\n schema = self.tool_schemas.get(api_name, {})\n if not schema:\n continue\n doc += f\"## Function: {api_name}\\n\\n\"\n doc += f\"**Description:** {schema.get('description', 'No description provided.')}\\n\\n\"\n doc += \"### Parameters:\\n\\n\"\n parameters = schema.get(\"parameters\", {})\n for prop, details in parameters.get(\"properties\", {}).items():\n req = \" (required)\" if prop in parameters.get(\"required\", []) else \"\"\n description = details.get(\"description\", \"\")\n default = f\" (default: {details.get('default')})\" if \"default\" in details else \"\"\n doc += f\"- **{prop}** ({details.get('type', 'any')}): {description}{default}{req}\\n\"\n doc += \"\\n### Example:\\n\\n\"\n example = schema.get(\"example\", \"No example provided.\")\n if isinstance(example, dict):\n example = json.dumps(example, indent=4)\n doc += f\"```json\\n{example}\\n```\\n\\n\"\n return doc\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.__init__","title":"__init__(name, tool_names, config=None)
","text":"Initialize the Toolkit with a name, list of tool names, and optional configuration.
Parameters:
Name Type Description Default name
str
The name of the toolkit.
required tool_names
List[str]
The names of tools to be managed by the toolkit.
required config
Optional[ToolkitConfig]
Configuration for security and roles.
None
Source code in src/aeiva/tool/toolkit/toolkit.py
def __init__(self, name: str, tool_names: List[str], config: Optional[ToolkitConfig] = None):\n \"\"\"\n Initialize the Toolkit with a name, list of tool names, and optional configuration.\n\n Args:\n name (str): The name of the toolkit.\n tool_names (List[str]): The names of tools to be managed by the toolkit.\n config (Optional[ToolkitConfig]): Configuration for security and roles.\n \"\"\"\n self.toolkit_name = name\n self.tool_names = tool_names\n self.config = config\n self.tools: Dict[str, Tool] = {}\n self.tool_schemas: Dict[str, Dict] = {}\n self.tool_models: Dict[str, Tuple[Type[BaseModel], Type[BaseModel]]] = {}\n self.shared_resources = None # Placeholder for shared resources\n\n # Setup the toolkit\n self.setup()\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.__init_subclass__","title":"__init_subclass__(**kwargs)
","text":"Automatically register subclasses in the Toolkit's subclasses dictionary.
Source code in src/aeiva/tool/toolkit/toolkit.py
def __init_subclass__(cls, **kwargs):\n \"\"\"\n Automatically register subclasses in the Toolkit's subclasses dictionary.\n \"\"\"\n super().__init_subclass__(**kwargs)\n Toolkit.subclasses[cls.__name__] = cls\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.acontext","title":"acontext()
async
","text":"Asynchronous context manager to handle setup and teardown of shared resources.
Usage async with toolkit.acontent(): # Execute tools
Source code in src/aeiva/tool/toolkit/toolkit.py
@asynccontextmanager\nasync def acontext(self):\n \"\"\"\n Asynchronous context manager to handle setup and teardown of shared resources.\n\n Usage:\n async with toolkit.acontent():\n # Execute tools\n \"\"\"\n try:\n await self.asetup()\n yield self\n finally:\n await self.ateardown()\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.aexecute","title":"aexecute(api_name, params)
async
","text":"Asynchronously execute a tool's API function with validation and RBAC checks.
Parameters:
Name Type Description Default api_name
str
The name of the API function to execute.
required params
Dict[str, Any]
The parameters for the API function.
required Returns:
Name Type Description Any
Any
The result of the tool execution.
Raises:
Type Description ValueError
If tool not found or parameter validation fails.
PermissionError
If user does not have permission.
RuntimeError
If tool execution fails.
Source code in src/aeiva/tool/toolkit/toolkit.py
async def aexecute(self, api_name: str, params: Dict[str, Any]) -> Any:\n \"\"\"\n Asynchronously execute a tool's API function with validation and RBAC checks.\n\n Args:\n api_name (str): The name of the API function to execute.\n params (Dict[str, Any]): The parameters for the API function.\n\n Returns:\n Any: The result of the tool execution.\n\n Raises:\n ValueError: If tool not found or parameter validation fails.\n PermissionError: If user does not have permission.\n RuntimeError: If tool execution fails.\n \"\"\"\n tool = self.tools.get(api_name)\n if not tool:\n logger.error(f\"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.\")\n raise ValueError(f\"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.\")\n\n # Perform RBAC check if config is provided\n if self.config:\n # Automatically retrieve user role from OS\n os_user = get_os_user()\n user_role = self.config.user_role_mapping.get(os_user)\n if not user_role:\n logger.error(f\"OS user '{os_user}' does not have an assigned role.\")\n raise ValueError(f\"OS user '{os_user}' does not have an assigned role.\")\n if not check_permission(user_role, api_name, self.config):\n logger.error(f\"User role '{user_role}' does not have permission to execute '{api_name}'.\")\n raise PermissionError(f\"User role '{user_role}' does not have permission to execute '{api_name}'.\")\n\n # Load the Pydantic models for validation\n param_model, result_model = self.tool_models.get(api_name, (None, None))\n if not param_model or not result_model:\n logger.error(f\"Pydantic models for tool '{api_name}' are not loaded.\")\n raise ValueError(f\"Pydantic models for tool '{api_name}' are not loaded.\")\n\n # Instantiate and validate the parameter model\n try:\n param_instance = param_model(**params)\n logger.debug(f\"Validated input parameters for '{api_name}': {param_instance}\")\n except Exception as e:\n logger.error(f\"Error parsing parameters for '{api_name}': {e}\")\n raise ValueError(f\"Invalid parameters for '{api_name}': {e}\")\n\n # Perform security checks on parameters if needed\n param_instance = self.perform_security_checks(param_instance)\n\n # Execute the API function via the Tool\n try:\n raw_result = await tool.aexecute(param_instance.dict())\n logger.debug(f\"Raw result from '{api_name}': {raw_result}\")\n except Exception as e:\n logger.error(f\"Error executing tool '{api_name}': {e}\")\n raise RuntimeError(f\"Error executing tool '{api_name}': {e}\")\n\n # Validate the result using the result model\n try:\n result_instance = result_model(**raw_result)\n logger.info(f\"Execution of '{api_name}' successful with result: {result_instance}\")\n return result_instance\n except Exception as e:\n logger.error(f\"Error parsing result for '{api_name}': {e}\")\n raise ValueError(f\"Invalid result from '{api_name}': {e}\")\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.asetup","title":"asetup()
async
","text":"Asynchronously setup shared resources.
Source code in src/aeiva/tool/toolkit/toolkit.py
async def asetup(self):\n \"\"\"\n Asynchronously setup shared resources.\n \"\"\"\n logger.info(f\"Asynchronously setting up toolkit '{self.toolkit_name}'.\")\n # Override in subclasses if asynchronous setup is required\n pass\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.ateardown","title":"ateardown()
async
","text":"Asynchronously teardown shared resources.
Source code in src/aeiva/tool/toolkit/toolkit.py
async def ateardown(self):\n \"\"\"\n Asynchronously teardown shared resources.\n \"\"\"\n logger.info(f\"Asynchronously tearing down toolkit '{self.toolkit_name}'.\")\n # Override in subclasses if asynchronous teardown is required\n self.teardown()\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.context","title":"context()
","text":"Synchronous context manager to handle setup and teardown of shared resources.
Usage with toolkit.context(): # Execute tools
Source code in src/aeiva/tool/toolkit/toolkit.py
@contextmanager\ndef context(self):\n \"\"\"\n Synchronous context manager to handle setup and teardown of shared resources.\n\n Usage:\n with toolkit.context():\n # Execute tools\n \"\"\"\n try:\n self.setup()\n yield self\n finally:\n self.teardown()\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.execute","title":"execute(api_name, params)
","text":"Synchronously execute a tool's API function with validation and RBAC checks.
Parameters:
Name Type Description Default api_name
str
The name of the API function to execute.
required params
Dict[str, Any]
The parameters for the API function.
required Returns:
Name Type Description Any
Any
The result of the tool execution.
Raises:
Type Description ValueError
If tool not found or parameter validation fails.
PermissionError
If user does not have permission.
RuntimeError
If tool execution fails.
Source code in src/aeiva/tool/toolkit/toolkit.py
def execute(self, api_name: str, params: Dict[str, Any]) -> Any:\n \"\"\"\n Synchronously execute a tool's API function with validation and RBAC checks.\n\n Args:\n api_name (str): The name of the API function to execute.\n params (Dict[str, Any]): The parameters for the API function.\n\n Returns:\n Any: The result of the tool execution.\n\n Raises:\n ValueError: If tool not found or parameter validation fails.\n PermissionError: If user does not have permission.\n RuntimeError: If tool execution fails.\n \"\"\"\n tool = self.tools.get(api_name)\n if not tool:\n logger.error(f\"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.\")\n raise ValueError(f\"Tool '{api_name}' not found in toolkit '{self.toolkit_name}'.\")\n\n # Perform RBAC check if config is provided\n if self.config:\n # Automatically retrieve user role from OS\n os_user = get_os_user()\n user_role = self.config.user_role_mapping.get(os_user)\n if not user_role:\n logger.error(f\"OS user '{os_user}' does not have an assigned role.\")\n raise ValueError(f\"OS user '{os_user}' does not have an assigned role.\")\n if not check_permission(user_role, api_name, self.config):\n logger.error(f\"User role '{user_role}' does not have permission to execute '{api_name}'.\")\n raise PermissionError(f\"User role '{user_role}' does not have permission to execute '{api_name}'.\")\n\n # Load the Pydantic models for validation\n param_model, result_model = self.tool_models.get(api_name, (None, None))\n if not param_model or not result_model:\n logger.error(f\"Pydantic models for tool '{api_name}' are not loaded.\")\n raise ValueError(f\"Pydantic models for tool '{api_name}' are not loaded.\")\n\n # Instantiate and validate the parameter model\n try:\n param_instance = param_model(**params)\n logger.debug(f\"Validated input parameters for '{api_name}': {param_instance}\")\n except Exception as e:\n logger.error(f\"Error parsing parameters for '{api_name}': {e}\")\n raise ValueError(f\"Invalid parameters for '{api_name}': {e}\")\n\n # Perform security checks on parameters if needed\n param_instance = self.perform_security_checks(param_instance)\n\n # Execute the API function via the Tool\n try:\n raw_result = tool.execute(param_instance.dict())\n logger.debug(f\"Raw result from '{api_name}': {raw_result}\")\n except Exception as e:\n logger.error(f\"Error executing tool '{api_name}': {e}\")\n raise RuntimeError(f\"Error executing tool '{api_name}': {e}\")\n\n # Validate the result using the result model\n try:\n result_instance = result_model(**raw_result)\n logger.info(f\"Execution of '{api_name}' successful with result: {result_instance}\")\n return result_instance\n except Exception as e:\n logger.error(f\"Error parsing result for '{api_name}': {e}\")\n raise ValueError(f\"Invalid result from '{api_name}': {e}\")\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.generate_documentation","title":"generate_documentation()
","text":"Generate documentation for all functions managed by this toolkit based on their schemas.
Returns:
Name Type Description str
str
Generated documentation as a markdown string.
Source code in src/aeiva/tool/toolkit/toolkit.py
def generate_documentation(self) -> str:\n \"\"\"\n Generate documentation for all functions managed by this toolkit based on their schemas.\n\n Returns:\n str: Generated documentation as a markdown string.\n \"\"\"\n doc = f\"# Toolkit: {self.toolkit_name}\\n\\n\"\n for api_name, tool in self.tools.items():\n schema = self.tool_schemas.get(api_name, {})\n if not schema:\n continue\n doc += f\"## Function: {api_name}\\n\\n\"\n doc += f\"**Description:** {schema.get('description', 'No description provided.')}\\n\\n\"\n doc += \"### Parameters:\\n\\n\"\n parameters = schema.get(\"parameters\", {})\n for prop, details in parameters.get(\"properties\", {}).items():\n req = \" (required)\" if prop in parameters.get(\"required\", []) else \"\"\n description = details.get(\"description\", \"\")\n default = f\" (default: {details.get('default')})\" if \"default\" in details else \"\"\n doc += f\"- **{prop}** ({details.get('type', 'any')}): {description}{default}{req}\\n\"\n doc += \"\\n### Example:\\n\\n\"\n example = schema.get(\"example\", \"No example provided.\")\n if isinstance(example, dict):\n example = json.dumps(example, indent=4)\n doc += f\"```json\\n{example}\\n```\\n\\n\"\n return doc\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.init_shared_resources","title":"init_shared_resources()
","text":"Initialize shared resources required by the toolkit. Override this method in subclasses if needed.
Source code in src/aeiva/tool/toolkit/toolkit.py
def init_shared_resources(self):\n \"\"\"\n Initialize shared resources required by the toolkit.\n Override this method in subclasses if needed.\n \"\"\"\n logger.info(\"Initializing shared resources.\")\n # Placeholder for initializing shared resources like databases, servers, etc.\n # Example:\n # self.shared_resources = initialize_database_connection()\n pass\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.load_pydantic_models_for_all_tools","title":"load_pydantic_models_for_all_tools()
","text":"Load Pydantic models (Params and Result) for all tools for validation.
Source code in src/aeiva/tool/toolkit/toolkit.py
def load_pydantic_models_for_all_tools(self):\n \"\"\"\n Load Pydantic models (Params and Result) for all tools for validation.\n \"\"\"\n logger.info(\"Loading Pydantic models for all tools.\")\n for tool_name in self.tool_names:\n try:\n param_model, result_model = self.load_pydantic_models_for_tool(tool_name)\n self.tool_models[tool_name] = (param_model, result_model)\n logger.debug(f\"Loaded models for tool '{tool_name}': Params={param_model}, Result={result_model}\")\n except Exception as e:\n logger.error(f\"Failed to load models for tool '{tool_name}': {e}\")\n raise\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.load_pydantic_models_for_tool","title":"load_pydantic_models_for_tool(api_name)
","text":"Load the parameter and result Pydantic models for the given API.
Parameters:
Name Type Description Default api_name
str
The name of the API function.
required Returns:
Type Description Tuple[Type[BaseModel], Type[BaseModel]]
Tuple[Type[BaseModel], Type[BaseModel]]: The parameter and result model classes.
Raises:
Type Description ValueError
If models cannot be loaded.
Source code in src/aeiva/tool/toolkit/toolkit.py
def load_pydantic_models_for_tool(self, api_name: str) -> Tuple[Type[BaseModel], Type[BaseModel]]:\n \"\"\"\n Load the parameter and result Pydantic models for the given API.\n\n Args:\n api_name (str): The name of the API function.\n\n Returns:\n Tuple[Type[BaseModel], Type[BaseModel]]: The parameter and result model classes.\n\n Raises:\n ValueError: If models cannot be loaded.\n \"\"\"\n module_path = f\"aeiva.tool.api.{api_name}.model\" # Adjusted as per user's path\n try:\n models_module = importlib.import_module(module_path)\n param_model_class = getattr(models_module, f\"{snake_to_camel(api_name)}Params\", None)\n result_model_class = getattr(models_module, f\"{snake_to_camel(api_name)}Result\", None)\n if not (param_model_class and issubclass(param_model_class, BaseModel)):\n logger.error(f\"Param model class '{snake_to_camel(api_name)}Params' not found in '{module_path}'.\")\n raise ValueError(f\"Param model class '{snake_to_camel(api_name)}Params' not found in '{module_path}'.\")\n if not (result_model_class and issubclass(result_model_class, BaseModel)):\n logger.error(f\"Result model class '{snake_to_camel(api_name)}Result' not found in '{module_path}'.\")\n raise ValueError(f\"Result model class '{snake_to_camel(api_name)}Result' not found in '{module_path}'.\")\n return param_model_class, result_model_class\n except ImportError as e:\n logger.error(f\"Error importing models from '{module_path}': {e}\")\n raise ImportError(f\"Error importing models from '{module_path}': {e}\")\n except AttributeError as e:\n logger.error(f\"Error accessing model classes in '{module_path}': {e}\")\n raise ValueError(f\"Error accessing model classes in '{module_path}': {e}\")\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.perform_security_checks","title":"perform_security_checks(param_instance)
","text":"Perform security checks on parameters that require sanitization.
Parameters:
Name Type Description Default param_instance
BaseModel
The validated parameter instance.
required Returns:
Name Type Description BaseModel
BaseModel
The sanitized parameter instance.
Raises:
Type Description ValueError
If sanitization fails for any field or if config is required but not provided.
Source code in src/aeiva/tool/toolkit/toolkit.py
def perform_security_checks(self, param_instance: BaseModel) -> BaseModel:\n \"\"\"\n Perform security checks on parameters that require sanitization.\n\n Args:\n param_instance (BaseModel): The validated parameter instance.\n\n Returns:\n BaseModel: The sanitized parameter instance.\n\n Raises:\n ValueError: If sanitization fails for any field or if config is required but not provided.\n \"\"\"\n sanitized_params = param_instance.dict()\n\n for field_name, field in param_instance.__fields__.items():\n sanitize = field.field_info.extra.get('sanitize', False)\n if not sanitize:\n continue # Skip fields that do not require sanitization\n\n field_type = field.type_\n origin = get_origin(field_type)\n args = get_args(field_type)\n\n # Determine if the field is a string type or contains string types\n is_string_field = False\n\n if field_type == str:\n is_string_field = True\n elif origin is Union and str in args:\n is_string_field = True\n elif origin is list and len(args) == 1 and args[0] == str:\n is_string_field = True\n elif origin is Optional and str in args:\n is_string_field = True\n # Add more conditions here if there are other complex types containing strings\n\n if is_string_field:\n original_value = sanitized_params.get(field_name)\n if original_value is None:\n continue # Skip if the field value is None\n\n if self.config is None:\n logger.error(\n f\"Configuration is required to sanitize field '{field_name}', \"\n f\"but config is not provided.\"\n )\n raise ValueError(\n f\"Configuration is required to sanitize field '{field_name}', \"\n f\"but config is not provided.\"\n )\n\n try:\n # If the field is a list of strings, sanitize each path individually\n if origin is list and len(args) == 1 and args[0] == str:\n if not isinstance(original_value, list):\n logger.error(\n f\"Expected a list for field '{field_name}', \"\n f\"got {type(original_value)}.\"\n )\n raise ValueError(\n f\"Expected a list for field '{field_name}'.\"\n )\n sanitized_list = []\n for idx, item in enumerate(original_value):\n sanitized_item = sanitize_file_path(item, self.config)\n sanitized_list.append(sanitized_item)\n logger.debug(\n f\"Sanitized '{field_name}[{idx}]': '{item}' -> '{sanitized_item}'\"\n )\n sanitized_params[field_name] = sanitized_list\n else:\n # Sanitize single string path\n sanitized_path = sanitize_file_path(original_value, self.config)\n sanitized_params[field_name] = sanitized_path\n logger.debug(\n f\"Sanitized '{field_name}': '{original_value}' -> '{sanitized_path}'\"\n )\n except ValueError as ve:\n logger.error(\n f\"Sanitization failed for field '{field_name}': {ve}\"\n )\n raise\n\n # Create a new instance of the parameter model with sanitized parameters\n sanitized_instance = param_instance.copy(update=sanitized_params)\n\n return sanitized_instance\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.setup","title":"setup()
","text":"Setup the toolkit by loading tools, their schemas, and initializing shared resources.
Source code in src/aeiva/tool/toolkit/toolkit.py
def setup(self):\n \"\"\"\n Setup the toolkit by loading tools, their schemas, and initializing shared resources.\n \"\"\"\n logger.info(f\"Setting up toolkit '{self.toolkit_name}'.\")\n\n # Load tools and their schemas\n for tool_name in self.tool_names:\n tool = Tool(api_name=tool_name)\n self.tools[tool_name] = tool\n schema = tool.load_tool_schema(tool_name)\n self.tool_schemas[tool_name] = schema\n logger.debug(f\"Loaded schema for tool '{tool_name}': {schema}\")\n\n # Load Pydantic models for validation\n self.load_pydantic_models_for_all_tools()\n\n # Initialize shared resources\n self.init_shared_resources()\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.teardown","title":"teardown()
","text":"Teardown the toolkit by unloading tools, their schemas, and cleaning up shared resources.
Source code in src/aeiva/tool/toolkit/toolkit.py
def teardown(self):\n \"\"\"\n Teardown the toolkit by unloading tools, their schemas, and cleaning up shared resources.\n \"\"\"\n logger.info(f\"Tearing down toolkit '{self.toolkit_name}'.\")\n\n # Clean up shared resources\n self.teardown_shared_resources()\n\n # Clear loaded data\n self.tools.clear()\n self.tool_schemas.clear()\n self.tool_models.clear()\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit.Toolkit.teardown_shared_resources","title":"teardown_shared_resources()
","text":"Teardown shared resources. Override this method in subclasses if needed.
Source code in src/aeiva/tool/toolkit/toolkit.py
def teardown_shared_resources(self):\n \"\"\"\n Teardown shared resources.\n Override this method in subclasses if needed.\n \"\"\"\n logger.info(\"Tearing down shared resources.\")\n # Placeholder for tearing down shared resources\n # Example:\n # if self.shared_resources:\n # self.shared_resources.close()\n pass\n
"},{"location":"reference/#src.aeiva.tool.toolkit.toolkit_config","title":"toolkit_config
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.toolkit_config.ToolkitConfig","title":"ToolkitConfig
dataclass
","text":" Bases: BaseConfig
Configuration for the Toolkit.
Source code in src/aeiva/tool/toolkit/toolkit_config.py
@dataclass\nclass ToolkitConfig(BaseConfig):\n \"\"\"\n Configuration for the Toolkit.\n \"\"\"\n\n allowed_directories: List[str] = field(\n default_factory=lambda: [\"/tmp/\", \"/home/user/allowed_directory/\"],\n metadata={\"help\": \"Directories that tools are allowed to access.\"}\n )\n # Mapping from OS usernames to roles\n user_role_mapping: Dict[str, str] = field(\n default_factory=lambda: {\n \"admin_user\": \"admin\",\n \"regular_user\": \"user\"\n # Add more user-role mappings as needed\n },\n metadata={\"help\": \"Mapping of OS usernames to their roles.\"}\n )\n # Define permissions for each role\n role_permissions: Dict[str, List[str]] = field(\n default_factory=lambda: {\n \"admin\": [\"delete_file\", \"view_file\", \"create_file\"],\n \"user\": [\"view_file\", \"create_file\"]\n },\n metadata={\"help\": \"Mapping of roles to allowed API functions.\"}\n )\n
"},{"location":"reference/#src.aeiva.tool.toolkit.web_toolkit","title":"web_toolkit
","text":""},{"location":"reference/#src.aeiva.tool.toolkit.web_toolkit.WebToolkit","title":"WebToolkit
","text":" Bases: Toolkit
A toolkit for interacting with web pages.
Source code in src/aeiva/tool/toolkit/web_toolkit.py
class WebToolkit(Toolkit):\n \"\"\"\n A toolkit for interacting with web pages.\n \"\"\"\n\n def __init__(self, config=None):\n super().__init__(\n name=\"WebToolkit\",\n tool_names=[\n \"click_webpage_element\",\n \"crawl\",\n \"execute_js_script_on_webpage\",\n \"get_webpage_details\",\n \"get_webpage_elements\",\n \"navigate_browser_history\",\n \"navigate_to_webpage\",\n \"refresh_webpage\",\n \"scrape\",\n \"scroll_webpage\",\n \"type_text_in_webpage_element\",\n \"web_search\"\n ],\n config=config\n )\n
"},{"location":"reference/#src.aeiva.trainer","title":"trainer
","text":""},{"location":"reference/#src.aeiva.trainer.pl_trainer","title":"pl_trainer
","text":""},{"location":"reference/#src.aeiva.trainer.pl_trainer.LightningTrainer","title":"LightningTrainer
","text":" Bases: LightningModule
Source code in src/aeiva/trainer/pl_trainer.py
class LightningTrainer(pl.LightningModule):\n def __init__(self, model, tokenizer, config):\n super().__init__()\n self.model = model\n self.tokenizer = tokenizer\n self.config = config\n\n def forward(self, batch):\n outputs = self.model(batch)\n return outputs\n\n def training_step(self, batch, batch_idx):\n outputs = self(batch)\n loss = outputs.loss\n return {\"loss\": loss}\n\n def validation_step(self, batch, batch_idx):\n outputs = self(batch)\n loss = outputs.loss\n return {\"loss\": loss}\n\n def test_step(self, batch, batch_idx):\n outputs = self(batch)\n loss = outputs.loss\n return {\"loss\": loss}\n\n def configure_optimizers(self):\n \"\"\"\n Function to prepare the optimizer and learning rate scheduler for model training.\n This function separates model parameters into two categories: parameters that will experience weight decay, \n and those that will not (e.g., bias and layernorm weights). \n\n Returns:\n Tuple[Optimizer, Scheduler]: Tuple containing the optimizer and learning rate scheduler.\n \"\"\"\n\n # List of module types that will be subjected to weight decay.\n whitelist_weight_modules = (torch.nn.Linear, ) \n\n # List of module types that will not be subjected to weight decay.\n blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)\n\n # Parameter sets for decay and no decay.\n decay = set()\n no_decay = set()\n\n # Populate the decay and no_decay sets. \n # Loop over all modules to get module name (mn) and module (m).\n # !!!! revise later.\n # for mn, m in self.model.named_modules():\n # for pn, p in m.named_parameters():\n # fpn = '%s.%s' % (mn, pn) if mn else pn \n\n # if 'bias' in pn:\n # no_decay.add(fpn)\n # elif 'weight' in pn:\n # decay.add(fpn)\n\n param_dict = {pn: p for pn, p in self.model.named_parameters()}\n\n for mn, m in self.model.named_modules():\n for pn, p in m.named_parameters():\n fpn = '%s.%s' % (mn, pn) if mn else pn # full param name\n # random note: because named_modules and named_parameters are recursive\n # we will see the same tensors p many many times. but doing it this way\n # allows us to know which parent module any tensor p belongs to...\n # Adding new condition to check for the 'class_embedding' and 'logit_scale' parameters\n if pn.endswith('bias') or 'class_embedding' in pn or 'logit_scale' in pn:\n no_decay.add(fpn)\n elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):\n no_decay.add(fpn)\n for pn, p in param_dict.items():\n if pn not in no_decay:\n decay.add(pn)\n\n\n # # After this loop, print out all parameters in the intersection of decay and no_decay:\n # print(\"decay: \", decay)\n # print(\"no_decay: \", no_decay)\n # print(\"intersection: \", decay.intersection(no_decay))\n\n # print(\"difference: \", param_dict.keys() - (decay | no_decay))\n\n\n # # 'lm_head.weight' is tied to 'model.embed_tokens.weight', so it should not be decayed. \n # # This ensures that the same tensor is not optimized in different ways.\n # decay.remove('llm.lm_head.weight')\n\n # Validate that we considered every parameter.\n param_dict = {pn: p for pn, p in self.model.named_parameters()}\n assert len(decay & no_decay) == 0, \"Some parameters are in both decay and no_decay sets!\"\n assert len(param_dict.keys() - (decay | no_decay)) == 0, \"Some parameters are in neither decay nor no_decay sets!\"\n\n # Create the PyTorch optimizer object.\n optim_groups = [\n {\"params\": [param_dict[pn] for pn in sorted(list(decay))], \"weight_decay\": self.config.weight_decay},\n {\"params\": [param_dict[pn] for pn in sorted(list(no_decay))], \"weight_decay\": 0.0},\n ]\n # new PyTorch nightly has a new 'fused' option for AdamW that is much faster\n use_fused = (self.config.device == 'cuda') and (\n 'fused' in inspect.signature(torch.optim.AdamW).parameters)\n print(f\"using fused AdamW: {use_fused}\")\n extra_args = dict(fused=True) if use_fused else dict()\n optimizer = torch.optim.AdamW(\n optim_groups, lr=self.config.learning_rate, betas=(self.config.adam_beta1, self.config.adam_beta2), **extra_args)\n\n # Prepare learning rate scheduler.\n total_steps = self.config.max_steps\n pct_start = self.config.warmup_steps / total_steps\n final_div_factor = self.config.learning_rate / self.config.min_lr\n\n scheduler = {\n 'scheduler': torch.optim.lr_scheduler.OneCycleLR(\n optimizer,\n max_lr=self.config.learning_rate,\n total_steps=total_steps,\n pct_start=pct_start,\n final_div_factor=final_div_factor,\n div_factor=1.0, # No additional scaling for the initial learning rate\n anneal_strategy='cos', # Use cosine annealing\n cycle_momentum=False, # Disable momentum cycling\n ),\n 'interval': 'step',\n 'frequency': 1\n }\n\n return [optimizer], [scheduler]\n\n\n def get_num_params(self, non_embedding=True):\n \"\"\"\n Return the number of parameters in the model.\n For non-embedding count (default), the position embeddings get subtracted.\n The token embeddings would too, except due to the parameter sharing these\n params are actually used as weights in the final layer, so we include them.\n \"\"\"\n n_params = sum(p.numel() for p in self.model.parameters())\n if non_embedding:\n embedding_params = sum(p.numel() for m in self.model.modules() if isinstance(m, nn.Embedding) for p in m.parameters())\n n_params -= embedding_params\n return n_params\n\n def estimate_mfu(self, fwdbwd_per_iter, dt):\n \"\"\" estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS \"\"\"\n # first estimate the number of flops we do per iteration.\n # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311\n N = self.get_num_params()\n cfg = self.config\n L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size\n flops_per_token = 6*N + 12*L*H*Q*T\n flops_per_fwdbwd = flops_per_token * T\n flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter\n # express our flops throughput as ratio of A100 bfloat16 peak flops\n flops_achieved = flops_per_iter * (1.0/dt) # per second\n flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS\n mfu = flops_achieved / flops_promised\n return mfu\n
"},{"location":"reference/#src.aeiva.trainer.pl_trainer.LightningTrainer.configure_optimizers","title":"configure_optimizers()
","text":"Function to prepare the optimizer and learning rate scheduler for model training. This function separates model parameters into two categories: parameters that will experience weight decay, and those that will not (e.g., bias and layernorm weights).
Returns:
Type Description Tuple[Optimizer, Scheduler]: Tuple containing the optimizer and learning rate scheduler.
Source code in src/aeiva/trainer/pl_trainer.py
def configure_optimizers(self):\n \"\"\"\n Function to prepare the optimizer and learning rate scheduler for model training.\n This function separates model parameters into two categories: parameters that will experience weight decay, \n and those that will not (e.g., bias and layernorm weights). \n\n Returns:\n Tuple[Optimizer, Scheduler]: Tuple containing the optimizer and learning rate scheduler.\n \"\"\"\n\n # List of module types that will be subjected to weight decay.\n whitelist_weight_modules = (torch.nn.Linear, ) \n\n # List of module types that will not be subjected to weight decay.\n blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)\n\n # Parameter sets for decay and no decay.\n decay = set()\n no_decay = set()\n\n # Populate the decay and no_decay sets. \n # Loop over all modules to get module name (mn) and module (m).\n # !!!! revise later.\n # for mn, m in self.model.named_modules():\n # for pn, p in m.named_parameters():\n # fpn = '%s.%s' % (mn, pn) if mn else pn \n\n # if 'bias' in pn:\n # no_decay.add(fpn)\n # elif 'weight' in pn:\n # decay.add(fpn)\n\n param_dict = {pn: p for pn, p in self.model.named_parameters()}\n\n for mn, m in self.model.named_modules():\n for pn, p in m.named_parameters():\n fpn = '%s.%s' % (mn, pn) if mn else pn # full param name\n # random note: because named_modules and named_parameters are recursive\n # we will see the same tensors p many many times. but doing it this way\n # allows us to know which parent module any tensor p belongs to...\n # Adding new condition to check for the 'class_embedding' and 'logit_scale' parameters\n if pn.endswith('bias') or 'class_embedding' in pn or 'logit_scale' in pn:\n no_decay.add(fpn)\n elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):\n no_decay.add(fpn)\n for pn, p in param_dict.items():\n if pn not in no_decay:\n decay.add(pn)\n\n\n # # After this loop, print out all parameters in the intersection of decay and no_decay:\n # print(\"decay: \", decay)\n # print(\"no_decay: \", no_decay)\n # print(\"intersection: \", decay.intersection(no_decay))\n\n # print(\"difference: \", param_dict.keys() - (decay | no_decay))\n\n\n # # 'lm_head.weight' is tied to 'model.embed_tokens.weight', so it should not be decayed. \n # # This ensures that the same tensor is not optimized in different ways.\n # decay.remove('llm.lm_head.weight')\n\n # Validate that we considered every parameter.\n param_dict = {pn: p for pn, p in self.model.named_parameters()}\n assert len(decay & no_decay) == 0, \"Some parameters are in both decay and no_decay sets!\"\n assert len(param_dict.keys() - (decay | no_decay)) == 0, \"Some parameters are in neither decay nor no_decay sets!\"\n\n # Create the PyTorch optimizer object.\n optim_groups = [\n {\"params\": [param_dict[pn] for pn in sorted(list(decay))], \"weight_decay\": self.config.weight_decay},\n {\"params\": [param_dict[pn] for pn in sorted(list(no_decay))], \"weight_decay\": 0.0},\n ]\n # new PyTorch nightly has a new 'fused' option for AdamW that is much faster\n use_fused = (self.config.device == 'cuda') and (\n 'fused' in inspect.signature(torch.optim.AdamW).parameters)\n print(f\"using fused AdamW: {use_fused}\")\n extra_args = dict(fused=True) if use_fused else dict()\n optimizer = torch.optim.AdamW(\n optim_groups, lr=self.config.learning_rate, betas=(self.config.adam_beta1, self.config.adam_beta2), **extra_args)\n\n # Prepare learning rate scheduler.\n total_steps = self.config.max_steps\n pct_start = self.config.warmup_steps / total_steps\n final_div_factor = self.config.learning_rate / self.config.min_lr\n\n scheduler = {\n 'scheduler': torch.optim.lr_scheduler.OneCycleLR(\n optimizer,\n max_lr=self.config.learning_rate,\n total_steps=total_steps,\n pct_start=pct_start,\n final_div_factor=final_div_factor,\n div_factor=1.0, # No additional scaling for the initial learning rate\n anneal_strategy='cos', # Use cosine annealing\n cycle_momentum=False, # Disable momentum cycling\n ),\n 'interval': 'step',\n 'frequency': 1\n }\n\n return [optimizer], [scheduler]\n
"},{"location":"reference/#src.aeiva.trainer.pl_trainer.LightningTrainer.estimate_mfu","title":"estimate_mfu(fwdbwd_per_iter, dt)
","text":"estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS
Source code in src/aeiva/trainer/pl_trainer.py
def estimate_mfu(self, fwdbwd_per_iter, dt):\n \"\"\" estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS \"\"\"\n # first estimate the number of flops we do per iteration.\n # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311\n N = self.get_num_params()\n cfg = self.config\n L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size\n flops_per_token = 6*N + 12*L*H*Q*T\n flops_per_fwdbwd = flops_per_token * T\n flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter\n # express our flops throughput as ratio of A100 bfloat16 peak flops\n flops_achieved = flops_per_iter * (1.0/dt) # per second\n flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS\n mfu = flops_achieved / flops_promised\n return mfu\n
"},{"location":"reference/#src.aeiva.trainer.pl_trainer.LightningTrainer.get_num_params","title":"get_num_params(non_embedding=True)
","text":"Return the number of parameters in the model. For non-embedding count (default), the position embeddings get subtracted. The token embeddings would too, except due to the parameter sharing these params are actually used as weights in the final layer, so we include them.
Source code in src/aeiva/trainer/pl_trainer.py
def get_num_params(self, non_embedding=True):\n \"\"\"\n Return the number of parameters in the model.\n For non-embedding count (default), the position embeddings get subtracted.\n The token embeddings would too, except due to the parameter sharing these\n params are actually used as weights in the final layer, so we include them.\n \"\"\"\n n_params = sum(p.numel() for p in self.model.parameters())\n if non_embedding:\n embedding_params = sum(p.numel() for m in self.model.modules() if isinstance(m, nn.Embedding) for p in m.parameters())\n n_params -= embedding_params\n return n_params\n
"},{"location":"reference/#src.aeiva.util","title":"util
","text":""},{"location":"reference/#src.aeiva.util.file_utils","title":"file_utils
","text":""},{"location":"reference/#src.aeiva.util.file_utils.from_json_or_yaml","title":"from_json_or_yaml(filepath)
","text":"Load configuration from a JSON or YAML file based on the file extension.
Parameters:
Name Type Description Default filepath
str
The path to the configuration file.
required Returns:
Name Type Description dict
dict
The configuration dictionary.
Raises:
Type Description FileNotFoundError
If the file does not exist.
ValueError
If the file extension is unsupported or if parsing fails.
Source code in src/aeiva/util/file_utils.py
def from_json_or_yaml(filepath: str) -> dict:\n \"\"\"\n Load configuration from a JSON or YAML file based on the file extension.\n\n Args:\n filepath (str): The path to the configuration file.\n\n Returns:\n dict: The configuration dictionary.\n\n Raises:\n FileNotFoundError: If the file does not exist.\n ValueError: If the file extension is unsupported or if parsing fails.\n \"\"\"\n if not os.path.exists(filepath):\n logger.error(f\"Configuration file not found at path: {filepath}\")\n raise FileNotFoundError(f\"Configuration file not found at path: {filepath}\")\n\n _, ext = os.path.splitext(filepath)\n ext = ext.lower()\n\n try:\n with open(filepath, 'r', encoding='utf-8') as f:\n if ext == '.json':\n config = json.load(f)\n logger.info(f\"Loaded JSON configuration from {filepath}.\")\n return config\n elif ext in ['.yaml', '.yml']:\n config = yaml.safe_load(f)\n logger.info(f\"Loaded YAML configuration from {filepath}.\")\n return config\n else:\n logger.error(f\"Unsupported configuration file format: {ext}\")\n raise ValueError(f\"Unsupported configuration file format: {ext}\")\n except (json.JSONDecodeError, yaml.YAMLError) as e:\n logger.error(f\"Error parsing configuration file '{filepath}': {e}\")\n raise ValueError(f\"Error parsing configuration file '{filepath}': {e}\")\n
"},{"location":"reference/#src.aeiva.util.os_utils","title":"os_utils
","text":""},{"location":"reference/#src.aeiva.util.os_utils.get_os_user","title":"get_os_user()
","text":"Retrieve the current OS username.
Returns:
Name Type Description str
str
The current OS user's name.
Source code in src/aeiva/util/os_utils.py
def get_os_user() -> str:\n \"\"\"\n Retrieve the current OS username.\n\n Returns:\n str: The current OS user's name.\n \"\"\"\n return getpass.getuser()\n
"},{"location":"reference/#src.aeiva.util.path_utils","title":"path_utils
","text":""},{"location":"reference/#src.aeiva.util.path_utils.get_package_root","title":"get_package_root(package_name)
","text":"Obtain the root directory of a given package.
Parameters:
Name Type Description Default package_name
str
The name of the package.
required Returns:
Name Type Description str
str
The absolute path to the package root directory.
Source code in src/aeiva/util/path_utils.py
def get_package_root(package_name: str) -> str:\n \"\"\"\n Obtain the root directory of a given package.\n\n Args:\n package_name (str): The name of the package.\n\n Returns:\n str: The absolute path to the package root directory.\n \"\"\"\n spec = importlib.util.find_spec(package_name)\n if spec is None or spec.origin is None:\n raise ImportError(f\"Cannot find package '{package_name}'\")\n package_path = os.path.dirname(os.path.abspath(spec.origin))\n return package_path\n
"},{"location":"reference/#src.aeiva.util.path_utils.get_user_home_path","title":"get_user_home_path()
","text":"Retrieves the home directory of the current user across different platforms.
Supported Platforms - Windows
- macOS
- Linux
- iOS (best-effort)
- Android (best-effort)
Returns:
Name Type Description Path
Path
A Path
object representing the user's home directory.
Raises:
Type Description EnvironmentError
If the home directory cannot be determined.
Source code in src/aeiva/util/path_utils.py
def get_user_home_path() -> Path:\n \"\"\"\n Retrieves the home directory of the current user across different platforms.\n\n Supported Platforms:\n - Windows\n - macOS\n - Linux\n - iOS (best-effort)\n - Android (best-effort)\n\n Returns:\n Path: A `Path` object representing the user's home directory.\n\n Raises:\n EnvironmentError: If the home directory cannot be determined.\n \"\"\"\n system = platform.system()\n logger.info(f\"Detected operating system: {system}\")\n\n try:\n if system == \"Windows\":\n # Windows: Use USERPROFILE or combine HOMEDRIVE and HOMEPATH\n home = os.environ.get('USERPROFILE') or os.path.join(os.environ.get('HOMEDRIVE', ''), os.environ.get('HOMEPATH', ''))\n logger.debug(f\"Windows home directory: {home}\")\n elif system in [\"Linux\", \"Darwin\"]: # Darwin is macOS\n # Unix-like systems: Use expanduser\n home = os.path.expanduser(\"~\")\n logger.debug(f\"Unix-like home directory: {home}\")\n elif system == \"Java\": # Potentially Android (e.g., running via Jython or similar)\n # Android typically uses /data/user/0/<package_name>/ or /sdcard/\n # However, accessing these paths may require specific permissions\n # Here, we attempt to use the HOME environment variable\n home = os.environ.get('HOME') or '/sdcard/'\n logger.debug(f\"Android home directory (best-effort): {home}\")\n elif system == \"iOS\":\n # iOS applications are sandboxed; home directory is typically the app's sandbox\n # Accessing it might require specific APIs or configurations\n # Here, we return the current working directory as a placeholder\n home = Path.cwd()\n logger.debug(f\"iOS home directory (best-effort): {home}\")\n else:\n # Fallback for unknown systems\n home = os.path.expanduser(\"~\")\n logger.warning(f\"Unknown system '{system}'. Falling back to expanduser: {home}\")\n\n if home and os.path.isdir(home):\n return Path(home)\n else:\n raise EnvironmentError(\"Determined home directory does not exist or is not a directory.\")\n\n except Exception as e:\n logger.error(f\"Failed to determine the user's home directory: {e}\")\n raise EnvironmentError(\"Cannot determine the user's home directory.\") from e\n
"},{"location":"reference/#src.aeiva.util.path_utils.snake_to_camel","title":"snake_to_camel(snake_str)
","text":"Convert a snake_case string to CamelCase.
Parameters:
Name Type Description Default snake_str
str
The snake_case string.
required Returns:
Name Type Description str
str
The CamelCase string.
Source code in src/aeiva/util/path_utils.py
def snake_to_camel(snake_str: str) -> str:\n \"\"\"\n Convert a snake_case string to CamelCase.\n\n Args:\n snake_str (str): The snake_case string.\n\n Returns:\n str: The CamelCase string.\n \"\"\"\n components = snake_str.split('_')\n # Capitalize the first letter of each component\n return ''.join(x.title() for x in components)\n
"},{"location":"reference/#src.aeiva.util.token_utils","title":"token_utils
","text":""},{"location":"reference/#src.aeiva.util.token_utils.pad_or_truncate_tokens","title":"pad_or_truncate_tokens(tokens, max_length, pad_token_id)
","text":"This function aims to pad or truncate tokens to max_length.
Parameters:
Name Type Description Default tokens
list
the list of tokens.
required max_length
int
the max length of tokens.
required pad_token_id
int
the id of pad token.
required Returns:
Name Type Description tokens
list
the list of tokens after padding or truncating.
Source code in src/aeiva/util/token_utils.py
def pad_or_truncate_tokens(tokens, max_length, pad_token_id):\n \"\"\" This function aims to pad or truncate tokens to max_length.\n\n Args:\n tokens (list): the list of tokens.\n max_length (int): the max length of tokens.\n pad_token_id (int): the id of pad token.\n\n Returns:\n tokens (list): the list of tokens after padding or truncating.\n \"\"\"\n if len(tokens) > max_length:\n tokens = tokens[:max_length]\n elif len(tokens) < max_length:\n tokens = tokens + [pad_token_id] * (max_length - len(tokens))\n return tokens\n
"},{"location":"tutorials/","title":"Tutorials","text":"Here we summarize some experience we learned during developing Aeiva.
How to generate project documentation automatically from docstrings
"},{"location":"blogs/Thoughts_on_key_concepts_for_agentic_intelligence/","title":"Thoughts on Several Key Concepts for Agentic Intelligence","text":"Author: Bang Liu
Date: 2023-10-21
In building an intelligent agent system, especially one designed to perform complex tasks and learn from experience, it is crucial to clearly define core concepts that guide its behavior. These concepts shape how the agent interacts with its environment, executes tasks, learns from past experiences, and acquires new knowledge. Below are my thoughts on several key concepts, enriched with examples to make them more tangible.
"},{"location":"blogs/Thoughts_on_key_concepts_for_agentic_intelligence/#1-what-is-a-plan","title":"1. What is a Plan?","text":"A Plan is a structured, goal-driven roadmap for an agent to achieve a specific task. The key feature of a Plan is that it decomposes the primary task into subtasks, forming a hierarchical structure. The agent follows this roadmap, completing one subtask after another. Since a plan ultimately governs execution, it must be well-structured\u2014most naturally as a Directed Acyclic Graph (DAG).
Each node in the DAG represents a Task or subtask, and the edges describe dependencies between them. This ensures a logical, stepwise execution where subtasks cannot begin until their dependencies are satisfied.
- Example: Consider an agent tasked with preparing a meal. The plan breaks the main task (\"Cook meal\") into subtasks like \"Chop vegetables,\" \"Boil water,\" \"Cook rice,\" and \"Serve meal.\" Some tasks must precede others (e.g., \"Boil water\" must happen before \"Cook rice\"). This structure forms a DAG, ensuring tasks are completed in the correct order without cycles or deadlocks.
"},{"location":"blogs/Thoughts_on_key_concepts_for_agentic_intelligence/#2-what-is-a-task","title":"2. What is a Task?","text":"A Task is the fundamental unit of work in a plan. Each task has a clear status, which can be one of: - Not Executed: The task is yet to be started. - Executing: The task is currently being performed. - Success: The task has been completed successfully. - Fail: The task has failed, possibly requiring intervention or retry.
Tasks can have meta-data such as the task owner, creation time, priority, or other relevant attributes. A task also needs a mechanism to check whether it has been completed successfully, which might involve running tests or checking outputs against expectations.
- Example: In a factory, an agent may have a task like \"Assemble component A.\" The task could have metadata such as who is responsible (agent A or robot arm B), creation time (timestamp when this task was queued), and a priority level (perhaps \"high\" because component A is needed soon). After execution, the task might check the assembled part for defects before marking itself as \"Success.\"
"},{"location":"blogs/Thoughts_on_key_concepts_for_agentic_intelligence/#3-what-is-a-tool","title":"3. What is a Tool?","text":"A Tool provides functionality that the agent can use to perform actions. In modern software, a tool often takes the form of an API\u2014a set of operations that accept inputs (parameters) and return outputs (results).
Tools can be seen as atomic units of functionality that are executed in isolation, but their results can influence the broader task or plan. Tools are often reusable across different tasks or plans.
- Example: Consider a research assistant agent that interacts with a remote API to retrieve scientific papers. Here, the \"Arxiv API\" is a tool. The agent calls this API (providing search parameters), and the tool returns a list of papers in a structured format. The agent uses this tool to complete tasks like \"Find papers related to quantum computing.\"
"},{"location":"blogs/Thoughts_on_key_concepts_for_agentic_intelligence/#4-what-is-an-action","title":"4. What is an Action?","text":"An Action is a higher-level operation the agent can take. While it may use a tool (or multiple tools), it is broader than just invoking a function. An Action might involve decision-making, performing logic internally, or combining the output of multiple tools.
Whereas tools are about \"doing one thing well,\" actions are more about how the agent decides to use tools or perform processes. Some actions may not even require external tools but might involve manipulating data internally.
- Example: A warehouse robot's action could be \"Pick up an item from shelf A and place it in bin B.\" The action uses the robot\u2019s sensors and movement tools, but the decision-making on how to execute it\u2014like which arm to use or which path to follow\u2014is part of the action.
"},{"location":"blogs/Thoughts_on_key_concepts_for_agentic_intelligence/#5-what-is-a-skill","title":"5. What is a Skill?","text":"A Skill is the procedural knowledge an agent uses to complete a task. It represents a series of actions or steps the agent follows to solve a problem. Skills can be encoded as DAGs, with each node representing an action, and the edges defining the flow or dependencies between actions.
What distinguishes a Skill from hardcoded instructions is its flexibility. For instance, a skill may allow for different actions to be taken in varying orders, or certain parameters may be adjusted dynamically. In other words, a skill isn\u2019t rigid but adaptable to different contexts or environments.
- Example: An agent trained to clean a room could have a \"Cleaning skill.\" It involves subtasks like \"vacuum the floor,\" \"wipe the table,\" and \"empty the trash.\" In some cases, the agent may vacuum first and then wipe the table, but in others, it may reverse the order depending on room conditions. The ability to adapt while following a general cleaning procedure is what makes it a skill.
"},{"location":"blogs/Thoughts_on_key_concepts_for_agentic_intelligence/#6-what-is-an-experience","title":"6. What is an Experience?","text":"An Experience is a personal record of how an agent solved a particular task. While the structure of an Experience may resemble that of a Skill, it is tied to a specific instance.
The main distinction is that Experiences are not generalized. Instead, they capture the details of how a task was solved under particular circumstances, including all the decisions, parameters, and actions taken during the process. Over time, multiple experiences can be analyzed to derive common patterns, which evolve into Skills.
- Example: After attempting to solve several puzzles, an agent might log each experience\u2014how it solved the puzzle, what tools it used, how long it took, etc. After analyzing several such experiences, the agent may extract a general strategy (skill) for solving puzzles of this type.
"},{"location":"blogs/Thoughts_on_key_concepts_for_agentic_intelligence/#7-what-is-memory","title":"7. What is Memory?","text":"Memory is the broader concept that includes all the data an agent remembers about its past actions, interactions, and decisions. Memory could encompass many forms, including: - Experiential memory: Specific memories about how the agent solved tasks (as described in Experience). - Episodic memory: Memory of specific events or interactions the agent has been part of. - Semantic memory: Knowledge the agent has learned about its environment or domain.
Memory plays a critical role in making an agent \"intelligent,\" as it allows the agent to learn from past mistakes, reuse successful strategies, and adapt to new situations by recalling prior experiences.
- Example: A personal assistant agent might have episodic memory of the last time it scheduled a meeting for the user. The next time the user asks it to schedule a meeting, it can reference that memory to understand the user's preferences, such as their preferred meeting time.
"},{"location":"blogs/Thoughts_on_key_concepts_for_agentic_intelligence/#8-what-is-knowledge","title":"8. What is Knowledge?","text":"Knowledge is a validated, generalizable form of learning. While an experience is a personal, one-off record, knowledge has been abstracted and validated across multiple situations. Knowledge allows the agent to generalize beyond specific experiences and apply what it has learned to new, similar tasks.
In many cases, a Skill represents a particular type of knowledge\u2014the procedural knowledge required to complete a task. Knowledge might also be sharable between agents, or taught from one agent to another, making it reusable.
- Example: An agent that has learned to solve various types of math problems can generalize its knowledge into a set of skills. When faced with a new math problem, it can apply this knowledge, even if the problem differs slightly from the ones it has solved before.
"},{"location":"blogs/Thoughts_on_key_concepts_for_agentic_intelligence/#closing-thoughts","title":"Closing Thoughts","text":"These key concepts\u2014Plan, Task, Tool, Action, Skill, Experience, Memory, and Knowledge\u2014form the foundation of agentic intelligence. Together, they allow an agent to: - Decompose tasks into executable steps (Plan), - Perform specific actions (Task, Action, Tool), - Learn from both immediate tasks and general experiences (Experience, Memory), - Generalize that learning into knowledge that improves future performance (Knowledge, Skill).
By keeping these concepts clear and well-defined, an agent can operate in a structured, intelligent way, continually learning and improving over time.
"},{"location":"blogs/unity_dev_notes/","title":"Notes about how I developed the AI Konoha Village","text":""},{"location":"blogs/unity_dev_notes/#obtain-the-hidden-leaf-village-3d-model","title":"Obtain the \"hidden leaf village \" 3D model","text":"I downloaded it from here:
https://mega.nz/file/vkcHSYLT#t5gG06y65gEp8g3U8N8Yic5BijvZ0PA_7UstCmnoG38
https://www.deviantart.com/naruko-uzumaki-the-9/art/Hidden-Leaf-Village-Complete-DL-Fixed-809223977
"},{"location":"blogs/unity_dev_notes/#import-to-blender","title":"Import to blender","text":"In my case, I cannot directly open the files. But I can import the .fbx file in blender 3.6 (mac M1). Change the Render Engine from Eevee to Workbench, and then at the Color drop menu, select Texture. Then press \"Z\" and select \"render\" model. You will see colored model there.
"},{"location":"blogs/unity_dev_notes/#import-export-fbx-file-from-blender","title":"import & export .fbx file from blender","text":"When export .fbx file from blender and load to unity, it may encounter errors like mesh tangents or self intersection warning. The way to solve this is: 1. Install Better FBX Importer Exporter plugin for blender (it solves the mesh tangent problem); 2. When export using the plugin, select triangulate (it solves the intersection problem).
"},{"location":"blogs/unity_dev_notes/#import-fbx-or-dae-file-to-unity","title":"import .fbx or .dae file to unity","text":"I found the best way is directly drag the whole folder including the materials/textures to the asset folder of the unity project. Then unity will load the assets in the folder and generate .meta data. After that, we can drag the assets to the project from the \"project\" window. Note that seems unity 2022 doesn't show project and inspector windows by default. But unity 2021 can show the windows. For unity 2022, we can select \"Window -> Layouts -> Default\" to get the desired layout.
I also compared the .dae and .fbx file for the hidden leaf village model. In Unity, seems the \"Hidden Leaf Village - Complete.dae\" file looks better in Unity.
"},{"location":"tutorials/How_to_Make_a_Python_Package/","title":"How to Make Your Python Project a Pip-Installable Package","text":"Author: Bang Liu
Date: 2024-11-23
This guide walks you through the process of creating a Python package that others can install using pip
.
"},{"location":"tutorials/How_to_Make_a_Python_Package/#step-1-structure-your-project","title":"Step 1: Structure Your Project","text":"Organize your project with a proper directory structure:
your_project/\n\u251c\u2500\u2500 src/\n\u2502 \u2514\u2500\u2500 your_project/\n\u2502 \u251c\u2500\u2500 __init__.py # Makes this a package\n\u2502 \u251c\u2500\u2500 module.py # Your module files\n\u251c\u2500\u2500 setup.py # Metadata and build script\n\u251c\u2500\u2500 README.md # Project description\n\u251c\u2500\u2500 LICENSE # License file (optional but recommended)\n\u251c\u2500\u2500 requirements.txt # Dependency file (optional)\n
src/your_project/
: Contains your package code. __init__.py
: Makes the folder a Python package. setup.py
: Defines metadata and installation behavior.
"},{"location":"tutorials/How_to_Make_a_Python_Package/#step-2-create-setuppy","title":"Step 2: Create setup.py
","text":"setup.py
is the script used to build and install your package. Here's a sample:
from setuptools import setup, find_packages\n\nsetup(\n name=\"your_project\", # Your package name\n version=\"0.1.0\", # Package version\n author=\"Your Name\", # Your name\n author_email=\"your.email@example.com\", # Your email\n description=\"A brief description\", # Short description\n long_description=open('README.md').read(), # Long description from README\n long_description_content_type='text/markdown', # Markdown format\n url=\"https://github.com/username/repository\", # Project repository\n packages=find_packages(where=\"src\"), # Find packages in src/\n package_dir={\"\": \"src\"}, # Root directory for packages\n classifiers=[ # Metadata for PyPI\n \"Programming Language :: Python :: 3\",\n \"License :: OSI Approved :: MIT License\",\n \"Operating System :: OS Independent\",\n ],\n python_requires='>=3.6', # Minimum Python version\n install_requires=[ # Dependencies\n \"numpy\", # Example dependency\n ],\n)\n
"},{"location":"tutorials/How_to_Make_a_Python_Package/#step-3-create-readmemd","title":"Step 3: Create README.md
","text":"Write a README.md
file to describe your project. Use Markdown for formatting. Example:
# Your Project Name\n\nA short description of your project.\n\n## Installation\n\n```bash\npip install your_project\n
"},{"location":"tutorials/How_to_Make_a_Python_Package/#usage","title":"Usage","text":"import your_project\nyour_project.some_function()\n
"},{"location":"tutorials/How_to_Make_a_Python_Package/#step-4-test-your-package-locally","title":"Step 4: Test Your Package Locally","text":"Test your package before publishing:
-
Navigate to your project root: bash cd /path/to/your_project
-
Install it in editable mode: bash pip install -e .
-
Import your package to verify: ```bash python
import your_project your_project.some_function() ```
"},{"location":"tutorials/How_to_Make_a_Python_Package/#step-5-build-the-package","title":"Step 5: Build the Package","text":"Install the necessary tools:
pip install build\n
Build your package:
python -m build\n
This creates a dist/
directory with .tar.gz
and .whl
files.
"},{"location":"tutorials/How_to_Make_a_Python_Package/#step-6-upload-to-pypi","title":"Step 6: Upload to PyPI","text":" - Register on PyPI:
- Create an account at PyPI.
-
Optionally, register on TestPyPI for testing.
-
Install Twine: bash pip install twine
-
Upload Your Package: bash python -m twine upload dist/*
To test uploads on TestPyPI: bash python -m twine upload --repository testpypi dist/*
- Provide Your PyPI Token:
-
If prompted, enter your PyPI API token.
-
Alternate way to uplaod
python -m twine upload --repository-url https://upload.pypi.org/legacy/ dist/* -u __token__ -p pypi-<your token password here>\n
"},{"location":"tutorials/How_to_Make_a_Python_Package/#step-7-verify-installation","title":"Step 7: Verify Installation","text":"Install your package from PyPI:
pip install your_project\n
Verify it works as expected:
python\n>>> import your_project\n>>> your_project.some_function()\n
"},{"location":"tutorials/How_to_Make_a_Python_Package/#tips-and-best-practices","title":"Tips and Best Practices","text":" - Include a License: Add a
LICENSE
file to clarify usage terms. - Automate Versioning: Use tools like
bumpversion
to manage versions. - Test Thoroughly: Use TestPyPI before uploading to the main PyPI repository.
- Secure Tokens: Use project-specific tokens for uploads.
Congratulations! Your project is now a pip-installable Python package.
"},{"location":"tutorials/coding_guidelines/","title":"Coding Guidelines","text":""},{"location":"tutorials/coding_guidelines/#code-hierarchy","title":"Code Hierarchy","text":"Generally, our code can be organized into three different levels:
-
Framework: This level forms the architectural backbone of your project. It houses the core functionalities that define the basic structure and shared logic for your project. These files, typically stored under the package_name/bases/
directory, establish the protocols and high-level operations that the rest of your project will adhere to.
-
Brick: The \"Brick\" level acts as a collection of modular, reusable components used across your project. These components, which are stored in the package_name/xxx/
directory, promote code reusability and reduce redundancy, thereby enhancing the efficiency of your codebase.
-
Applications: This level contains specific implementations associated with particular datasets, models, or experiments. These files, which are stored in the package_name/
directory, are separate from the abstract base classes and reusable functions found in the other levels. This separation aids in code navigation and readability, making it easier to locate and understand the specific components of your project.
By adhering to this structure, your codebase will be well-organized, easily navigable, and efficient. This organization adheres to best practices in software development, promoting code reusability and a clear separation of concerns.
"},{"location":"tutorials/coding_guidelines/#generate-requirementstxt","title":"Generate requirements.txt","text":"Use pipreqs: pipreqs is a useful tool that generates a requirements.txt file based on the imports in your Python project, not on the installed packages in your current environment. You can install it and use it as follows:
pip install pipreqs\npipreqs --force /path/to/your/project\n
"},{"location":"tutorials/coding_guidelines/#args-and-kwargs","title":"args and *kwargs","text":"*args
and **kwargs
in Python allow a function to accept optional arguments, meaning that the user can pass a variable number of arguments to these functions. Here's when you might want to use them:
-
When you're not sure how many arguments might be passed to your function: *args
is used to send a non-keyworded variable-length argument list to your function. You might use it when you're not sure how many arguments might be passed to your function, or if you want to support an arbitrary number of arguments.
-
When you want to write a function that must accept a dictionary: **kwargs
is used to pass a keyworded, variable-length argument list. You would use this if you want your function to be able to accept a dictionary of attributes.
-
When creating wrapper functions or decorators: *args
and **kwargs
are commonly used when you're writing higher-order functions or decorators that need to manipulate the inputs to another function that they're wrapping.
-
When subclassing and you want to extend the parent class's methods: In this case, you may not know exactly what the parent class's method takes as arguments. *args
and **kwargs
let you pass any parameters from the child class to the parent class's method without having to know what those parameters are.
However, while *args
and **kwargs
are very helpful, they should be used judiciously. Overuse can make your code harder to understand and debug since it's not immediately clear what arguments a function expects. When writing a function, if you know the exact number and role of each argument, it's better to list them explicitly.
In summary, *args
and **kwargs
are powerful tools that make Python functions more flexible. However, as with any tool, they should be used judiciously and appropriately.
"},{"location":"tutorials/coding_guidelines/#order-of-function-arguments-in-python","title":"Order of Function Arguments in Python","text":"In Python, the recommended order of function parameters is as follows:
-
Required positional arguments: These are arguments that need to be in a specific positional order. When calling the function, Python interprets them based on their order.
Example: def func(name, age):
-
Optional positional arguments / Default Parameters: These are arguments that are optional and have a default value. They are also interpreted based on their order.
Example: def func(name, age=22):
-
Required keyword-only arguments: These are arguments that must be supplied by keyword and follow a \"*,\" in the function definition.
Example: def func(name, age, *, city):
-
Optional keyword-only arguments / Default Keyword Parameters: These are keyword arguments that are optional. The function will use the default value if no value is provided.
Example: def func(name, age, *, city='New York'):
-
Arbitrary argument lists: The *args
and **kwargs
parameters, which collect all positional and keyword arguments that are not already caught by other parameters.
Example: def func(name, age, *args, city='New York', **kwargs):
This order can help make your function definitions clear and easy to read. It also helps prevent common bugs caused by confusing positional and keyword arguments.
"},{"location":"tutorials/coding_guidelines/#naming-noun-or-verb","title":"Naming: Noun or Verb?","text":"Thing Choice of Word Modules Noun Data types Noun or Adjective Functions Noun or Verb Constants/Variables Noun - Try to make your name short and avoid longer than 3 words name if possible.
- Use verb or noun for functions or methods depends on what you want to emphasize: the return result or the process to get the result.
To better choose verbs for functions, below are some suggestions:
- Is the function a test? -> test_\\_\\.
-
Does the function has a @property decorator? -> don\u2019t use a verb in the function name.
-
Does the function use a disk or a network:
3.1. \u2026 to store data? -> save_to, send, write_to
3.2. \u2026 to receive data? -> fetch, load, read
-
Does the function output any data? -> print, output
-
Returns boolean value? -> is_, has_/have_, can_, check_if_\\_\\
-
Aggregates data? -> calculate, extract, analyze
-
Put data from one form to another:
7.1. Creates a single meaningful object? -> create
7.2. Fills an existing object with data? -> initialize, configure
7.3. Clean raw data? -> clean
7.4. Receive a string as input? -> parse
7.5. Return a string as output? -> render
7.6. Return an iterator as output? ->iter
7.7. Mutates its arguments or some global state? -> update, mutate, add, remove, insert, set
7.8. Return a list of errors? -> validate
7.9. Checks data items recursively? -> walk
7.10. Finds appropriate item in data? -> find, search, match
7.11. Transform data type? -> \\_to_\\
7.12. None of the above, but still works with data? -> Check one of those: morph, compose, prepare, extract, generate, initialize, filter, map, aggregate, export, import, normalize, calculate .
"},{"location":"tutorials/coding_guidelines/#install-package","title":"Install package","text":"We can install the package we are developing by the following command:
pip install -e .\n
It means we are installing it in editable mode. In Python, if you want to be able to edit your package and have the changes be reflected immediately without needing to reinstall the package every time, you can use pip to install the package in \"editable\" mode.
If you are worried about the state of your package affecting other parts of your system or other projects, you might consider using a virtual environment. A virtual environment is an isolated Python environment, separate from your system Python and other virtual environments. You can install your package in a virtual environment and make changes and test without worrying about affecting other projects.
"},{"location":"tutorials/coding_guidelines/#reference","title":"Reference","text":"1(https://ahsmart.com/pub/naming-things-properly/ ).
2(https://melevir.medium.com/python-functions-naming-the-algorithm-74320a18278d)
"},{"location":"tutorials/generate_docs/","title":"How to generate docs automatically","text":"Author: Bang Liu
Date: 2023-08-05
In this document, I will introduce how to automatically generate the documentation for your python project with several tools.
"},{"location":"tutorials/generate_docs/#install-libraries","title":"Install libraries","text":"We use the following python packages:
- MkDocs for building static pages from Markdown
- mkdocstrings for auto-generating documentation from docstrings in your code
- Material for MkDocs for styling your documentation
pip install --upgrade pip\npip install mkdocs\npip install mkdocstrings\npip install mkdocs-material\n
You can install support for specific languages using extras, for example:
pip install 'mkdocstrings[crystal,python]'\n
Note: the support for specific languages are not installed by default, so I would recommend install by the above command.
"},{"location":"tutorials/generate_docs/#create-mkdocs-project","title":"Create mkdocs project","text":"Now assume you are in the root directory of your project:
mkdocs new .\n
You will see:
INFO - Writing config file: ./mkdocs.yml\nINFO - Writing initial docs: ./docs/index.md\n
MkDocs comes with a built-in dev-server that lets you preview your documentation as you work on it. Make sure you're in the same directory as the mkdocs.yml
configuration file, and then start the server by running the mkdocs serve
command:
% mkdocs serve\nINFO - Building documentation...\nINFO - Cleaning site directory\nWARNING - Excluding 'README.md' from the site because it conflicts with\n 'index.md'.\nINFO - Documentation built in 0.08 seconds\nINFO - [14:25:59] Watching paths for changes: 'docs', 'mkdocs.yml'\nINFO - [14:25:59] Serving on http://127.0.0.1:8000/\nINFO - [14:26:11] Browser connected: http://127.0.0.1:8000/\n
Open up http://127.0.0.1:8000/ in your browser, and you'll see the default home page being displayed.
"},{"location":"tutorials/generate_docs/#customize-your-mkdocsyml","title":"Customize your mkdocs.yml","text":"We can customize the style of our documentation. Edit the ./mkdocs.yml file:
site_name: your-project-name\nsite_url: your-project-website\nnav:\n - Home: index.md\ntheme:\n name: \"material\"\n
This way, we can use the material theme. You can also use other themes [1,2].
"},{"location":"tutorials/generate_docs/#add-more-markdown-files-to-the-documentation","title":"Add more markdown files to the documentation","text":"As described in [1], we can follow the structure proposed in the Di\u00e1taxis documentation framework, which suggests splitting your documentation into four distinct parts:
- Tutorials
- How-To Guides
- Reference
- Explanation
Therefore, we can create these markdown files and put them into the ./docs/ folder. Then we edit our mkdocs.yml configuration file to add them:
site_name: your-project-name\nsite_url: your-project-website\n\nnav:\n - index.md\n - tutorials.md\n - how-to-guides.md\n - reference.md\n - explanation.md\n\ntheme:\n name: \"material\"\n
We can also edit the titles for each page, adjust their order, and so on. See [1] for more details.
"},{"location":"tutorials/generate_docs/#generate-document-from-docstrings","title":"Generate document from Docstrings","text":"We need to use mkdocstrings
package for this purpose.
MkDocs is a static-site generator geared toward writing documentation. However, you can\u2019t fetch docstring information from your code using MkDocs alone. You can make it work with an additional package called mkdocstrings.
You already installed mkdocstrings into your virtual environment at the beginning of this tutorial, so you only need to add it as a plugin to your MkDocs configuration file:
site_name: your-project-name\nsite_url: your-project-website\n\nplugins:\n - mkdocstrings\n\nnav:\n - index.md\n - tutorials.md\n - how-to-guides.md\n - reference.md\n - explanation.md\n\ntheme:\n name: \"material\"\n
Now, to generate documentation from soruce code docstrings, we can select a markdown file, e.g., the reference.md file we have created, and put identifiers in it.
Mkdocstrings allows you to insert docstring information right into your Markdown pages using a special syntax of three colons (:::) followed by the code identifier that you want to document:
::: identifier\n
The identifier is a string identifying the object you want to document. The format of an identifier
can vary from one handler to another. For example, the Python handler expects the full dotted-path to a Python object: my_package.my_module.MyClass.my_method
[3].
The syntax to use identifier is:
::: identifier\n YAML block\n
See https://mkdocstrings.github.io/usage/ for more details.
Basically, the YAML block is optional, and contains some configuration options.
For global options, we can put it in mkdocs.yml
. For example:
plugins:\n- mkdocstrings:\n enabled: !ENV [ENABLE_MKDOCSTRINGS, true]\n custom_templates: templates\n default_handler: python\n handlers:\n python:\n options:\n show_source: false\n
And global configurations can be overridden by local configurations.
See [3] for more detailed tutorials. Briefly summarize, with mkdocstrings, we can use identifiers to gather the docstrings in our code and turn them into documentation.
Tips: Maintain a good coding style is very important. I prefer to use the docstring style listed here: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings
"},{"location":"tutorials/generate_docs/#automatically-collect-all-the-docstrings-in-a-module","title":"Automatically collect all the docstrings in a module","text":"To avoid manually write the identifiers for each submodule/class/method in a markdown file to include the corresponding docstrings in our documentation, we can use the following option:
::: src.aeiva.agent\n options:\n show_submodules: true\n
The above example will automatically introduce all the docstrings in the aeiva.agent package into our documentation.
"},{"location":"tutorials/generate_docs/#advanced-theme-customization","title":"Advanced Theme Customization","text":""},{"location":"tutorials/generate_docs/#changing-the-logo-and-icons","title":"Changing the logo and icons","text":"See: https://squidfunk.github.io/mkdocs-material/setup/changing-the-logo-and-icons/
"},{"location":"tutorials/generate_docs/#customize-the-landing-home-page","title":"Customize the landing home page","text":"We can further customize the home page of our documentation.
First, set your custom_dir in mkdocs.yml:
theme:\n custom_dir: docs/overrides\n...\n\n
The above setting use overrides directory in docs/ as the custom directory.
We than copy all the contents in: https://github.com/squidfunk/mkdocs-material/tree/master/src/.overrides to our docs/overrides/
folder.
Next, in the front matter of your index.md, you need to specify the template to use (copy below to index.md):
---\ntitle: Title\ntemplate: home.html\n---\n
One important thing that took me a while to realize: you need a newline at the end of your md file. If you don't have one, the content will not display [6].
Finally, we can customize the home.html
and main.html
in the overrides folder to make it consistent with our project.
See [6] for a reference.
Note: I found the landing page on https://squidfunk.github.io/mkdocs-material/ is really fancy! It is based on Parallax Image Effect using html and css. To DIY the effect, I downloaded the source file of the webpage directly, and then replace all assets/images/layers/
in the html source file with ./Material for MkDocs_files/
. Because this is the only folder I can get with downloading. I haven't done with understanding and customizing the landing homepage based on this template. To be tested in the future. :) (I put this verion in docs/overrides-dev/)
"},{"location":"tutorials/generate_docs/#organize-your-documentation","title":"Organize your documentation","text":""},{"location":"tutorials/generate_docs/#navbar-nesting","title":"Navbar nesting","text":"You can add an additional level to your navbar like this:
nav:\n - Home: index.md\n - About: about.md\n - Foo:\n - Overview: foo/index.md\n - Bar: foo/bar.md\n
"},{"location":"tutorials/generate_docs/#reference-to-another-markdown-file","title":"Reference to another markdown file","text":"In a markdown document, we can refer to another file from one file, like the following:
[How to generate project documentation automatically from docstrings](./GENERATE_DOCS.md)\n
"},{"location":"tutorials/generate_docs/#deploy-your-documentation-to-github","title":"Deploy Your Documentation to GitHub","text":"GitHub repositories automatically serve static content when committed to a branch named gh-pages. MkDocs integrates with that and allows you to build and deploy your project documentation in a single step:
mkdocs gh-deploy\n
Running this command rebuilds the documentation from your Markdown files and source code and pushes it to the gh-pages branch on your remote GitHub repository.
Because of GitHub\u2019s default configuration, that\u2019ll make your documentation available at the URL that MkDocs shows you at the end of your terminal output:
INFO - Your documentation should shortly be available at:\n https://user-name.github.io/project-name/\n
"},{"location":"tutorials/generate_docs/#summarize","title":"Summarize","text":"So we basically follow the following procedures to create our documentation:
- Create virtual env for your project. Create your project. Create your github repository.
- Install the libraries: mkdocs, mkdocstrings, mkdocs-material
- Go to the project root directory.
- Use mkdocs to create the docs. It will produce
mkdocs.yml
and ./docs/index.md
. - Customize the
mkdocs.yml
. Basically, this is the global setting of the documentation. See [2] for details. You can customize your documentation theme to materials
theme that supported by mkdocs-material
python package. - Customize the contents in
./docs/
. Basically, you can create different markdown files here; you can automatically create documentation contents from docstrings of your code by using ::: identifier
that supported by mkdocstrings
. See [4] for details. - Customize the organization of your documentation. For example, you can use nested navigation; you can use cross-reference, etc.
- Build your documentation using ```mkdocs build.
- Host your documentation using
mkdocs gh-deploy
. Your documentation should shortly be available at: https://user-name.github.io/project-name/
.
"},{"location":"tutorials/generate_docs/#more","title":"More","text":"Please read [1,2,3,4] for more detailed tutorials.
"},{"location":"tutorials/generate_docs/#reference","title":"Reference","text":"1(https://realpython.com/python-project-documentation-with-mkdocs/)
2(https://www.mkdocs.org/getting-started/)
3(https://mkdocstrings.github.io/)
4(https://github.com/squidfunk/mkdocs-material)
[5] Di\u00e1taxis A systematic approach to technical documentation authoring.
6(https://github.com/squidfunk/mkdocs-material/issues/1996)
"},{"location":"tutorials/install_minedojo/","title":"Install MineDojo platform on MacBook Pro with M1 Chip","text":"Author: Bang Liu
Date: 2023-08-01
"},{"location":"tutorials/install_minedojo/#setup-java-environment","title":"Setup Java Environment","text":"I followed the instructions on: https://docs.minedojo.org/sections/getting_started/install.html#prerequisites
Specifically, remember to list all installed Java and and export the temurin8 version java:
/usr/libexec/java_home -V\nexport JAVA_HOME=path/to/eclipse/temurin8\n
After run
java -version\n
I got
openjdk version \"1.8.0_332\"\nOpenJDK Runtime Environment (Temurin)(build 1.8.0_332-b09)\nOpenJDK 64-Bit Server VM (Temurin)(build 25.332-b09, mixed mode)\n
"},{"location":"tutorials/install_minedojo/#install-minedojo","title":"Install MineDojo","text":"I used the following command: (Assume Java JDK 8 is already installed)
pip3 install setuptools==65.5.0 pip==21\npip3 install gym==0.21\ngit clone https://github.com/MineDojo/MineDojo && cd MineDojo\npip install -e .\n
Note: I found that at the end, if I install from source, I cannot remove the source directory. So after resolved all the bugs as follows, I reinstalled minedojo via pip in my conda virtual env:
pip install minedojo\n
So I would recommend install via pip rather than from source.
"},{"location":"tutorials/install_minedojo/#debug-experience","title":"Debug experience","text":"There are many different bugs when I try to run
python scripts/validate_install.py\n
Below, I list all the operations I have done.
"},{"location":"tutorials/install_minedojo/#upgraded-gradle","title":"Upgraded gradle","text":"Check the following: https://gradle.org/install/
After installed the new gradle, I got:
>>> gradle -v\n\n------------------------------------------------------------\nGradle 8.2.1\n------------------------------------------------------------\n\nBuild time: 2023-07-10 12:12:35 UTC\nRevision: a38ec64d3c4612da9083cc506a1ccb212afeecaa\n\nKotlin: 1.8.20\nGroovy: 3.0.17\nAnt: Apache Ant(TM) version 1.10.13 compiled on January 4 2023\nJVM: 1.8.0_332 (Temurin 25.332-b09)\nOS: Mac OS X 10.16 x86_64\n\n
"},{"location":"tutorials/install_minedojo/#malmo-errors","title":"Malmo errors","text":"I referred to: https://github.com/MineDojo/MineDojo/issues/32#issuecomment-1237247417 It says:
For Deprecated Gradle feature --> Go to Malmo project download latest prebuild version https://github.com/Microsoft/malmo/releases. Then find and replace the Malmo directory in your python package directory @ xxx/minedojo/sim/Malmo on your computer. (Reminder directory shall keep the same name \"Malmo\")
For \"OpenGL: ERROR RuntimeException: No OpenGL context found in the current thread.\" (X Error & bad value) --> make sure you run sudo apt update && sudo apt upgrade before you compile the minecraft java program as the same problem has been described in https://stackoverflow.com/questions/28867285/lwjgl-reports-that-opengl-is-not-supported-on-a-modern-nvidia-card. This works for me.
Before running python Minedojo code, go xxx/minedojo/sim/Malmo/Minecraft/ where your python put minedojo package and execute ./launchClient.sh (for linux/unix) or .\\launchClient (for windows, there's a launchClient.bat file) and make sure it can run normally before you start with Minedojo.
Specifically, when I try to run ./launchClient.sh, I got error due to tools.jar, so I did the following:
copy tools.jar from \n/Library/Java/JavaVirtualMachines/temurin-8.jdk/Contents/Home/lib\nto\n/Library/Internet Plug-Ins/JavaAppletPlugin.plugin/Contents/Home/lib\n\n>>> sudo copy XXX XXX\npasswd: (For me, it is the same as the passwd when I login to my macbook pro: the name :)\n
Then, it still fail. So I used back the original Malmo in MineDojo installation (i.e., maybe we DON'T need to download latest prebuild version https://github.com/Microsoft/malmo/releases and then find and replace the Malmo directory in your python package directory ).
Now it can run. But still some error due to
raise NoSuchProcess(self.pid, self._name)\npsutil.NoSuchProcess: process no longer exists (pid=50957, name='bash')\n
I removed the
env.close()\n
in the script and it works.
This is not the end of the story: I found the script doesn't always work. Sometimes, I don't need to remvoe the env.close()
and it still works. Sometimes it doesn't work due to errors like
...\n at org.apache.http.impl.DefaultBHttpClientConnection.receiveResponseHeader(DefaultBHttpClientConnection.java:163)\n at org.apache.http.impl.conn.CPoolProxy.receiveResponseHeader(CPoolProxy.java:165)\n at org.apache.http.protocol.HttpRequestExecutor.doReceiveResponse(HttpRequestExecutor.java:273)\n at org.apache.http.protocol.HttpRequestExecutor.execute(HttpRequestExecutor.java:125)\n at org.apache.http.impl.execchain.MainClientExec.createTunnelToTarget(MainClientExec.java:473)\n at org.apache.http.impl.execchain.MainClientExec.establishRoute(MainClientExec.java:398)\n at org.apache.http.impl.execchain.MainClientExec.execute(MainClientExec.java:237)\n at org.apache.http.impl.execchain.ProtocolExec.execute(ProtocolExec.java:185)\n at org.apache.http.impl.execchain.RetryExec.execute(RetryExec.java:89)\n at org.apache.http.impl.execchain.RedirectExec.execute(RedirectExec.java:111)\n at org.apache.http.impl.client.InternalHttpClient.doExecute(InternalHttpClient.java:185)\n at org.apache.http.impl.client.CloseableHttpClient.execute(CloseableHttpClient.java:83)\n at org.gradle.internal.resource.transport.http.HttpClientHelper.performHttpRequest(HttpClientHelper.java:148)\n at org.gradle.internal.resource.transport.http.HttpClientHelper.performHttpRequest(HttpClientHelper.java:126)\n at org.gradle.internal.resource.transport.http.HttpClientHelper.executeGetOrHead(HttpClientHelper.java:103)\n at org.gradle.internal.resource.transport.http.HttpClientHelper.performRequest(HttpClientHelper.java:94)\n ... 171 more\n\n\n* Get more help at https://help.gradle.org\n\nBUILD FAILED in 31s\n\n\nMinecraft process finished unexpectedly. There was an error with Malmo.\n
I suppose it is due to some network connection errors?
Anyway, now it can work.
"}]}
\ No newline at end of file