From 85f5e451682d339f4e246b4e2a0281641deba0f8 Mon Sep 17 00:00:00 2001 From: FluxCapacitor2 <31071265+FluxCapacitor2@users.noreply.github.com> Date: Sun, 24 Nov 2024 02:52:01 -0500 Subject: [PATCH] Implement hybrid (FTS + vector) search and add an API endpoint for it --- app/config/config.go | 12 ++ app/database/db.go | 12 ++ app/database/db_sqlite.go | 152 +++++++++++++++++++++++ app/server/server.go | 96 ++++++++++++++ app/server/templates/highlight.html.tmpl | 18 +-- app/server/templates/index.html.tmpl | 22 ++-- app/server/templates/results.html.tmpl | 56 +++++---- 7 files changed, 321 insertions(+), 47 deletions(-) diff --git a/app/config/config.go b/app/config/config.go index c5030ad..525959f 100644 --- a/app/config/config.go +++ b/app/config/config.go @@ -1,7 +1,9 @@ package config import ( + "fmt" "os" + "regexp" "gopkg.in/yaml.v3" ) @@ -62,6 +64,8 @@ type Source struct { } } +var sourceIDPattern = regexp.MustCompile("^[a-zA-Z0-9_]+$") + func Read() (*Config, error) { data, err := os.ReadFile("./config.yml") @@ -76,5 +80,13 @@ func Read() (*Config, error) { return nil, err } + // Validate the loaded configuration + + for _, src := range config.Sources { + if !sourceIDPattern.MatchString(src.ID) { + panic(fmt.Sprintf("Invalid source ID: %v. Source IDs may only contain alphanumeric characters and underscores.", src.ID)) + } + } + return config, nil } diff --git a/app/database/db.go b/app/database/db.go index 58be835..e8e226f 100644 --- a/app/database/db.go +++ b/app/database/db.go @@ -46,6 +46,7 @@ type Database interface { StartEmbeddings(getChunkDetails func(sourceID string) (chunkSize int, chunkOverlap int)) error SimilaritySearch(sourceID string, query []float32, limit int) ([]SimilarityResult, error) + HybridSearch(sources []string, queryString string, embeddedQueries map[string][]float32, limit int) ([]HybridResult, error) } type Page struct { @@ -77,6 +78,17 @@ type SimilarityResult struct { Similarity float32 `json:"similarity"` } +type HybridResult struct { + URL string `json:"url"` + Title string `json:"title"` + Description string `json:"description"` + Content []Match `json:"content"` + FTSRank *int `json:"ftsRank"` + VecRank *int `json:"vecRank"` + VecDistance *float64 `json:"vecDistance"` + HybridRank float64 `json:"rank"` +} + type Match struct { Highlighted bool `json:"highlighted"` Content string `json:"content"` diff --git a/app/database/db_sqlite.go b/app/database/db_sqlite.go index 2ee82df..d6909e1 100644 --- a/app/database/db_sqlite.go +++ b/app/database/db_sqlite.go @@ -1,10 +1,12 @@ package database import ( + "bytes" "database/sql" "fmt" "regexp" "strings" + "text/template" _ "embed" @@ -304,6 +306,156 @@ func (db *SQLiteDatabase) SimilaritySearch(sourceID string, query []float32, lim return results, err } +var tmpl *template.Template = template.Must(template.New("hybrid-search").Parse(` +WITH {{ range $index, $value := .Sources -}} + vec_subquery_{{ $value }} AS ( + SELECT + vec_chunks.page AS page, + row_number() OVER (ORDER BY distance) AS rank_number, + vec_chunks.chunk AS chunk, + distance + FROM pages_vec_{{ $value }} + JOIN vec_chunks USING (id) + WHERE embedding MATCH ? AND k = ? + -- Select only the most relevant chunk for each page + GROUP BY vec_chunks.page + HAVING MIN(distance) + ORDER BY distance +), {{ end }}fts_subquery AS ( + SELECT + pages_fts.rowid AS page, + highlight(pages_fts, 1, ?, ?) AS title, + snippet(pages_fts, 2, ?, ?, '…', 8) AS description, + snippet(pages_fts, 3, ?, ?, '…', 24) AS content, + rank + FROM pages_fts + JOIN pages ON pages.id = pages_fts.rowid + WHERE + pages.source IN ( + {{- range $index, $value := .Sources -}} + {{- if gt $index 0 }}, {{ end -}} + ? + {{- end -}} + ) + AND pages.status = ? + AND pages_fts MATCH ? + LIMIT ? +), fts_ordered AS ( + SELECT *, row_number() OVER (ORDER BY rank) AS rank_number + FROM fts_subquery +) +SELECT + pages.url, + coalesce(fts_ordered.title, pages.title) AS title, + coalesce(fts_ordered.description, pages.description) AS description, + + coalesce( + fts_ordered.content, {{ range $index, $value := .Sources -}} + {{- if gt $index 0 }}, {{ end -}} + vec_subquery_{{ $value }}.chunk + {{- end }} + ) AS content, + + coalesce( + {{ range $index, $value := .Sources -}} + {{- if gt $index 0 }}, {{ end -}} + vec_subquery_{{ $value }}.distance + {{- end }} + ) AS vec_distance, + + coalesce( + {{ range $index, $value := .Sources -}} + {{- if gt $index 0 }}, {{ end -}} + vec_subquery_{{ $value }}.rank_number + {{- end }} + ) AS vec_rank, + + fts_ordered.rank_number AS fts_rank, + + ( + {{ range $index, $value := .Sources -}} + coalesce(1.0 / (60 + vec_subquery_{{ $value }}.rank_number) * 0.5, 0.0) + + {{ end -}} + coalesce(1.0 / (60 + fts_ordered.rank_number), 0.0) + ) AS combined_rank +FROM fts_ordered +{{ range $index, $value := .Sources -}} + FULL OUTER JOIN vec_subquery_{{ $value }} USING (page) +{{ end -}} +JOIN pages ON pages.id = coalesce( + fts_ordered.page, {{ range $index, $value := .Sources -}} + {{- if gt $index 0 }}, {{ end -}} + vec_subquery_{{ $value }}.page + {{- end }} +) +ORDER BY combined_rank DESC; +`)) + +func (db *SQLiteDatabase) HybridSearch(sources []string, queryString string, embeddedQueries map[string][]float32, limit int) ([]HybridResult, error) { + + // Convert the query vectors to a blob format that `sqlite-vec` will accept + serializedQueries := make(map[string][]byte) + + for sourceID, query := range embeddedQueries { + serialized, err := vec.SerializeFloat32(query) + if err != nil { + return nil, err + } + serializedQueries[sourceID] = serialized + } + + type TemplateData struct { + Sources []string + } + + var query bytes.Buffer + err := tmpl.Execute(&query, TemplateData{Sources: sources}) + if err != nil { + return nil, fmt.Errorf("error formatting query: %v", err) + } + + args := []any{} + + // Vector query args + for _, src := range sources { + args = append(args, serializedQueries[src], limit) + } + + // FTS query args + start := uuid.New().String() + end := uuid.New().String() + + args = append(args, start, end, start, end, start, end) + + for _, src := range sources { + args = append(args, src) + } + + args = append(args, Finished, queryString, limit) + + rows, err := db.conn.Query(query.String(), args...) + + if err != nil { + return nil, err + } + + results := make([]HybridResult, 0) + + for rows.Next() { + res := HybridResult{} + var content string + err := rows.Scan(&res.URL, &res.Title, &res.Description, &content, &res.VecDistance, &res.VecRank, &res.FTSRank, &res.HybridRank) + if err != nil { + return nil, err + } + res.Content = processResult(content, start, end) + // res.Content = []Match{{Content: content, Highlighted: false}} + results = append(results, res) + } + + return results, err +} + func (db *SQLiteDatabase) AddToQueue(source string, referrer string, urls []string, depth int32, isRefresh bool) error { tx, err := db.conn.Begin() diff --git a/app/server/server.go b/app/server/server.go index 6a4693c..11e20c6 100644 --- a/app/server/server.go +++ b/app/server/server.go @@ -210,6 +210,102 @@ func Start(db database.Database, cfg *config.Config) { }) }) + http.HandleFunc("/api/hybrid-search", func(w http.ResponseWriter, req *http.Request) { + type httpResponse struct { + status int16 + Success bool `json:"success"` + Error string `json:"error,omitempty"` + Results []database.HybridResult `json:"results"` + ResponseTime float64 `json:"responseTime"` + } + + timeStart := time.Now().UnixMicro() + + respond := func(response httpResponse) { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(int(response.status)) + response.ResponseTime = float64(time.Now().UnixMicro()-timeStart) / 1e6 + str, err := json.Marshal(response) + if err != nil { + w.Write([]byte(`{"success":"false","error":"Failed to marshal struct into JSON"}`)) + } else { + w.Write([]byte(str)) + } + } + + src := req.URL.Query()["source"] + q := req.URL.Query().Get("q") + + if q == "" || src == nil || len(src) == 0 { + respond(httpResponse{ + status: 400, + Success: false, + Error: "Bad request", + }) + return + } + + foundSources := make([]config.Source, 0, len(src)) + + for _, sourceID := range src { + for _, s := range cfg.Sources { + if s.ID == sourceID { + foundSources = append(foundSources, s) + break + } + } + } + + queryEmbeds := make(map[string][]float32) + + for _, s := range foundSources { + if s.Embeddings.Enabled && queryEmbeds[s.Embeddings.Model] == nil { + vector, err := embedding.GetEmbeddings(s.Embeddings.OpenAIBaseURL, s.Embeddings.Model, s.Embeddings.APIKey, q) + if err != nil { + fmt.Printf("Error getting embeddings for search query: %v\n", err) + respond(httpResponse{ + status: 500, + Success: false, + Error: "Internal server error", + }) + return + } + queryEmbeds[s.Embeddings.Model] = vector + + } + } + + embeddedQueries := make(map[string][]float32) + + for _, s := range foundSources { + if s.Embeddings.Enabled { + embeddedQueries[s.ID] = queryEmbeds[s.Embeddings.Model] + } + } + + sourceList := make([]string, 0) + for _, s := range foundSources { + sourceList = append(sourceList, s.ID) + } + + results, err := db.HybridSearch(sourceList, q, embeddedQueries, 10) + if err != nil { + fmt.Printf("Error generating search results: %v\n", err) + respond(httpResponse{ + status: 500, + Success: false, + Error: "Internal server error", + }) + return + } + + respond(httpResponse{ + status: 200, + Success: true, + Results: results, + }) + }) + addr := fmt.Sprintf("%v:%v", cfg.HTTP.Listen, cfg.HTTP.Port) fmt.Printf("Listening on http://%v\n", addr) log.Fatal(http.ListenAndServe(addr, nil)) diff --git a/app/server/templates/highlight.html.tmpl b/app/server/templates/highlight.html.tmpl index f9978aa..1b59d86 100644 --- a/app/server/templates/highlight.html.tmpl +++ b/app/server/templates/highlight.html.tmpl @@ -1,9 +1,9 @@ -{{ define "highlight" }} - {{ range $index, $value := . }} - {{ if $value.Highlighted }} - {{ $value.Content }} - {{ else }} - {{ $value.Content }} - {{ end }} - {{ end }} -{{ end }} +{{- define "highlight" -}} + {{- range $index, $value := . -}} + {{- if $value.Highlighted -}} + {{- $value.Content -}} + {{- else -}} + {{- $value.Content -}} + {{- end -}} + {{- end -}} +{{- end -}} diff --git a/app/server/templates/index.html.tmpl b/app/server/templates/index.html.tmpl index 4a5c087..8af45b7 100644 --- a/app/server/templates/index.html.tmpl +++ b/app/server/templates/index.html.tmpl @@ -1,4 +1,4 @@ -{{ define "index" }} +{{- define "index" -}} @@ -28,7 +28,7 @@ crossorigin="anonymous" defer > - {{ .CustomHTML }} + {{- .CustomHTML -}}
@@ -44,26 +44,26 @@

Search

Include results from: - {{ range $index, $value := .Sources }} + {{- range $index, $value := .Sources -}} - {{ end }} + {{- end -}}

-{{ end }} +{{- end -}} diff --git a/app/server/templates/results.html.tmpl b/app/server/templates/results.html.tmpl index c072425..2000c5a 100644 --- a/app/server/templates/results.html.tmpl +++ b/app/server/templates/results.html.tmpl @@ -1,48 +1,50 @@ -{{ define "results" }} +{{- define "results" -}}
- {{ if .Results }} - {{ range $index, $result := .Results }} + {{- if .Results -}} + {{- range $index, $result := .Results -}}
- {{ template "highlight" $result.Title }} -

{{ template "highlight" $result.Content }}

+ {{- template "highlight" $result.Title -}} +

{{- template "highlight" $result.Content -}}

- {{ end }} + {{- end -}} - {{ range $index, $value := .Pages }} - {{ if $value.Current }} + {{- range $index, $value := .Pages -}} + {{- if $value.Current -}} - {{ else }} - + {{- else -}} + - {{ end }} - {{ end }} + {{- end -}} + {{- end -}} -

{{ .Total }} results found in {{ .Time }}s.

- {{ else if .Query }} +

{{- .Total -}} results found in {{- .Time -}}s.

+ {{- else if .Query -}} No results found. - {{ end }} -{{ end }} + {{- end -}} +{{- end -}}