Skip to content

Commit

Permalink
feat: add json pointer support
Browse files Browse the repository at this point in the history
  • Loading branch information
HoKim98 committed Jul 17, 2024
1 parent bc6c171 commit 481c3c9
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 26 deletions.
14 changes: 9 additions & 5 deletions crates/cassette-core/src/task.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use garde::Validate;
use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::{Map, Value};
use serde_json::Value;
#[cfg(feature = "ui")]
use yew::prelude::*;

Expand Down Expand Up @@ -70,7 +70,7 @@ pub enum TaskState {

#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(transparent)]
pub struct TaskSpec(Map<String, Value>);
pub struct TaskSpec(Value);

impl TaskSpec {
fn preserve_arbitrary(
Expand All @@ -86,9 +86,13 @@ impl TaskSpec {
#[cfg(feature = "ui")]
impl TaskSpec {
fn get(&self, key: &str) -> TaskResult<&Value> {
self.0
.get(&key[1..])
.ok_or_else(|| format!("no such key: {key}"))
match key {
"" | "/" => Ok(&self.0),
key => self
.0
.pointer(key)
.ok_or_else(|| format!("no such key: {key}")),
}
}

pub fn get_string(&self, key: &str) -> TaskResult<String> {
Expand Down
11 changes: 5 additions & 6 deletions crates/cassette-plugin-openai-chat/src/hooks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use yew::prelude::*;
use crate::schema::{Request, Response};

#[hook]
pub fn use_fetch(base_url: &str, request: Request) -> UseStateHandle<FetchState<Response>> {
pub fn use_fetch(base_url: &str, request: &Request) -> UseStateHandle<FetchState<Response>> {
let state = use_state(|| FetchState::Pending);
{
let state = state.clone();
Expand All @@ -20,13 +20,12 @@ pub fn use_fetch(base_url: &str, request: Request) -> UseStateHandle<FetchState<
method: Method::POST,
name: "chat completions",
url: "/chat/completions",
body: Some(Body::Json(request)),
body: Some(Body::Json(request.clone())),
};

let f: Box<dyn FnOnce()> = if stream {
Box::new(move || request.try_stream_with(&base_url, state, try_stream))
} else {
Box::new(move || request.try_fetch_unchecked(&base_url, state))
let f: Box<dyn FnOnce()> = match stream {
Some(true) => Box::new(move || request.try_stream_with(&base_url, state, try_stream)),
Some(false) | None => Box::new(move || request.try_fetch_unchecked(&base_url, state)),
};
use_effect(f)
}
Expand Down
16 changes: 5 additions & 11 deletions crates/cassette-plugin-openai-chat/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,26 @@ use patternfly_yew::prelude::*;
use yew::prelude::*;
use yew_markdown::Markdown;

use crate::schema::{Message, Request, RequestOptions, Response};
use crate::schema::{Request, Response};

pub fn render(_state: &UseStateHandle<CassetteState>, spec: &TaskSpec) -> TaskResult {
let base_url = spec.get_string("/baseUrl")?;
let messages: Vec<Message> = spec.get_model("/messages")?;
let request: Request = spec.get_model("/")?;

Ok(TaskState::Continue {
body: html! { <Component { base_url } { messages } /> },
body: html! { <Component { base_url } { request } /> },
})
}

#[derive(Clone, Debug, PartialEq, Properties)]
struct Props {
base_url: String,
messages: Vec<Message>,
request: Request,
}

#[function_component(Component)]
fn component(props: &Props) -> Html {
let Props { base_url, messages } = props;

let request = Request {
model: "tgi".into(),
options: RequestOptions { stream: true },
messages: messages.clone(),
};
let Props { base_url, request } = props;

let value = self::hooks::use_fetch(base_url, request);
match &*value {
Expand Down
19 changes: 15 additions & 4 deletions crates/cassette-plugin-openai-chat/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,28 @@ use std::collections::VecDeque;

use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, PartialEq, Serialize)]
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Request {
#[serde(default = "Request::default_model")]
pub model: String,
#[serde(flatten)]
#[serde(default, flatten)]
pub options: RequestOptions,
#[serde(default)]
pub messages: Vec<Message>,
}

#[derive(Clone, Debug, Default, PartialEq, Serialize)]
impl Request {
fn default_model() -> String {
"any".into()
}
}

#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct RequestOptions {
pub stream: bool,
#[serde(default)]
pub stream: Option<bool>,
}

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
Expand All @@ -27,9 +36,11 @@ pub struct Message {
#[derive(Clone, Debug, PartialEq, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct MessageChoice {
#[serde(default)]
pub index: u32,
#[serde(alias = "delta")]
pub message: Message,
#[serde(default)]
pub finish_reason: Option<MessageFinishReason>,
}

Expand Down
1 change: 1 addition & 0 deletions examples/openai_chat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ spec:
kind: OpenAIChat
spec:
baseUrl: /v1
stream: true
messages:
- role: user
content: What is your name?

0 comments on commit 481c3c9

Please sign in to comment.