Skip to content

Commit

Permalink
feat: Add enum and default value support in task processing
Browse files Browse the repository at this point in the history
Signed-off-by: Anupam Kumar <kyteinsky@gmail.com>
  • Loading branch information
kyteinsky committed Aug 9, 2024
1 parent 3d26987 commit ce6ece1
Showing 1 changed file with 97 additions and 30 deletions.
127 changes: 97 additions & 30 deletions nc_py_api/ex_app/providers/task_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,104 @@
import contextlib
import dataclasses
import typing
from enum import IntEnum

from pydantic import RootModel
from pydantic.dataclasses import dataclass

from ..._exceptions import NextcloudException, NextcloudExceptionNotFound
from ..._misc import clear_from_params_empty, require_capabilities
from ..._misc import require_capabilities
from ..._session import AsyncNcSessionApp, NcSessionApp

_EP_SUFFIX: str = "ai_provider/task_processing"


@dataclasses.dataclass
class TaskProcessingProvider:
"""TaskProcessing provider description."""
class ShapeType(IntEnum):
"""Enum for shape types."""

NUMBER = 0
TEXT = 1
IMAGE = 2
AUDIO = 3
VIDEO = 4
FILE = 5
ENUM = 6
LISTOFNUMBERS = 10
LISTOFTEXTS = 11
LISTOFIMAGES = 12
LISTOFAUDIOS = 13
LISTOFVIDEOS = 14
LISTOFFILES = 15


@dataclass
class ShapeEnumValue:
"""Data object for input output shape enum slot value."""

name: str
"""Name of the enum slot value which will be displayed in the UI"""
value: str
"""Value of the enum slot value"""


@dataclass
class ShapeDescriptor:
"""Data object for input output shape entries."""

def __init__(self, raw_data: dict):
self._raw_data = raw_data
name: str
"""Name of the shape entry"""
description: str
"""Description of the shape entry"""
shape_type: ShapeType
"""Type of the shape entry"""

@property
def name(self) -> str:
"""Unique ID for the provider."""
return self._raw_data["name"]

@property
def display_name(self) -> str:
"""Providers display name."""
return self._raw_data["display_name"]
@dataclass
class TaskType:
"""TaskType description for the provider."""

@property
def task_type(self) -> str:
"""The TaskType provided by this provider."""
return self._raw_data["task_type"]
id: str
"""The unique ID for the task type."""
name: str
"""The localized name of the task type."""
description: str
"""The localized description of the task type."""
input_shape: list[ShapeDescriptor]
"""The input shape of the task."""
output_shape: list[ShapeDescriptor]
"""The output shape of the task."""


@dataclass
class TaskProcessingProvider:
"""TaskProcessing provider description."""

# pylint: disable=too-many-instance-attributes

id: str
"""Unique ID for the provider."""
name: str
"""The localized name of this provider"""
task_type: str
"""The TaskType provided by this provider."""
expected_runtime: int = dataclasses.field(default=0)
"""Expected runtime of the task in seconds."""
optional_input_shape: list[ShapeDescriptor] = dataclasses.field(default_factory=list)
"""Optional input shape of the task."""
optional_output_shape: list[ShapeDescriptor] = dataclasses.field(default_factory=list)
"""Optional output shape of the task."""
input_shape_enum_values: dict[str, list[ShapeEnumValue]] = dataclasses.field(default_factory=dict)
"""The option dict for each input shape ENUM slot."""
input_shape_defaults: dict[str, str | int | float] = dataclasses.field(default_factory=dict)
"""The default values for input shape slots."""
optional_input_shape_enum_values: dict[str, list[ShapeEnumValue]] = dataclasses.field(default_factory=dict)
"""The option list for each optional input shape ENUM slot."""
optional_input_shape_defaults: dict[str, str | int | float] = dataclasses.field(default_factory=dict)
"""The default values for optional input shape slots."""
output_shape_enum_values: dict[str, list[ShapeEnumValue]] = dataclasses.field(default_factory=dict)
"""The option list for each output shape ENUM slot."""
optional_output_shape_enum_values: dict[str, list[ShapeEnumValue]] = dataclasses.field(default_factory=dict)
"""The option list for each optional output shape ENUM slot."""

def __repr__(self):
return f"<{self.__class__.__name__} name={self.name}, type={self.task_type}>"
Expand All @@ -44,17 +113,16 @@ def __init__(self, session: NcSessionApp):
self._session = session

def register(
self, name: str, display_name: str, task_type: str, custom_task_type: dict[str, typing.Any] | None = None
self,
provider: TaskProcessingProvider,
custom_task_type: TaskType | None = None,
) -> None:
"""Registers or edit the TaskProcessing provider."""
require_capabilities("app_api", self._session.capabilities)
params = {
"name": name,
"displayName": display_name,
"taskType": task_type,
"customTaskType": custom_task_type,
"provider": RootModel(provider).model_dump(),
**({"customTaskType": RootModel(custom_task_type).model_dump()} if custom_task_type else {}),
}
clear_from_params_empty(["customTaskType"], params)
self._session.ocs("POST", f"{self._session.ae_url}/{_EP_SUFFIX}", json=params)

def unregister(self, name: str, not_fail=True) -> None:
Expand Down Expand Up @@ -123,17 +191,16 @@ def __init__(self, session: AsyncNcSessionApp):
self._session = session

async def register(
self, name: str, display_name: str, task_type: str, custom_task_type: dict[str, typing.Any] | None = None
self,
provider: TaskProcessingProvider,
custom_task_type: TaskType | None = None,
) -> None:
"""Registers or edit the TaskProcessing provider."""
require_capabilities("app_api", await self._session.capabilities)
params = {
"name": name,
"displayName": display_name,
"taskType": task_type,
"customTaskType": custom_task_type,
"provider": RootModel(provider).model_dump(),
**({"customTaskType": RootModel(custom_task_type).model_dump()} if custom_task_type else {}),
}
clear_from_params_empty(["customTaskType"], params)
await self._session.ocs("POST", f"{self._session.ae_url}/{_EP_SUFFIX}", json=params)

async def unregister(self, name: str, not_fail=True) -> None:
Expand Down

0 comments on commit ce6ece1

Please sign in to comment.