Skip to content

Commit

Permalink
feat(openai): add markdown output support
Browse files Browse the repository at this point in the history
  • Loading branch information
HoKim98 committed Jul 17, 2024
1 parent b750a9c commit bc6c171
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 33 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,5 @@ uuid = { version = "1.10", default-features = false, features = ["serde"] }
wasm-streams = { version = "0.4" }
web-sys = { version = "0.3", features = ["MediaQueryList", "Url", "Window"] }
yew = { version = "0.21", features = ["csr"] }
yew-markdown = { git = "https://github.com/ulagbulag/yew-markdown" }
yew-nested-router = { version = "0.7" }
28 changes: 20 additions & 8 deletions crates/cassette-core/src/net/fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ impl<Url, Req> FetchRequest<Url, Req> {
Req: 'static + Serialize,
Res: 'static + DeserializeOwned,
Url: fmt::Display,
F: 'static + FnOnce(UseStateSetter<FetchState<Res>>, IntoStream<'reader>) -> Fut,
Fut: Future<Output = ::anyhow::Result<()>>,
F: 'static + FnOnce(FetchStateSetter<Res>, IntoStream<'reader>) -> Fut,
Fut: Future<Output = ::anyhow::Result<Res>>,
{
if matches!(&*state, FetchState::Pending) {
state.set(FetchState::Fetching);
Expand Down Expand Up @@ -146,12 +146,14 @@ impl<Url, Req> FetchRequest<Url, Req> {
.map(ReadableStream::from_raw)
.map(ReadableStream::into_stream)
{
Some(body) => match handler(state.setter(), body).await {
Ok(()) => return,
Err(error) => FetchState::Error(format!(
"Failed to parse the {name}: {error}"
)),
},
Some(body) => {
match handler(FetchStateSetter(state.setter()), body).await {
Ok(data) => FetchState::Completed(data),
Err(error) => FetchState::Error(format!(
"Failed to parse the {name}: {error}"
)),
}
}
None => FetchState::Error(format!("Empty body: {name}")),
},
Err(error) => {
Expand All @@ -168,11 +170,20 @@ impl<Url, Req> FetchRequest<Url, Req> {
}
}

pub struct FetchStateSetter<T>(UseStateSetter<FetchState<T>>);

impl<T> FetchStateSetter<T> {
pub fn set(&self, value: T) {
self.0.set(FetchState::Collecting(value))
}
}

#[derive(Clone, Debug, Default)]
pub enum FetchState<T> {
#[default]
Pending,
Fetching,
Collecting(T),
Completed(T),
Error(String),
}
Expand All @@ -185,6 +196,7 @@ where
match self {
Self::Pending => "pending".fmt(f),
Self::Fetching => "loading".fmt(f),
Self::Collecting(data) => data.fmt(f),
Self::Completed(data) => data.fmt(f),
Self::Error(error) => error.fmt(f),
}
Expand Down
2 changes: 1 addition & 1 deletion crates/cassette-plugin-kubernetes-list/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ fn component(props: &Props) -> Html {
<p>{ "Loading..." }</p>
</Content>
},
FetchState::Completed(data) => html! {
FetchState::Collecting(data) | FetchState::Completed(data) => html! {
<ComponentBody<DynamicObject> list={ data.items.clone() } />
},
FetchState::Error(error) => html! {
Expand Down
1 change: 1 addition & 0 deletions crates/cassette-plugin-openai-chat/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ patternfly-yew = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
yew = { workspace = true }
yew-markdown = { workspace = true }
78 changes: 64 additions & 14 deletions crates/cassette-plugin-openai-chat/src/hooks.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use anyhow::{anyhow, bail, Result};
use cassette_core::net::fetch::{Body, FetchRequest, FetchState, IntoStream, Method};
use cassette_core::net::fetch::{
Body, FetchRequest, FetchState, FetchStateSetter, IntoStream, Method,
};
use futures::TryStreamExt;
use itertools::Itertools;
use js_sys::Uint8Array;
use yew::prelude::*;

Expand Down Expand Up @@ -31,29 +34,76 @@ pub fn use_fetch(base_url: &str, request: Request) -> UseStateHandle<FetchState<
}

async fn try_stream(
state: UseStateSetter<FetchState<Response>>,
setter: FetchStateSetter<Response>,
stream: IntoStream<'_>,
) -> Result<()> {
) -> Result<Response> {
let mut stream = stream
.map_ok(|chunk| Uint8Array::new(&chunk).to_vec())
.map_err(|error| match error.as_string() {
Some(error) => anyhow!("{error}"),
None => anyhow!("{error:?}"),
});

let mut output = Response::default();
while let Some(chunk) = stream.try_next().await? {
let (opcode, data) = chunk.split_at(6);
match ::core::str::from_utf8(opcode)? {
"data: " => {
let Response { mut choices } = ::serde_json::from_slice(&data)?;
if let Some(choice) = choices.pop_front() {
output.choices.push_back(choice);
state.set(FetchState::Completed(output.clone()));
const PATTERN: &[u8] = "\n\ndata: ".as_bytes();

struct TokenStream {
output: Response,
setter: FetchStateSetter<Response>,
}

impl TokenStream {
fn new(setter: FetchStateSetter<Response>) -> Self {
Self {
output: Response::default(),
setter,
}
}

fn feed(&mut self, data: &[u8]) -> Result<()> {
self.feed_with(data, true)
}

fn feed_with(&mut self, data: &[u8], update: bool) -> Result<()> {
if !data.starts_with(PATTERN) {
bail!("unexpected opcode");
}
let data = &data[PATTERN.len()..];

let Response { mut choices } = ::serde_json::from_slice(data)?;
if let Some(choice) = choices.pop_front() {
self.output.choices.push_back(choice);
if update {
self.setter.set(self.output.clone());
}
}
opcode => bail!("unexpected opcode: {opcode:?}"),
Ok(())
}

fn finish(mut self, data: &[u8]) -> Result<Response> {
self.feed_with(data, false)?;
Ok(self.output)
}
}

let mut heystack = "\n\n".to_string().into_bytes();
let mut output = TokenStream::new(setter);
while let Some(chunk) = stream.try_next().await? {
heystack.extend(chunk);

let token_indices: Vec<_> = heystack
.windows(PATTERN.len())
.positions(|window| window == PATTERN)
.collect();

// [PATTERN, first, PATTERN, ...] => take "first"
if token_indices.len() >= 2 {
let mut offset = 0;
for (start, end) in token_indices.into_iter().tuple_windows() {
output.feed(&heystack[start..end])?;
offset = end;
}
heystack.drain(..offset);
}
}
Ok(())
output.finish(&heystack)
}
21 changes: 14 additions & 7 deletions crates/cassette-plugin-openai-chat/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use cassette_core::{
use itertools::Itertools;
use patternfly_yew::prelude::*;
use yew::prelude::*;
use yew_markdown::Markdown;

use crate::schema::{Message, Request, RequestOptions, Response};

Expand Down Expand Up @@ -44,8 +45,11 @@ fn component(props: &Props) -> Html {
<p>{ "Loading..." }</p>
</Content>
},
FetchState::Collecting(tokens) => html! {
<ComponentBody completed=false tokens={ tokens.clone() } />
},
FetchState::Completed(tokens) => html! {
<ComponentBody tokens={ tokens.clone() } />
<ComponentBody completed=true tokens={ tokens.clone() } />
},
FetchState::Error(error) => html! {
<Alert inline=true title="Error" r#type={AlertType::Danger}>
Expand All @@ -57,24 +61,27 @@ fn component(props: &Props) -> Html {

#[derive(Clone, Debug, PartialEq, Properties)]
struct BodyProps {
completed: bool,
tokens: Response,
}

#[function_component(ComponentBody)]
fn component_body(props: &BodyProps) -> Html {
let BodyProps { tokens } = props;
let BodyProps { completed, tokens } = props;

let content = tokens
.choices
.iter()
.map(|choice| &choice.message.content)
.join("");

let style = if *completed { "" } else { "color: #FF3333;" };

html! {
<CodeBlock>
<CodeBlockCode>
{ content }
</CodeBlockCode>
</CodeBlock>
<Content>
<div { style }>
<Markdown src={ content } />
</div>
</Content>
}
}
2 changes: 1 addition & 1 deletion crates/cassette/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub fn app_page(props: &AppPageProps) -> Html {
let cassette_list = use_cassette_list();
let cassette_list = match &*cassette_list {
FetchState::Pending | FetchState::Fetching => Err(html! { <p>{ "Loading..." }</p> }),
FetchState::Completed(list) => Ok(list.as_slice()),
FetchState::Collecting(list) | FetchState::Completed(list) => Ok(list.as_slice()),
FetchState::Error(error) => Err(html! { <p>{ format!("Error: {error}") }</p> }),
};

Expand Down
4 changes: 2 additions & 2 deletions crates/cassette/src/pages/cassette.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ pub fn cassette(props: &Props) -> Html {
FetchState::Pending | FetchState::Fetching => html! {
<CassetteFallback />
},
FetchState::Completed(Some(data)) => html! {
FetchState::Collecting(Some(data)) | FetchState::Completed(Some(data)) => html! {
<CassetteView data={ data.clone() } />
},
FetchState::Completed(None) => html! {
FetchState::Collecting(None) | FetchState::Completed(None) => html! {
<crate::pages::error::Error kind={ ErrorKind::NotFound } />
},
FetchState::Error(error) => html! {
Expand Down

0 comments on commit bc6c171

Please sign in to comment.