Skip to content

Commit

Permalink
Implement hybrid (FTS + vector) search and add an API endpoint for it
Browse files Browse the repository at this point in the history
  • Loading branch information
FluxCapacitor2 committed Nov 24, 2024
1 parent 2a7eca0 commit 85f5e45
Show file tree
Hide file tree
Showing 7 changed files with 321 additions and 47 deletions.
12 changes: 12 additions & 0 deletions app/config/config.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package config

import (
"fmt"
"os"
"regexp"

"gopkg.in/yaml.v3"
)
Expand Down Expand Up @@ -62,6 +64,8 @@ type Source struct {
}
}

var sourceIDPattern = regexp.MustCompile("^[a-zA-Z0-9_]+$")

func Read() (*Config, error) {

data, err := os.ReadFile("./config.yml")
Expand All @@ -76,5 +80,13 @@ func Read() (*Config, error) {
return nil, err
}

// Validate the loaded configuration

for _, src := range config.Sources {
if !sourceIDPattern.MatchString(src.ID) {
panic(fmt.Sprintf("Invalid source ID: %v. Source IDs may only contain alphanumeric characters and underscores.", src.ID))
}
}

return config, nil
}
12 changes: 12 additions & 0 deletions app/database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type Database interface {
StartEmbeddings(getChunkDetails func(sourceID string) (chunkSize int, chunkOverlap int)) error

SimilaritySearch(sourceID string, query []float32, limit int) ([]SimilarityResult, error)
HybridSearch(sources []string, queryString string, embeddedQueries map[string][]float32, limit int) ([]HybridResult, error)
}

type Page struct {
Expand Down Expand Up @@ -77,6 +78,17 @@ type SimilarityResult struct {
Similarity float32 `json:"similarity"`
}

type HybridResult struct {
URL string `json:"url"`
Title string `json:"title"`
Description string `json:"description"`
Content []Match `json:"content"`
FTSRank *int `json:"ftsRank"`
VecRank *int `json:"vecRank"`
VecDistance *float64 `json:"vecDistance"`
HybridRank float64 `json:"rank"`
}

type Match struct {
Highlighted bool `json:"highlighted"`
Content string `json:"content"`
Expand Down
152 changes: 152 additions & 0 deletions app/database/db_sqlite.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package database

import (
"bytes"
"database/sql"
"fmt"
"regexp"
"strings"
"text/template"

_ "embed"

Expand Down Expand Up @@ -304,6 +306,156 @@ func (db *SQLiteDatabase) SimilaritySearch(sourceID string, query []float32, lim
return results, err
}

var tmpl *template.Template = template.Must(template.New("hybrid-search").Parse(`
WITH {{ range $index, $value := .Sources -}}
vec_subquery_{{ $value }} AS (
SELECT
vec_chunks.page AS page,
row_number() OVER (ORDER BY distance) AS rank_number,
vec_chunks.chunk AS chunk,
distance
FROM pages_vec_{{ $value }}
JOIN vec_chunks USING (id)
WHERE embedding MATCH ? AND k = ?
-- Select only the most relevant chunk for each page
GROUP BY vec_chunks.page
HAVING MIN(distance)
ORDER BY distance
), {{ end }}fts_subquery AS (
SELECT
pages_fts.rowid AS page,
highlight(pages_fts, 1, ?, ?) AS title,
snippet(pages_fts, 2, ?, ?, '…', 8) AS description,
snippet(pages_fts, 3, ?, ?, '…', 24) AS content,
rank
FROM pages_fts
JOIN pages ON pages.id = pages_fts.rowid
WHERE
pages.source IN (
{{- range $index, $value := .Sources -}}
{{- if gt $index 0 }}, {{ end -}}
?
{{- end -}}
)
AND pages.status = ?
AND pages_fts MATCH ?
LIMIT ?
), fts_ordered AS (
SELECT *, row_number() OVER (ORDER BY rank) AS rank_number
FROM fts_subquery
)
SELECT
pages.url,
coalesce(fts_ordered.title, pages.title) AS title,
coalesce(fts_ordered.description, pages.description) AS description,
coalesce(
fts_ordered.content, {{ range $index, $value := .Sources -}}
{{- if gt $index 0 }}, {{ end -}}
vec_subquery_{{ $value }}.chunk
{{- end }}
) AS content,
coalesce(
{{ range $index, $value := .Sources -}}
{{- if gt $index 0 }}, {{ end -}}
vec_subquery_{{ $value }}.distance
{{- end }}
) AS vec_distance,
coalesce(
{{ range $index, $value := .Sources -}}
{{- if gt $index 0 }}, {{ end -}}
vec_subquery_{{ $value }}.rank_number
{{- end }}
) AS vec_rank,
fts_ordered.rank_number AS fts_rank,
(
{{ range $index, $value := .Sources -}}
coalesce(1.0 / (60 + vec_subquery_{{ $value }}.rank_number) * 0.5, 0.0) +
{{ end -}}
coalesce(1.0 / (60 + fts_ordered.rank_number), 0.0)
) AS combined_rank
FROM fts_ordered
{{ range $index, $value := .Sources -}}
FULL OUTER JOIN vec_subquery_{{ $value }} USING (page)
{{ end -}}
JOIN pages ON pages.id = coalesce(
fts_ordered.page, {{ range $index, $value := .Sources -}}
{{- if gt $index 0 }}, {{ end -}}
vec_subquery_{{ $value }}.page
{{- end }}
)
ORDER BY combined_rank DESC;
`))

func (db *SQLiteDatabase) HybridSearch(sources []string, queryString string, embeddedQueries map[string][]float32, limit int) ([]HybridResult, error) {

// Convert the query vectors to a blob format that `sqlite-vec` will accept
serializedQueries := make(map[string][]byte)

for sourceID, query := range embeddedQueries {
serialized, err := vec.SerializeFloat32(query)
if err != nil {
return nil, err
}
serializedQueries[sourceID] = serialized
}

type TemplateData struct {
Sources []string
}

var query bytes.Buffer
err := tmpl.Execute(&query, TemplateData{Sources: sources})
if err != nil {
return nil, fmt.Errorf("error formatting query: %v", err)
}

args := []any{}

// Vector query args
for _, src := range sources {
args = append(args, serializedQueries[src], limit)
}

// FTS query args
start := uuid.New().String()
end := uuid.New().String()

args = append(args, start, end, start, end, start, end)

for _, src := range sources {
args = append(args, src)
}

args = append(args, Finished, queryString, limit)

rows, err := db.conn.Query(query.String(), args...)

if err != nil {
return nil, err
}

results := make([]HybridResult, 0)

for rows.Next() {
res := HybridResult{}
var content string
err := rows.Scan(&res.URL, &res.Title, &res.Description, &content, &res.VecDistance, &res.VecRank, &res.FTSRank, &res.HybridRank)
if err != nil {
return nil, err
}
res.Content = processResult(content, start, end)
// res.Content = []Match{{Content: content, Highlighted: false}}
results = append(results, res)
}

return results, err
}

func (db *SQLiteDatabase) AddToQueue(source string, referrer string, urls []string, depth int32, isRefresh bool) error {

tx, err := db.conn.Begin()
Expand Down
96 changes: 96 additions & 0 deletions app/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,102 @@ func Start(db database.Database, cfg *config.Config) {
})
})

http.HandleFunc("/api/hybrid-search", func(w http.ResponseWriter, req *http.Request) {
type httpResponse struct {
status int16
Success bool `json:"success"`
Error string `json:"error,omitempty"`
Results []database.HybridResult `json:"results"`
ResponseTime float64 `json:"responseTime"`
}

timeStart := time.Now().UnixMicro()

respond := func(response httpResponse) {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(int(response.status))
response.ResponseTime = float64(time.Now().UnixMicro()-timeStart) / 1e6
str, err := json.Marshal(response)
if err != nil {
w.Write([]byte(`{"success":"false","error":"Failed to marshal struct into JSON"}`))
} else {
w.Write([]byte(str))
}
}

src := req.URL.Query()["source"]
q := req.URL.Query().Get("q")

if q == "" || src == nil || len(src) == 0 {
respond(httpResponse{
status: 400,
Success: false,
Error: "Bad request",
})
return
}

foundSources := make([]config.Source, 0, len(src))

for _, sourceID := range src {
for _, s := range cfg.Sources {
if s.ID == sourceID {
foundSources = append(foundSources, s)
break
}
}
}

queryEmbeds := make(map[string][]float32)

for _, s := range foundSources {
if s.Embeddings.Enabled && queryEmbeds[s.Embeddings.Model] == nil {
vector, err := embedding.GetEmbeddings(s.Embeddings.OpenAIBaseURL, s.Embeddings.Model, s.Embeddings.APIKey, q)
if err != nil {
fmt.Printf("Error getting embeddings for search query: %v\n", err)
respond(httpResponse{
status: 500,
Success: false,
Error: "Internal server error",
})
return
}
queryEmbeds[s.Embeddings.Model] = vector

}
}

embeddedQueries := make(map[string][]float32)

for _, s := range foundSources {
if s.Embeddings.Enabled {
embeddedQueries[s.ID] = queryEmbeds[s.Embeddings.Model]
}
}

sourceList := make([]string, 0)
for _, s := range foundSources {
sourceList = append(sourceList, s.ID)
}

results, err := db.HybridSearch(sourceList, q, embeddedQueries, 10)
if err != nil {
fmt.Printf("Error generating search results: %v\n", err)
respond(httpResponse{
status: 500,
Success: false,
Error: "Internal server error",
})
return
}

respond(httpResponse{
status: 200,
Success: true,
Results: results,
})
})

addr := fmt.Sprintf("%v:%v", cfg.HTTP.Listen, cfg.HTTP.Port)
fmt.Printf("Listening on http://%v\n", addr)
log.Fatal(http.ListenAndServe(addr, nil))
Expand Down
18 changes: 9 additions & 9 deletions app/server/templates/highlight.html.tmpl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
{{ define "highlight" }}
{{ range $index, $value := . }}
{{ if $value.Highlighted }}
<b>{{ $value.Content }}</b>
{{ else }}
{{ $value.Content }}
{{ end }}
{{ end }}
{{ end }}
{{- define "highlight" -}}
{{- range $index, $value := . -}}
{{- if $value.Highlighted -}}
<b>{{- $value.Content -}}</b>
{{- else -}}
{{- $value.Content -}}
{{- end -}}
{{- end -}}
{{- end -}}
Loading

0 comments on commit 85f5e45

Please sign in to comment.