Skip to content

Commit

Permalink
feat: error handling & adapt queries (#6)
Browse files Browse the repository at this point in the history
* Error display WIP

* Assistant error handling

* Adapt query & types

Co-authored-by: Etienne Soulard-Geoffrion <etiennecl@users.noreply.github.com>

* Fix assistant passages & use latest client

---------

Co-authored-by: Etienne Soulard-Geoffrion <etiennecl@users.noreply.github.com>
  • Loading branch information
xWiiLLz and etiennecl authored Sep 9, 2024
1 parent 1d4a859 commit 02cfce3
Show file tree
Hide file tree
Showing 10 changed files with 1,647 additions and 1,006 deletions.
8 changes: 4 additions & 4 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
},
"dependencies": {
"@auth0/nextjs-auth0": "^3.5.0",
"@clinia-ui/icons": "^0.1.8",
"@clinia-ui/react": "^0.1.8",
"@clinia/client-common": "1.0.10-hgs-84d73454.0",
"@clinia/client-datapartition": "1.0.10-hgs-84d73454.0",
"@clinia-ui/icons": "^0.1.22",
"@clinia-ui/react": "^0.1.22",
"@clinia/client-common": "1.0.10-edge-842481d.0",
"@clinia/client-datapartition": "1.0.10-edge-842481d.0",
"@clinia/search-sdk-core": "^0.1.0",
"@clinia/search-sdk-react": "^0.1.0",
"@clinia/tritonclient": "^1.0.0",
Expand Down
2,415 changes: 1,457 additions & 958 deletions pnpm-lock.yaml

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/components/article-drawer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export const ArticleDrawer = () => {

const hitsHighlights = allhighlights.filter(
(highlight): highlight is HitsHighlight =>
'type' in highlight && highlight.type === 'hits'
'type' in highlight && highlight.type === 'vector'
);
if (hitsHighlights.length === 0) {
return allhighlights.map(getHighlightText);
Expand Down
9 changes: 3 additions & 6 deletions src/components/article-hit.tsx
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
'use client';

import { Article, Hit, HitsHighlight } from '@/lib/client';
import { Article, Hit } from '@/lib/client';
import { useMemo } from 'react';
import { getHighlightText } from '../lib/client/util';
import { getHighlightText, isVectorHighlight } from '../lib/client/util';
import { HtmlDisplay } from './html-display';
import { useSearchLayout } from './search-layout';

Expand All @@ -15,10 +15,7 @@ export const ArticleHit = ({ hit }: { hit: Hit<Article> }) => {
if (allHighlights.length === 0) {
return undefined;
}
const hitsHighlights = allHighlights.filter(
(highlight): highlight is HitsHighlight =>
'type' in highlight && highlight.type === 'hits'
);
const hitsHighlights = allHighlights.filter(isVectorHighlight);
if (hitsHighlights.length === 0) {
// We fallback to displaying the first text highlight
return getHighlightText(allHighlights[0]);
Expand Down
158 changes: 132 additions & 26 deletions src/components/assistant.tsx
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
'use client';

import { Sparkles } from 'lucide-react';
import { getHighlightText } from '@/lib/client/util';
import { CircleAlertIcon, RefreshCw, Sparkles } from 'lucide-react';
import { twMerge } from 'tailwind-merge';
import { useCallback, useEffect, useRef, useState } from 'react';
import Markdown from 'react-markdown';
import { V1Hit } from '@clinia/client-common';
import { useHits, useQuery } from '@clinia/search-sdk-react';
import { V1HighlightingHitVector, V1Hit } from '@clinia/client-common';
import { useHits, useLoading, useQuery } from '@clinia/search-sdk-react';
import { Button } from '@clinia-ui/react';
import styles from './assistant.module.css';
import { useStreamRequest } from './use-stream-request';

Expand All @@ -15,7 +17,18 @@ export type AssistantProps = {

export const Assistant = ({ className }: AssistantProps) => {
const hits = useHits();
const querying = useLoading();
const [query] = useQuery();
const [seenHits, setSeenHits] = useState(false);
useEffect(() => {
if (hits.length > 0) {
setSeenHits(true);
}
}, [hits]);

if (!seenHits) {
return null;
}

return (
<div className={twMerge('rounded-lg border p-6', className)}>
Expand All @@ -24,7 +37,11 @@ export const Assistant = ({ className }: AssistantProps) => {
<h1 className="text-base font-medium text-primary">Assistant</h1>
</header>
<div>
<AssistantListener hits={hits as any} query={query} />
<AssistantListener
querying={querying}
hits={hits as any}
query={query}
/>
</div>
<footer></footer>
</div>
Expand All @@ -33,52 +50,60 @@ export const Assistant = ({ className }: AssistantProps) => {

type AssistantListenerProps = {
query: string;
querying: boolean;
hits: V1Hit[];
};
const AssistantListener = ({ hits, query }: AssistantListenerProps) => {
const AssistantListener = ({
hits,
query,
querying,
}: AssistantListenerProps) => {
const [summary, setSummary] = useState('');
// Little hack to ensure we do not display the error message between assistant requests
const [disableErrors, setDisableErrors] = useState(true);
const queryRef = useRef(query);

useEffect(() => {
// We store the query in a ref so that we only refetch the assistant when new articles are coming.
// This avoids doing a double-query in between the request-response from the query API.
queryRef.current = query;
setSummary('');
setDisableErrors(true);
}, [query]);

const { refetch, status } = useStreamRequest(
useCallback(
(chunk: string) => {
setSummary((s) => s + chunk);
setSummary((s) => (s !== null ? s + chunk : chunk));
},
[setSummary]
)
);

// Reset summary every time the query changes
const handleRefetch = useCallback(() => {
refetchHandlerFromHits(hits, queryRef.current, refetch)?.();
}, [hits, refetch]);

useEffect(() => {
if (hits.length === 0) return;
const passages = hits.flatMap((h) =>
(h.highlighting?.['abstract.passages'] ?? []).slice(0, 1).map((x) =>
JSON.stringify({
id: h.resource.id,
text: '',
title: h.resource.data.title,
passages: [x.highlight],
})
)
);
refetch(`/api/assistant`, {
method: 'POST',
body: JSON.stringify({
query: queryRef.current,
articles: passages.slice(0, 3),
}),
});
refetchHandlerFromHits(hits, queryRef.current, refetch)?.();
setDisableErrors(false);
}, [hits, refetch]);

if (
!querying &&
!disableErrors &&
(status === 'error' || (status === 'success' && summary.trim() === ''))
) {
return (
<ErrorDisplay
disabled={['idle', 'loading'].includes(status)}
onRetry={handleRefetch}
/>
);
}

const classnames = [];
if (status === 'loading' || status === 'idle') {
if (status === 'loading' || status === 'idle' || querying) {
classnames.push(styles.type);
}

Expand All @@ -98,3 +123,84 @@ const AssistantListener = ({ hits, query }: AssistantListenerProps) => {
</Markdown>
);
};

type ErrorDisplayProps = {
onRetry?: () => void;
disabled?: boolean;
};

const ErrorDisplay = ({ disabled, onRetry }: ErrorDisplayProps) => {
return (
<div className="flex items-center gap-2.5 rounded-lg bg-accent p-2.5 text-sm">
<CircleAlertIcon className="text-accent-foreground" />
<div className="flex flex-1 flex-col">
<h3 className="font-medium text-accent-foreground">
Couldn&apos;t generate a summary
</h3>
<p className="text-accent-foreground">
You can still browse results below, or try regenerating a summary
</p>
</div>
{onRetry ? (
<Button
type="button"
className="items-center justify-center gap-2"
disabled={disabled}
onClick={onRetry}
>
<RefreshCw className="h-3.5 w-3.5" />
Retry
</Button>
) : null}
</div>
);
};

const refetchHandlerFromHits = (
hits: V1Hit[],
query: string,
refetch: (url: string, request: RequestInit) => Promise<void>
): undefined | (() => void) => {
if (hits.length === 0) return undefined;
const passages = hits.flatMap((h) =>
// Find the highest scoring passage from each hit
{
const allHits = Object.values(h.highlighting ?? {}).flat();
if (allHits.length === 0) {
return [];
}

const vectorHits = allHits
.filter(
(highlight): highlight is V1HighlightingHitVector =>
'type' in highlight && highlight.type === 'vector'
)
.sort((a, b) => b.score - a.score);
if (vectorHits.length === 0) {
return [
JSON.stringify({
id: h.resource.id,
text: '',
title: h.resource.data.title,
passages: [getHighlightText(allHits[0])],
}),
];
}

return JSON.stringify({
id: h.resource.id,
text: '',
title: h.resource.data.title,
passages: [getHighlightText(vectorHits[0])],
});
}
);
return () =>
refetch(`/api/assistant`, {
method: 'POST',
body: JSON.stringify({
query,
articles: passages.slice(0, 3),
}),
});
};
5 changes: 3 additions & 2 deletions src/components/hits.tsx
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
'use client';

import { Article, Hit } from '@/lib/client';
import { useHits } from '@clinia/search-sdk-react';
import { useHits, useLoading } from '@clinia/search-sdk-react';
import { ArticleHit } from './article-hit';

export const Hits = () => {
const hits = useHits() as Hit<Article>[];
const loading = useLoading();

if (hits.length === 0) {
if (hits.length === 0 || loading) {
return null;
}

Expand Down
34 changes: 31 additions & 3 deletions src/components/search-provider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ export const SearchProvider = ({ children, state }: SearchProviderProps) => {
match: {
'abstract.passages': {
value: params.query ?? '',
type: 'word',
type: 'phrase',
},
},
},
Expand All @@ -68,12 +68,40 @@ export const SearchProvider = ({ children, state }: SearchProviderProps) => {
},
},
},
{
match: {
title: {
value: params.query ?? '',
type: 'word',
},
},
},
{
match: {
keywords: {
value: params.query ?? '',
type: 'word',
},
},
},
{
knn: {
'content.text.passages.vector': {
value: params.query ?? '',
},
},
},
],
},
highlighting: ['abstract.passages'],
highlighting: [
'abstract.passages',
'abstract.passages.vector',
'title',
'keywords',
'content.text.passages.vector',
],
},
});
// const resp = await client.search({ query: params.query ?? '' });
return resp;
};

Expand Down
1 change: 0 additions & 1 deletion src/components/use-stream-request.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ export function useStreamRequest(onData: (data: string) => void) {
async (url: string, request: RequestInit) => {
setStatus('loading');
if (controllerRef.current) {
console.warn('Aborting previous request');
controllerRef.current.abort();
}
const controller = new AbortController();
Expand Down
6 changes: 4 additions & 2 deletions src/lib/client/types.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { V1Hit } from '@clinia/client-datapartition';

export type InformationPocClient = {
search: <T = Resource>(params: SearchRequest) => Promise<SearchResponse<T>>;
};
Expand All @@ -15,7 +17,7 @@ export type SearchResponse<T = Resource> = {

export type Hit<T = Resource> = {
resource: T;
highlighting?: Record<string, Highlight[]>;
highlighting?: V1Hit['highlighting'];
};

export type Resource = {
Expand Down Expand Up @@ -47,7 +49,7 @@ export type Highlight =
| HitsHighlight;

export type HitsHighlight = {
type: 'hits';
type: 'vector';
score: number;
data: string;
// content.0.passages.0
Expand Down
15 changes: 12 additions & 3 deletions src/lib/client/util.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
import type { Highlight } from './types';
import {
V1HighlightingHit,
V1HighlightingHitVector,
} from '@clinia/client-common';

export const getHighlightText = (highlight: Highlight): string => {
export const getHighlightText = (highlight: V1HighlightingHit): string => {
if ('highlight' in highlight) {
return highlight.highlight;
}

if (highlight.type === 'hits' || 'data' in highlight) {
if (highlight.type === 'vector' || 'data' in highlight) {
return highlight.data;
}

return '';
};

export const isVectorHighlight = (
highlight: V1HighlightingHit
): highlight is V1HighlightingHitVector => {
return 'type' in highlight && highlight.type === 'vector';
};

0 comments on commit 02cfce3

Please sign in to comment.