Skip to content

Commit

Permalink
add perplexity prompter
Browse files Browse the repository at this point in the history
  • Loading branch information
Southclaws committed Jan 11, 2025
1 parent 7f022f9 commit 565c6f3
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 129 deletions.
55 changes: 28 additions & 27 deletions app/services/semdex/asker/asker.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,34 @@ import (
"github.com/Southclaws/storyden/internal/infrastructure/ai"
)

type Asker struct {
searcher semdex.Searcher
prompter ai.Prompter
}

func New(cfg config.Config, searcher semdex.Searcher, prompter ai.Prompter) (*Asker, error) {
func New(cfg config.Config, searcher semdex.Searcher, prompter ai.Prompter) (semdex.Asker, error) {
if cfg.SemdexProvider != "" && cfg.LanguageModelProvider == "" {
return nil, fault.New("semdex requires a language model provider to be enabled")
}

return &Asker{
searcher: searcher,
prompter: prompter,
}, nil
switch cfg.AskerProvider {
case "perplexity":
// NOTE: While Perplexity looks like it could satisfy the language model
// provider interface, it does not provide an embedding func, it's only
// functional for chat-like interactions so it's only an Asker for now.
// This means that if you wish to use Perplexity, you must also provide
// a language model provider such as OpenAI along with an API key. Keep
// this in mind when considering the cost of your Storyden installation.
return newPerplexityAsker(cfg, searcher)

default:

return &defaultAsker{
searcher: searcher,
prompter: prompter,
}, nil
}
}

var AnswerPrompt = template.Must(template.New("").Parse(`
You are an expert assistant. Answer the user's question accurately and concisely using the provided sources. Cite the sources in a separate list at the end of your answer.
Ensure that the source URLs (in "sdr" format) are kept exactly as they appear, without modification or breaking them across lines.
You MUST include references to the sources below in your answer in addition to other sources you may have.
Sources:
{{- range .Context }}
Expand All @@ -48,26 +57,22 @@ Question: {{ .Question }}
Answer:
1. Provide your answer here in clear and concise paragraphs.
2. Use information from the sources above to support your answer, but do not include citations inline.
3. Include a "References" section with the source URLs listed, like this:
3. Include a "Sources" section with the source URLs listed, like this:
References:
- (the url to the source): (Short description of the source content)
Sources:
- [title](url): Short description of the source content
`))

const maxContextForRAG = 10

func (a *Asker) Ask(ctx context.Context, q string) (chan string, chan error) {
chunks, err := a.searcher.SearchChunks(ctx, q, pagination.NewPageParams(1, 200), searcher.Options{})
func buildContextPrompt(ctx context.Context, s semdex.Searcher, q string) (string, error) {
chunks, err := s.SearchChunks(ctx, q, pagination.NewPageParams(1, 200), searcher.Options{})
if err != nil {
ech := make(chan error, 1)
ech <- fault.Wrap(err, fctx.With(ctx))
return nil, ech
return "", fault.Wrap(err, fctx.With(ctx))
}

if len(chunks) == 0 {
ech := make(chan error, 1)
ech <- fault.New("no context found for question", fctx.With(ctx), ftag.With(ftag.NotFound))
return nil, ech
return "", fault.New("no context found for question", fctx.With(ctx), ftag.With(ftag.NotFound))
}

if len(chunks) > maxContextForRAG {
Expand All @@ -80,12 +85,8 @@ func (a *Asker) Ask(ctx context.Context, q string) (chan string, chan error) {
"Question": q,
})
if err != nil {
ech := make(chan error, 1)
ech <- fault.Wrap(err, fctx.With(ctx))
return nil, ech
return "", fault.Wrap(err, fctx.With(ctx))
}

chch, ech := a.prompter.PromptStream(ctx, t.String())

return chch, ech
return t.String(), nil
}
30 changes: 30 additions & 0 deletions app/services/semdex/asker/default.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package asker

import (
"context"

"github.com/Southclaws/fault"
"github.com/Southclaws/fault/fctx"

"github.com/Southclaws/storyden/app/services/semdex"
"github.com/Southclaws/storyden/internal/infrastructure/ai"
)

// defaultAsker uses whatever prompter is available and performs RAG prompting.
type defaultAsker struct {
searcher semdex.Searcher
prompter ai.Prompter
}

func (a *defaultAsker) Ask(ctx context.Context, q string) (chan string, chan error) {
t, err := buildContextPrompt(ctx, a.searcher, q)
if err != nil {
ech := make(chan error, 1)
ech <- fault.Wrap(err, fctx.With(ctx))
return nil, ech
}

chch, ech := a.prompter.PromptStream(ctx, t)

return chch, ech
}
178 changes: 178 additions & 0 deletions app/services/semdex/asker/perplexity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package asker

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"time"

"github.com/Southclaws/fault"
"github.com/Southclaws/fault/fctx"
"github.com/Southclaws/storyden/app/services/semdex"
"github.com/Southclaws/storyden/internal/config"
"github.com/openai/openai-go/packages/ssestream"
)

const (
DefaultEndpoint = "https://api.perplexity.ai/chat/completions"
DefautTimeout = 10 * time.Second
)

const (
Llama_3_1SonarSmall_128kChat = "llama-3.1-sonar-small-128k-chat"
Llama_3_1SonarLarge_128kChat = "llama-3.1-sonar-large-128k-chat"
Llama_3_1SonarSmall_128kOnline = "llama-3.1-sonar-small-128k-online"
Llama_3_1SonarLarge_128kOnline = "llama-3.1-sonar-large-128k-online"
Llama_3_1_8bInstruct = "llama-3.1-8b-instruct"
Llama_3_1_70bInstruct = "llama-3.1-70b-instruct"
)

type Perplexity struct {
endpoint string
apiKey string
model string
httpClient *http.Client
httpTimeout time.Duration
searcher semdex.Searcher
}

func newPerplexityAsker(cfg config.Config, searcher semdex.Searcher) (*Perplexity, error) {
s := &Perplexity{
apiKey: cfg.PerplexityAPIKey,
endpoint: DefaultEndpoint,
model: Llama_3_1SonarSmall_128kOnline,
httpClient: &http.Client{},
httpTimeout: DefautTimeout,
searcher: searcher,
}
return s, nil
}

func (a *Perplexity) Ask(ctx context.Context, q string) (chan string, chan error) {
outch := make(chan string)
errch := make(chan error)

t, err := buildContextPrompt(ctx, a.searcher, q)
if err != nil {
ech := make(chan error, 1)
ech <- fault.Wrap(err, fctx.With(ctx))
return nil, ech
}

fmt.Println(t)

resp, err := func() (*http.Response, error) {
reqBody := CompletionRequest{
Stream: true,
Messages: []Message{{Role: "user", Content: t}},
Model: a.model,
}

requestBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}

req, err := http.NewRequestWithContext(ctx, "POST", a.endpoint, bytes.NewBuffer(requestBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

req.Header.Set("Authorization", "Bearer "+a.apiKey)
req.Header.Set("Content-Type", "application/json")

resp, err := a.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}

return resp, nil
}()
if err != nil {
errch <- err
return outch, errch
}

dec := ssestream.NewDecoder(resp)

go func() {
defer resp.Body.Close()
defer close(outch)
defer close(errch)

for dec.Next() {
event := dec.Event()
var cr CompletionResponse

if err := json.Unmarshal(event.Data, &cr); err != nil {
errch <- fmt.Errorf("failed to unmarshal SSE event: %w", err)
return
}

if len(cr.Choices) == 0 {
errch <- fmt.Errorf("no choices in response")
return
}

if len(cr.Citations) == 0 {
fmt.Println(string(event.Data))
errch <- fmt.Errorf("no citations in response")
return
}

choice := cr.Choices[0]

outch <- choice.Delta.Content

if choice.FinishReason == "stop" {
break
}
}

if dec.Err() != nil {
errch <- fmt.Errorf("failed to read SSE stream: %w", dec.Err())
}
}()

return outch, errch
}

func replaceCitations(message string, citations []string) string {
return message
}

type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}

type CompletionRequest struct {
Messages []Message `json:"messages"`
Model string `json:"model"`
Stream bool `json:"stream"`
}

type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

type Choice struct {
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
Message Message `json:"message"`
Delta Message `json:"delta"`
}

type CompletionResponse struct {
ID string `json:"id"`
Model string `json:"model"`
Created int `json:"created"`
Usage Usage `json:"usage"`
Citations []string `json:"citations"`
Object string `json:"object"`
Choices []Choice `json:"choices"`
}
5 changes: 1 addition & 4 deletions app/services/semdex/semdexer/semdexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@ func newSemdexer(
func Build() fx.Option {
return fx.Options(
fx.Provide(
fx.Annotate(
asker.New,
fx.As(new(semdex.Asker)),
),
asker.New,
),
fx.Provide(
fx.Annotate(
Expand Down
10 changes: 3 additions & 7 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ require (
github.com/ThreeDotsLabs/watermill v1.4.1
github.com/ThreeDotsLabs/watermill-amqp/v3 v3.0.0
github.com/a8m/enter v0.0.0-20230407172335-1834787a98fe
github.com/alitto/pond/v2 v2.1.6
github.com/bwmarrin/discordgo v0.28.1
github.com/cixtor/readability v1.0.0
github.com/dave/jennifer v1.7.1
Expand All @@ -46,11 +47,9 @@ require (
github.com/glebarez/go-sqlite v1.22.0
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/google/go-github/v66 v66.0.0
github.com/google/go-github/v68 v68.0.0
github.com/iancoleman/strcase v0.3.0
github.com/jackc/pgx/v5 v5.7.1
github.com/jmoiron/sqlx v1.4.0
github.com/k3a/html2text v1.2.1
github.com/klippa-app/go-pdfium v1.13.0
github.com/matcornic/hermes/v2 v2.1.0
github.com/mazznoer/colorgrad v0.10.0
Expand All @@ -65,6 +64,7 @@ require (
github.com/openai/openai-go v0.1.0-alpha.39
github.com/pb33f/libopenapi v0.18.7
github.com/philippgille/chromem-go v0.7.0
github.com/pinecone-io/go-pinecone v1.1.2-0.20241220212044-af29d07e7c68
github.com/puzpuzpuz/xsync/v3 v3.4.0
github.com/redis/rueidis v1.0.49
github.com/rs/cors v1.11.1
Expand All @@ -85,8 +85,6 @@ require (
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver v1.5.0 // indirect
github.com/Masterminds/sprig v2.22.0+incompatible // indirect
github.com/alitto/pond v1.9.2 // indirect
github.com/alitto/pond/v2 v2.1.6 // indirect
github.com/andybalholm/cascadia v1.3.2 // indirect
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect
Expand Down Expand Up @@ -137,7 +135,6 @@ require (
github.com/oklog/ulid v1.3.1 // indirect
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/perimeterx/marshmallow v1.1.5 // indirect
github.com/pinecone-io/go-pinecone v1.1.2-0.20241220212044-af29d07e7c68 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/rabbitmq/amqp091-go v1.10.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
Expand Down Expand Up @@ -193,7 +190,6 @@ require (
github.com/go-openapi/inflect v0.21.0 // indirect
github.com/go-openapi/jsonpointer v0.21.0 // indirect
github.com/go-openapi/swag v0.23.0 // indirect
github.com/go-resty/resty/v2 v2.16.2
github.com/go-webauthn/webauthn v0.11.2
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/google/go-cmp v0.6.0 // indirect
Expand Down Expand Up @@ -224,6 +220,6 @@ require (
golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.21.0
golang.org/x/time v0.8.0 // indirect
google.golang.org/protobuf v1.36.1 // indirect
google.golang.org/protobuf v1.36.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)
Loading

0 comments on commit 565c6f3

Please sign in to comment.