From d453904f003f8d9e64656670fb48b5dfc341bd2b Mon Sep 17 00:00:00 2001 From: FluxCapacitor2 <31071265+FluxCapacitor2@users.noreply.github.com> Date: Mon, 18 Nov 2024 00:56:49 -0500 Subject: [PATCH] Allow the embedding model and dimensions to be changed per-source --- app/config/config.go | 16 ++-- app/crawler/crawler_test.go | 2 +- app/database/db.go | 9 +- app/database/db_sqlite.go | 51 ++++++------ app/database/db_sqlite_embedding.sql | 25 ++++++ app/database/db_sqlite_setup.sql | 15 ---- app/database/db_sqlite_test.go | 33 +++++++- app/easysearch.go | 11 ++- app/embedding/embedding.go | 5 +- app/processqueue.go | 16 +++- app/server/server.go | 118 +++++++++++++++++---------- config-sample.yml | 24 +++--- 12 files changed, 205 insertions(+), 120 deletions(-) create mode 100644 app/database/db_sqlite_embedding.sql diff --git a/app/config/config.go b/app/config/config.go index 57c771b..0214abb 100644 --- a/app/config/config.go +++ b/app/config/config.go @@ -15,14 +15,6 @@ type Config struct { Driver string ConnectionString string `yaml:"connectionString"` } `yaml:"db"` - Embeddings struct { - OpenAIBaseURL string `yaml:"openaiBaseUrl"` - APIKey string `yaml:"apiKey"` - Model string - Dimensions int - ChunkSize int `yaml:"chunkSize"` - ChunkOverlap int `yaml:"chunkOverlap"` - } Sources []Source ResultsPage ResultsPageConfig `yaml:"resultsPage"` } @@ -58,7 +50,13 @@ type Source struct { Embeddings struct { Enabled bool // The maximum number of requests per minute to the embeddings API - Speed int + Speed int + OpenAIBaseURL string `yaml:"openaiBaseUrl"` + APIKey string `yaml:"apiKey"` + Model string + Dimensions int + ChunkSize int `yaml:"chunkSize"` + ChunkOverlap int `yaml:"chunkOverlap"` } } diff --git a/app/crawler/crawler_test.go b/app/crawler/crawler_test.go index 34b8c2f..2694f79 100644 --- a/app/crawler/crawler_test.go +++ b/app/crawler/crawler_test.go @@ -15,7 +15,7 @@ func createDB(t *testing.T) database.Database { t.Fatalf("database creation failed: %v", err) } - if err := db.Setup(1536); err != nil { + if err := db.Setup(); err != nil { t.Fatalf("database setup failed: %v", err) } diff --git a/app/database/db.go b/app/database/db.go index b66f877..f2ebb11 100644 --- a/app/database/db.go +++ b/app/database/db.go @@ -2,7 +2,8 @@ package database type Database interface { // Create necessary tables - Setup(vectorDimension int) error + Setup() error + SetupVectorTables(sourceID string, dimension int) error // Set the status of items that have been Processing for over a minute to Pending and remove any Finished entries Cleanup() error @@ -38,12 +39,12 @@ type Database interface { AddToEmbedQueue(id int64, chunks []string) error PopEmbedQueue(source string) (*EmbedQueueItem, error) UpdateEmbedQueueEntry(id int64, status QueueItemStatus) error - AddEmbedding(pageID int64, chunkIndex int, chunk string, vector []float32) error + AddEmbedding(pageID int64, sourceID string, chunkIndex int, chunk string, vector []float32) error // Adds pages with no embeddings to the embed queue - StartEmbeddings(chunkSize int, chunkOverlap int) error + StartEmbeddings(getChunkDetails func(sourceID string) (chunkSize int, chunkOverlap int)) error - SimilaritySearch(sources []string, query []float32, limit int) ([]SimilarityResult, error) + SimilaritySearch(sourceID string, query []float32, limit int) ([]SimilarityResult, error) } type Page struct { diff --git a/app/database/db_sqlite.go b/app/database/db_sqlite.go index db0c681..ca3f474 100644 --- a/app/database/db_sqlite.go +++ b/app/database/db_sqlite.go @@ -24,9 +24,16 @@ type SQLiteDatabase struct { //go:embed db_sqlite_setup.sql var setupCommands string -func (db *SQLiteDatabase) Setup(vectorDimension int) error { - vec.Auto() // Load the `sqlite-vec` extension - _, err := db.conn.Exec(fmt.Sprintf(setupCommands, vectorDimension)) +//go:embed db_sqlite_embedding.sql +var embedSetupCommands string + +func (db *SQLiteDatabase) Setup() error { + _, err := db.conn.Exec(setupCommands) + return err +} + +func (db *SQLiteDatabase) SetupVectorTables(sourceID string, dimensions int) error { + _, err := db.conn.Exec(fmt.Sprintf(embedSetupCommands, sourceID, sourceID, sourceID, sourceID, sourceID, dimensions)) return err } @@ -251,37 +258,24 @@ func processResult(input string, start string, end string) []Match { } } -func (db *SQLiteDatabase) SimilaritySearch(sources []string, query []float32, limit int) ([]SimilarityResult, error) { +func (db *SQLiteDatabase) SimilaritySearch(sourceID string, query []float32, limit int) ([]SimilarityResult, error) { serialized, err := vec.SerializeFloat32(query) if err != nil { return nil, err } - args := []any{serialized} - for _, src := range sources { - args = append(args, src) - } - args = append(args, Finished, limit, limit*5) - - // TODO: In this query, we use `k = limit * 5` to select the first 5N results and then limit them to N with a LIMIT clause. - // When `sqlite-vec` adds support for metadata columns in the next release, this workaround will be unnecessary - // and we can filter the sources directly. - - sourcesString := strings.Repeat("?, ", len(sources)-1) + "?" - rows, err := db.conn.Query(fmt.Sprintf(` - SELECT pages_vec.distance, pages.url, pages.title, vec_chunks.chunk FROM pages_vec + SELECT pages_vec_%s.distance, pages.url, pages.title, vec_chunks.chunk FROM pages_vec_%s JOIN vec_chunks USING (id) JOIN pages ON pages.id = vec_chunks.page WHERE - pages_vec.embedding MATCH ? AND - pages.source IN (%s) AND + pages_vec_%s.embedding MATCH ? AND pages.status = ? AND k = ? - ORDER BY pages_vec.distance + ORDER BY pages_vec_%s.distance LIMIT ?; - `, sourcesString), args...) + `, sourceID, sourceID, sourceID, sourceID), serialized, Finished, limit, limit) if err != nil { return nil, err @@ -401,7 +395,7 @@ func (db *SQLiteDatabase) PopEmbedQueue(source string) (*EmbedQueueItem, error) return item, nil } -func (db *SQLiteDatabase) AddEmbedding(pageID int64, chunkIndex int, chunk string, vector []float32) error { +func (db *SQLiteDatabase) AddEmbedding(pageID int64, sourceID string, chunkIndex int, chunk string, vector []float32) error { serialized, err := vec.SerializeFloat32(vector) if err != nil { return err @@ -424,7 +418,7 @@ func (db *SQLiteDatabase) AddEmbedding(pageID int64, chunkIndex int, chunk strin return err } - _, err = tx.Exec("INSERT INTO pages_vec (id, embedding) VALUES (?, ?);", id, serialized) + _, err = tx.Exec(fmt.Sprintf("INSERT INTO pages_vec_%s (id, embedding) VALUES (?, ?);", sourceID), id, serialized) if err != nil { if err := tx.Rollback(); err != nil { return err @@ -479,20 +473,22 @@ func (db *SQLiteDatabase) Cleanup() error { return err } -func (db *SQLiteDatabase) StartEmbeddings(chunkSize int, chunkOverlap int) error { +func (db *SQLiteDatabase) StartEmbeddings(getChunkDetails func(sourceID string) (chunkSize int, chunkOverlap int)) error { // If a page has been indexed but has no embeddings, make sure an embedding job has been queued - rows, err := db.conn.Query("SELECT id, content FROM pages WHERE status = ? AND id NOT IN (SELECT page FROM vec_chunks) AND id NOT IN (SELECT page FROM embed_queue);", Finished) + rows, err := db.conn.Query("SELECT id, source, content FROM pages WHERE status = ? AND id NOT IN (SELECT page FROM vec_chunks) AND id NOT IN (SELECT page FROM embed_queue);", Finished) if err != nil { return fmt.Errorf("error finding pages without embeddings: %v", err) } for rows.Next() { var id int64 + var sourceID string var content string - err := rows.Scan(&id, &content) + err := rows.Scan(&id, &sourceID, &content) if err != nil { return err } + chunkSize, chunkOverlap := getChunkDetails(sourceID) chunks, err := embedding.ChunkText(content, chunkSize, chunkOverlap) if err != nil { @@ -563,13 +559,14 @@ func (db *SQLiteDatabase) SetCanonical(source string, url string, canonical stri } func SQLiteFromFile(fileName string) (*SQLiteDatabase, error) { + vec.Auto() // Load the `sqlite-vec` extension conn, err := sql.Open("sqlite3", fileName) if err != nil { return nil, err } - return &SQLiteDatabase{conn}, nil + return SQLite(conn) } func SQLite(conn *sql.DB) (*SQLiteDatabase, error) { diff --git a/app/database/db_sqlite_embedding.sql b/app/database/db_sqlite_embedding.sql new file mode 100644 index 0000000..2610fef --- /dev/null +++ b/app/database/db_sqlite_embedding.sql @@ -0,0 +1,25 @@ +-- This script creates the embedding tables for one source. + +-- Required format string placeholders: +-- (Repeated 4 times) source ID (string) +-- vector size (integer) + +-- Why use separate tables for each source? +-- * Faster query times when there are many sources with lots of embeddings that aren't included in the user's query +-- * More accurate `k` limit when there are many sources that aren't included in the query +-- * In the future, different sources could use different embedding sources with different vector sizes + +CREATE TRIGGER IF NOT EXISTS pages_refresh_vector_embeddings_%s AFTER UPDATE ON pages +WHEN old.url != new.url OR old.title != new.title OR old.description != new.description OR old.content != new.content BEGIN + -- If the page has associated vector embeddings, they must be recomputed when the text changes + DELETE FROM pages_vec_%s WHERE id IN (SELECT * FROM vec_chunks WHERE page = old.id); +END; + +CREATE TRIGGER IF NOT EXISTS delete_embedding_on_delete_chunk_%s AFTER DELETE ON vec_chunks BEGIN + DELETE FROM pages_vec_%s WHERE id = old.id; +END; + +CREATE VIRTUAL TABLE IF NOT EXISTS pages_vec_%s USING vec0( + id INTEGER PRIMARY KEY, + embedding FLOAT[%d] distance_metric=cosine +); diff --git a/app/database/db_sqlite_setup.sql b/app/database/db_sqlite_setup.sql index a7e7055..6ec1829 100644 --- a/app/database/db_sqlite_setup.sql +++ b/app/database/db_sqlite_setup.sql @@ -107,11 +107,6 @@ CREATE TABLE IF NOT EXISTS vec_chunks( CREATE UNIQUE INDEX IF NOT EXISTS vec_chunks_page_chunk_unique ON vec_chunks(page, chunkIndex); -CREATE VIRTUAL TABLE IF NOT EXISTS pages_vec USING vec0( - id INTEGER PRIMARY KEY, - embedding FLOAT[%d] distance_metric=cosine -- This number is populated based on the config -); - CREATE TABLE IF NOT EXISTS embed_queue( id INTEGER PRIMARY KEY, page INTEGER NOT NULL, @@ -124,13 +119,3 @@ CREATE TABLE IF NOT EXISTS embed_queue( ) STRICT; CREATE UNIQUE INDEX IF NOT EXISTS embed_queue_page_chunk_unique ON embed_queue(page, chunkIndex); - -CREATE TRIGGER IF NOT EXISTS pages_refresh_vector_embeddings AFTER UPDATE ON pages -WHEN old.url != new.url OR old.title != new.title OR old.description != new.description OR old.content != new.content BEGIN - -- If the page has associated vector embeddings, they must be recomputed when the text changes - DELETE FROM pages_vec WHERE rowid IN (SELECT * FROM vec_chunks WHERE page = old.id); -END; - -CREATE TRIGGER IF NOT EXISTS delete_embedding_on_delete_chunk AFTER DELETE ON vec_chunks BEGIN - DELETE FROM pages_vec WHERE id = old.id; -END; diff --git a/app/database/db_sqlite_test.go b/app/database/db_sqlite_test.go index 12c4d26..5d0c1cc 100644 --- a/app/database/db_sqlite_test.go +++ b/app/database/db_sqlite_test.go @@ -5,16 +5,19 @@ import ( "reflect" "testing" "time" + + vec "github.com/asg017/sqlite-vec-go-bindings/cgo" ) func createDB(t *testing.T) Database { + vec.Auto() db, err := SQLiteFromFile(path.Join(t.TempDir(), "temp.db")) if err != nil { t.Fatalf("database creation failed: %v", err) } - if err := db.Setup(1536); err != nil { + if err := db.Setup(); err != nil { t.Fatalf("database setup failed: %v", err) } @@ -25,6 +28,34 @@ func TestSetup(t *testing.T) { createDB(t) } +func TestVecSetup(t *testing.T) { + db := createDB(t) + err := db.SetupVectorTables("1", 768) + if err != nil { + t.Fatalf("error setting up vector tables: %v\n", err) + } +} + +func TestCleanup(t *testing.T) { + db := createDB(t) + err := db.Cleanup() + if err != nil { + t.Fatalf("error occurred in Cleanup: %v\n", err) + } +} + +func TestStartEmbeddings(t *testing.T) { + db := createDB(t) + err := db.SetupVectorTables("1", 768) + if err != nil { + t.Fatalf("error creating vector table: %v\n", err) + } + err = db.StartEmbeddings(func(sourceID string) (chunkSize int, chunkOverlap int) { return 200, 30 }) + if err != nil { + t.Fatalf("error occurred in StartEmbeddings: %v\n", err) + } +} + func TestEscape(t *testing.T) { testCases := []struct { input string diff --git a/app/easysearch.go b/app/easysearch.go index c23ce97..3905bef 100644 --- a/app/easysearch.go +++ b/app/easysearch.go @@ -44,11 +44,20 @@ func main() { { // Create DB tables if they don't exist (and set SQLite to WAL mode) - err := db.Setup(config.Embeddings.Dimensions) + err := db.Setup() if err != nil { panic(fmt.Sprintf("Failed to set up database: %v", err)) } + + for _, src := range config.Sources { + if src.Embeddings.Enabled { + err := db.SetupVectorTables(src.ID, src.Embeddings.Dimensions) + if err != nil { + panic(fmt.Sprintf("Failed to set up embeddings database tables for source %v: %v", src.ID, err)) + } + } + } } // Continuously pop items off each source's queue and crawl them diff --git a/app/embedding/embedding.go b/app/embedding/embedding.go index 55493e5..689c427 100644 --- a/app/embedding/embedding.go +++ b/app/embedding/embedding.go @@ -10,7 +10,10 @@ import ( ) func GetEmbeddings(openAIBaseURL string, model string, apiKey string, chunk string) ([]float32, error) { - + if apiKey == "" { + // `langchaingo` emits an error when the OpenAI API key is empty, even if the API URL has been changed to one that doesn't require authentication. + apiKey = "-" + } llm, err := openai.New(openai.WithBaseURL(openAIBaseURL), openai.WithEmbeddingModel(model), openai.WithToken(apiKey)) if err != nil { return nil, fmt.Errorf("error setting up LLM for embedding: %v", err) diff --git a/app/processqueue.go b/app/processqueue.go index c67fa50..d941ab6 100644 --- a/app/processqueue.go +++ b/app/processqueue.go @@ -54,7 +54,15 @@ func startQueueJob(db database.Database, config *config.Config) { if err != nil { fmt.Printf("Error cleaning queue: %v\n", err) } - err = db.StartEmbeddings(config.Embeddings.ChunkSize, config.Embeddings.ChunkOverlap) + + err = db.StartEmbeddings(func(sourceID string) (chunkSize int, chunkOverlap int) { + for _, src := range config.Sources { + if src.ID == sourceID { + return src.Embeddings.ChunkSize, src.Embeddings.ChunkOverlap + } + } + return 200, 30 + }) if err != nil { fmt.Printf("Error queueing pages that need embeddings: %v\n", err) } @@ -82,14 +90,14 @@ func processEmbedQueue(db database.Database, config *config.Config, src config.S } } - vector, err := embedding.GetEmbeddings(config.Embeddings.OpenAIBaseURL, config.Embeddings.Model, config.Embeddings.APIKey, item.Content) + vector, err := embedding.GetEmbeddings(src.Embeddings.OpenAIBaseURL, src.Embeddings.Model, src.Embeddings.APIKey, item.Content) if err != nil { fmt.Printf("error getting embeddings: %v\n", err) markFailure() return } - err = db.AddEmbedding(item.PageID, item.ChunkIndex, item.Content, vector) + err = db.AddEmbedding(item.PageID, src.ID, item.ChunkIndex, item.Content, vector) if err != nil { fmt.Printf("error saving embedding: %v\n", err) markFailure() @@ -138,7 +146,7 @@ func processCrawlQueue(db database.Database, config *config.Config, src config.S } else { // Chunk the page into sections and add it to the embedding queue if result.PageID > 0 { - chunks, err := embedding.ChunkText(result.Content.Content, config.Embeddings.ChunkSize, config.Embeddings.ChunkOverlap) + chunks, err := embedding.ChunkText(result.Content.Content, src.Embeddings.ChunkSize, src.Embeddings.ChunkOverlap) if err != nil { fmt.Printf("error chunking page: %v\n", err) diff --git a/app/server/server.go b/app/server/server.go index 7175331..6a4693c 100644 --- a/app/server/server.go +++ b/app/server/server.go @@ -1,6 +1,7 @@ package server import ( + "cmp" "embed" "encoding/json" "fmt" @@ -28,9 +29,9 @@ type paginationInfo struct { Total uint32 `json:"total"` } -func Start(db database.Database, config *config.Config) { +func Start(db database.Database, cfg *config.Config) { - if config.ResultsPage.Enabled { + if cfg.ResultsPage.Enabled { http.Handle("/static/", http.FileServerFS(content)) t, err := template.ParseFS(content, "templates/*.tmpl") @@ -40,7 +41,7 @@ func Start(db database.Database, config *config.Config) { } http.HandleFunc("/{$}", func(w http.ResponseWriter, req *http.Request) { - renderTemplateWithResults(db, config, req, w, t, "index") + renderTemplateWithResults(db, cfg, req, w, t, "index") }) http.HandleFunc("/results", func(w http.ResponseWriter, req *http.Request) { @@ -55,7 +56,7 @@ func Start(db database.Database, config *config.Config) { w.Header().Set("HX-Replace-URL", url.String()) } - renderTemplateWithResults(db, config, req, w, t, "results") + renderTemplateWithResults(db, cfg, req, w, t, "results") }) } @@ -70,7 +71,7 @@ func Start(db database.Database, config *config.Config) { } timeStart := time.Now().UnixMicro() - var response *httpResponse + var response httpResponse src := req.URL.Query()["source"] q := req.URL.Query().Get("q") @@ -79,7 +80,7 @@ func Start(db database.Database, config *config.Config) { if q != "" && src != nil && len(src) > 0 && err == nil { results, total, err := db.Search(src, q, uint32(page), 10) if err != nil { - response = &httpResponse{ + response = httpResponse{ status: 500, Success: false, Error: "Internal server error", @@ -87,7 +88,7 @@ func Start(db database.Database, config *config.Config) { fmt.Printf("Error generating search results: %v\n", err) } else { - response = &httpResponse{ + response = httpResponse{ status: 200, Success: true, Results: results, @@ -99,7 +100,7 @@ func Start(db database.Database, config *config.Config) { } } } else { - response = &httpResponse{ + response = httpResponse{ status: 400, Success: false, Error: "Bad request", @@ -127,60 +128,89 @@ func Start(db database.Database, config *config.Config) { } timeStart := time.Now().UnixMicro() - var response *httpResponse + + respond := func(response httpResponse) { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(int(response.status)) + response.ResponseTime = float64(time.Now().UnixMicro()-timeStart) / 1e6 + str, err := json.Marshal(response) + if err != nil { + w.Write([]byte(`{"success":"false","error":"Failed to marshal struct into JSON"}`)) + } else { + w.Write([]byte(str)) + } + } src := req.URL.Query()["source"] q := req.URL.Query().Get("q") - if q != "" && src != nil && len(src) > 0 { + if q == "" || src == nil || len(src) == 0 { + respond(httpResponse{ + status: 400, + Success: false, + Error: "Bad request", + }) + return + } - vector, err := embedding.GetEmbeddings(config.Embeddings.OpenAIBaseURL, config.Embeddings.Model, config.Embeddings.APIKey, q) - if err != nil { - response = &httpResponse{ - status: 500, - Success: false, - Error: "Internal server error", + foundSources := make([]config.Source, 0, len(src)) + + for _, sourceID := range src { + for _, s := range cfg.Sources { + if s.ID == sourceID { + foundSources = append(foundSources, s) + break } + } + } - fmt.Printf("Error generating embeddings for search query: %v\n", err) - } else { - results, err := db.SimilaritySearch(src, vector, 10) + queryEmbeds := make(map[string][]float32) + + for _, s := range foundSources { + if s.Embeddings.Enabled && queryEmbeds[s.Embeddings.Model] == nil { + vector, err := embedding.GetEmbeddings(s.Embeddings.OpenAIBaseURL, s.Embeddings.Model, s.Embeddings.APIKey, q) if err != nil { - response = &httpResponse{ + fmt.Printf("Error getting embeddings for search query: %v\n", err) + respond(httpResponse{ status: 500, Success: false, Error: "Internal server error", - } - - fmt.Printf("Error generating search results: %v\n", err) - } else { - response = &httpResponse{ - status: 200, - Success: true, - Results: results, - } + }) + return } - } - } else { - response = &httpResponse{ - status: 400, - Success: false, - Error: "Bad request", + queryEmbeds[s.Embeddings.Model] = vector + } } - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(int(response.status)) - response.ResponseTime = float64(time.Now().UnixMicro()-timeStart) / 1e6 - str, err := json.Marshal(response) - if err != nil { - w.Write([]byte(`{"success":"false","error":"Failed to marshal struct into JSON"}`)) - } else { - w.Write([]byte(str)) + allResults := make([]database.SimilarityResult, 0) + + for _, s := range foundSources { + results, err := db.SimilaritySearch(s.ID, queryEmbeds[s.Embeddings.Model], 10) + if err != nil { + fmt.Printf("Error generating search results: %v\n", err) + respond(httpResponse{ + status: 500, + Success: false, + Error: "Internal server error", + }) + return + } + allResults = append(allResults, results...) } + + slices.SortFunc(allResults, func(a database.SimilarityResult, b database.SimilarityResult) int { + return cmp.Compare(a.Similarity, b.Similarity) + }) + + respond(httpResponse{ + status: 200, + Success: true, + Results: allResults, + }) }) - addr := fmt.Sprintf("%v:%v", config.HTTP.Listen, config.HTTP.Port) + addr := fmt.Sprintf("%v:%v", cfg.HTTP.Listen, cfg.HTTP.Port) fmt.Printf("Listening on http://%v\n", addr) log.Fatal(http.ListenAndServe(addr, nil)) } diff --git a/config-sample.yml b/config-sample.yml index 4a49504..1570eb3 100644 --- a/config-sample.yml +++ b/config-sample.yml @@ -45,18 +45,16 @@ sources: # This number is used to start a scheduled task, so don't set this number too high to conserve CPU cycles. speed: 30 -# This section can be omitted if `embeddings.enabled` is set to `false` for all your sources. -embeddings: - # Use OpenAI's embedding model: - # openaiBaseUrl: https://api.openai.com/v1/ - # model: text-embedding-3-small - # dimensions: 1536 - # apiKey: sk-************************************* + # Use OpenAI's embedding model: + # openaiBaseUrl: https://api.openai.com/v1/ + # model: text-embedding-3-small + # dimensions: 1536 + # apiKey: sk-************************************* - # You can also use any OpenAI-compatible API, like a local Ollama server: - openaiBaseUrl: http://localhost:11434/v1/ - model: bge-m3 - dimensions: 1024 + # You can also use any OpenAI-compatible API, like a local Ollama server: + openaiBaseUrl: http://localhost:11434/v1/ + model: bge-m3 + dimensions: 1024 - chunkSize: 200 - chunkOverlap: 30 # 15% overlap + chunkSize: 200 + chunkOverlap: 30 # 15% overlap