Skip to content

Commit

Permalink
Allow the embedding model and dimensions to be changed per-source
Browse files Browse the repository at this point in the history
  • Loading branch information
FluxCapacitor2 committed Nov 18, 2024
1 parent e7270e1 commit d453904
Show file tree
Hide file tree
Showing 12 changed files with 205 additions and 120 deletions.
16 changes: 7 additions & 9 deletions app/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,6 @@ type Config struct {
Driver string
ConnectionString string `yaml:"connectionString"`
} `yaml:"db"`
Embeddings struct {
OpenAIBaseURL string `yaml:"openaiBaseUrl"`
APIKey string `yaml:"apiKey"`
Model string
Dimensions int
ChunkSize int `yaml:"chunkSize"`
ChunkOverlap int `yaml:"chunkOverlap"`
}
Sources []Source
ResultsPage ResultsPageConfig `yaml:"resultsPage"`
}
Expand Down Expand Up @@ -58,7 +50,13 @@ type Source struct {
Embeddings struct {
Enabled bool
// The maximum number of requests per minute to the embeddings API
Speed int
Speed int
OpenAIBaseURL string `yaml:"openaiBaseUrl"`
APIKey string `yaml:"apiKey"`
Model string
Dimensions int
ChunkSize int `yaml:"chunkSize"`
ChunkOverlap int `yaml:"chunkOverlap"`
}
}

Expand Down
2 changes: 1 addition & 1 deletion app/crawler/crawler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func createDB(t *testing.T) database.Database {
t.Fatalf("database creation failed: %v", err)
}

if err := db.Setup(1536); err != nil {
if err := db.Setup(); err != nil {
t.Fatalf("database setup failed: %v", err)
}

Expand Down
9 changes: 5 additions & 4 deletions app/database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ package database

type Database interface {
// Create necessary tables
Setup(vectorDimension int) error
Setup() error
SetupVectorTables(sourceID string, dimension int) error

// Set the status of items that have been Processing for over a minute to Pending and remove any Finished entries
Cleanup() error
Expand Down Expand Up @@ -38,12 +39,12 @@ type Database interface {
AddToEmbedQueue(id int64, chunks []string) error
PopEmbedQueue(source string) (*EmbedQueueItem, error)
UpdateEmbedQueueEntry(id int64, status QueueItemStatus) error
AddEmbedding(pageID int64, chunkIndex int, chunk string, vector []float32) error
AddEmbedding(pageID int64, sourceID string, chunkIndex int, chunk string, vector []float32) error

// Adds pages with no embeddings to the embed queue
StartEmbeddings(chunkSize int, chunkOverlap int) error
StartEmbeddings(getChunkDetails func(sourceID string) (chunkSize int, chunkOverlap int)) error

SimilaritySearch(sources []string, query []float32, limit int) ([]SimilarityResult, error)
SimilaritySearch(sourceID string, query []float32, limit int) ([]SimilarityResult, error)
}

type Page struct {
Expand Down
51 changes: 24 additions & 27 deletions app/database/db_sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,16 @@ type SQLiteDatabase struct {
//go:embed db_sqlite_setup.sql
var setupCommands string

func (db *SQLiteDatabase) Setup(vectorDimension int) error {
vec.Auto() // Load the `sqlite-vec` extension
_, err := db.conn.Exec(fmt.Sprintf(setupCommands, vectorDimension))
//go:embed db_sqlite_embedding.sql
var embedSetupCommands string

func (db *SQLiteDatabase) Setup() error {
_, err := db.conn.Exec(setupCommands)
return err
}

func (db *SQLiteDatabase) SetupVectorTables(sourceID string, dimensions int) error {
_, err := db.conn.Exec(fmt.Sprintf(embedSetupCommands, sourceID, sourceID, sourceID, sourceID, sourceID, dimensions))
return err
}

Expand Down Expand Up @@ -251,37 +258,24 @@ func processResult(input string, start string, end string) []Match {
}
}

func (db *SQLiteDatabase) SimilaritySearch(sources []string, query []float32, limit int) ([]SimilarityResult, error) {
func (db *SQLiteDatabase) SimilaritySearch(sourceID string, query []float32, limit int) ([]SimilarityResult, error) {

serialized, err := vec.SerializeFloat32(query)
if err != nil {
return nil, err
}

args := []any{serialized}
for _, src := range sources {
args = append(args, src)
}
args = append(args, Finished, limit, limit*5)

// TODO: In this query, we use `k = limit * 5` to select the first 5N results and then limit them to N with a LIMIT clause.
// When `sqlite-vec` adds support for metadata columns in the next release, this workaround will be unnecessary
// and we can filter the sources directly.

sourcesString := strings.Repeat("?, ", len(sources)-1) + "?"

rows, err := db.conn.Query(fmt.Sprintf(`
SELECT pages_vec.distance, pages.url, pages.title, vec_chunks.chunk FROM pages_vec
SELECT pages_vec_%s.distance, pages.url, pages.title, vec_chunks.chunk FROM pages_vec_%s
JOIN vec_chunks USING (id)
JOIN pages ON pages.id = vec_chunks.page
WHERE
pages_vec.embedding MATCH ? AND
pages.source IN (%s) AND
pages_vec_%s.embedding MATCH ? AND
pages.status = ? AND
k = ?
ORDER BY pages_vec.distance
ORDER BY pages_vec_%s.distance
LIMIT ?;
`, sourcesString), args...)
`, sourceID, sourceID, sourceID, sourceID), serialized, Finished, limit, limit)

if err != nil {
return nil, err
Expand Down Expand Up @@ -401,7 +395,7 @@ func (db *SQLiteDatabase) PopEmbedQueue(source string) (*EmbedQueueItem, error)
return item, nil
}

func (db *SQLiteDatabase) AddEmbedding(pageID int64, chunkIndex int, chunk string, vector []float32) error {
func (db *SQLiteDatabase) AddEmbedding(pageID int64, sourceID string, chunkIndex int, chunk string, vector []float32) error {
serialized, err := vec.SerializeFloat32(vector)
if err != nil {
return err
Expand All @@ -424,7 +418,7 @@ func (db *SQLiteDatabase) AddEmbedding(pageID int64, chunkIndex int, chunk strin
return err
}

_, err = tx.Exec("INSERT INTO pages_vec (id, embedding) VALUES (?, ?);", id, serialized)
_, err = tx.Exec(fmt.Sprintf("INSERT INTO pages_vec_%s (id, embedding) VALUES (?, ?);", sourceID), id, serialized)
if err != nil {
if err := tx.Rollback(); err != nil {
return err
Expand Down Expand Up @@ -479,20 +473,22 @@ func (db *SQLiteDatabase) Cleanup() error {
return err
}

func (db *SQLiteDatabase) StartEmbeddings(chunkSize int, chunkOverlap int) error {
func (db *SQLiteDatabase) StartEmbeddings(getChunkDetails func(sourceID string) (chunkSize int, chunkOverlap int)) error {
// If a page has been indexed but has no embeddings, make sure an embedding job has been queued
rows, err := db.conn.Query("SELECT id, content FROM pages WHERE status = ? AND id NOT IN (SELECT page FROM vec_chunks) AND id NOT IN (SELECT page FROM embed_queue);", Finished)
rows, err := db.conn.Query("SELECT id, source, content FROM pages WHERE status = ? AND id NOT IN (SELECT page FROM vec_chunks) AND id NOT IN (SELECT page FROM embed_queue);", Finished)
if err != nil {
return fmt.Errorf("error finding pages without embeddings: %v", err)
}
for rows.Next() {
var id int64
var sourceID string
var content string
err := rows.Scan(&id, &content)
err := rows.Scan(&id, &sourceID, &content)
if err != nil {
return err
}

chunkSize, chunkOverlap := getChunkDetails(sourceID)
chunks, err := embedding.ChunkText(content, chunkSize, chunkOverlap)

if err != nil {
Expand Down Expand Up @@ -563,13 +559,14 @@ func (db *SQLiteDatabase) SetCanonical(source string, url string, canonical stri
}

func SQLiteFromFile(fileName string) (*SQLiteDatabase, error) {
vec.Auto() // Load the `sqlite-vec` extension
conn, err := sql.Open("sqlite3", fileName)

if err != nil {
return nil, err
}

return &SQLiteDatabase{conn}, nil
return SQLite(conn)
}

func SQLite(conn *sql.DB) (*SQLiteDatabase, error) {
Expand Down
25 changes: 25 additions & 0 deletions app/database/db_sqlite_embedding.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
-- This script creates the embedding tables for one source.

-- Required format string placeholders:
-- (Repeated 4 times) source ID (string)
-- vector size (integer)

-- Why use separate tables for each source?
-- * Faster query times when there are many sources with lots of embeddings that aren't included in the user's query
-- * More accurate `k` limit when there are many sources that aren't included in the query
-- * In the future, different sources could use different embedding sources with different vector sizes

CREATE TRIGGER IF NOT EXISTS pages_refresh_vector_embeddings_%s AFTER UPDATE ON pages
WHEN old.url != new.url OR old.title != new.title OR old.description != new.description OR old.content != new.content BEGIN
-- If the page has associated vector embeddings, they must be recomputed when the text changes
DELETE FROM pages_vec_%s WHERE id IN (SELECT * FROM vec_chunks WHERE page = old.id);
END;

CREATE TRIGGER IF NOT EXISTS delete_embedding_on_delete_chunk_%s AFTER DELETE ON vec_chunks BEGIN
DELETE FROM pages_vec_%s WHERE id = old.id;
END;

CREATE VIRTUAL TABLE IF NOT EXISTS pages_vec_%s USING vec0(
id INTEGER PRIMARY KEY,
embedding FLOAT[%d] distance_metric=cosine
);
15 changes: 0 additions & 15 deletions app/database/db_sqlite_setup.sql
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,6 @@ CREATE TABLE IF NOT EXISTS vec_chunks(

CREATE UNIQUE INDEX IF NOT EXISTS vec_chunks_page_chunk_unique ON vec_chunks(page, chunkIndex);

CREATE VIRTUAL TABLE IF NOT EXISTS pages_vec USING vec0(
id INTEGER PRIMARY KEY,
embedding FLOAT[%d] distance_metric=cosine -- This number is populated based on the config
);

CREATE TABLE IF NOT EXISTS embed_queue(
id INTEGER PRIMARY KEY,
page INTEGER NOT NULL,
Expand All @@ -124,13 +119,3 @@ CREATE TABLE IF NOT EXISTS embed_queue(
) STRICT;

CREATE UNIQUE INDEX IF NOT EXISTS embed_queue_page_chunk_unique ON embed_queue(page, chunkIndex);

CREATE TRIGGER IF NOT EXISTS pages_refresh_vector_embeddings AFTER UPDATE ON pages
WHEN old.url != new.url OR old.title != new.title OR old.description != new.description OR old.content != new.content BEGIN
-- If the page has associated vector embeddings, they must be recomputed when the text changes
DELETE FROM pages_vec WHERE rowid IN (SELECT * FROM vec_chunks WHERE page = old.id);
END;

CREATE TRIGGER IF NOT EXISTS delete_embedding_on_delete_chunk AFTER DELETE ON vec_chunks BEGIN
DELETE FROM pages_vec WHERE id = old.id;
END;
33 changes: 32 additions & 1 deletion app/database/db_sqlite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@ import (
"reflect"
"testing"
"time"

vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
)

func createDB(t *testing.T) Database {
vec.Auto()
db, err := SQLiteFromFile(path.Join(t.TempDir(), "temp.db"))

if err != nil {
t.Fatalf("database creation failed: %v", err)
}

if err := db.Setup(1536); err != nil {
if err := db.Setup(); err != nil {
t.Fatalf("database setup failed: %v", err)
}

Expand All @@ -25,6 +28,34 @@ func TestSetup(t *testing.T) {
createDB(t)
}

func TestVecSetup(t *testing.T) {
db := createDB(t)
err := db.SetupVectorTables("1", 768)
if err != nil {
t.Fatalf("error setting up vector tables: %v\n", err)
}
}

func TestCleanup(t *testing.T) {
db := createDB(t)
err := db.Cleanup()
if err != nil {
t.Fatalf("error occurred in Cleanup: %v\n", err)
}
}

func TestStartEmbeddings(t *testing.T) {
db := createDB(t)
err := db.SetupVectorTables("1", 768)
if err != nil {
t.Fatalf("error creating vector table: %v\n", err)
}
err = db.StartEmbeddings(func(sourceID string) (chunkSize int, chunkOverlap int) { return 200, 30 })
if err != nil {
t.Fatalf("error occurred in StartEmbeddings: %v\n", err)
}
}

func TestEscape(t *testing.T) {
testCases := []struct {
input string
Expand Down
11 changes: 10 additions & 1 deletion app/easysearch.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,20 @@ func main() {

{
// Create DB tables if they don't exist (and set SQLite to WAL mode)
err := db.Setup(config.Embeddings.Dimensions)
err := db.Setup()

if err != nil {
panic(fmt.Sprintf("Failed to set up database: %v", err))
}

for _, src := range config.Sources {
if src.Embeddings.Enabled {
err := db.SetupVectorTables(src.ID, src.Embeddings.Dimensions)
if err != nil {
panic(fmt.Sprintf("Failed to set up embeddings database tables for source %v: %v", src.ID, err))
}
}
}
}

// Continuously pop items off each source's queue and crawl them
Expand Down
5 changes: 4 additions & 1 deletion app/embedding/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ import (
)

func GetEmbeddings(openAIBaseURL string, model string, apiKey string, chunk string) ([]float32, error) {

if apiKey == "" {
// `langchaingo` emits an error when the OpenAI API key is empty, even if the API URL has been changed to one that doesn't require authentication.
apiKey = "-"
}
llm, err := openai.New(openai.WithBaseURL(openAIBaseURL), openai.WithEmbeddingModel(model), openai.WithToken(apiKey))
if err != nil {
return nil, fmt.Errorf("error setting up LLM for embedding: %v", err)
Expand Down
16 changes: 12 additions & 4 deletions app/processqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,15 @@ func startQueueJob(db database.Database, config *config.Config) {
if err != nil {
fmt.Printf("Error cleaning queue: %v\n", err)
}
err = db.StartEmbeddings(config.Embeddings.ChunkSize, config.Embeddings.ChunkOverlap)

err = db.StartEmbeddings(func(sourceID string) (chunkSize int, chunkOverlap int) {
for _, src := range config.Sources {
if src.ID == sourceID {
return src.Embeddings.ChunkSize, src.Embeddings.ChunkOverlap
}
}
return 200, 30
})
if err != nil {
fmt.Printf("Error queueing pages that need embeddings: %v\n", err)
}
Expand Down Expand Up @@ -82,14 +90,14 @@ func processEmbedQueue(db database.Database, config *config.Config, src config.S
}
}

vector, err := embedding.GetEmbeddings(config.Embeddings.OpenAIBaseURL, config.Embeddings.Model, config.Embeddings.APIKey, item.Content)
vector, err := embedding.GetEmbeddings(src.Embeddings.OpenAIBaseURL, src.Embeddings.Model, src.Embeddings.APIKey, item.Content)
if err != nil {
fmt.Printf("error getting embeddings: %v\n", err)
markFailure()
return
}

err = db.AddEmbedding(item.PageID, item.ChunkIndex, item.Content, vector)
err = db.AddEmbedding(item.PageID, src.ID, item.ChunkIndex, item.Content, vector)
if err != nil {
fmt.Printf("error saving embedding: %v\n", err)
markFailure()
Expand Down Expand Up @@ -138,7 +146,7 @@ func processCrawlQueue(db database.Database, config *config.Config, src config.S
} else {
// Chunk the page into sections and add it to the embedding queue
if result.PageID > 0 {
chunks, err := embedding.ChunkText(result.Content.Content, config.Embeddings.ChunkSize, config.Embeddings.ChunkOverlap)
chunks, err := embedding.ChunkText(result.Content.Content, src.Embeddings.ChunkSize, src.Embeddings.ChunkOverlap)

if err != nil {
fmt.Printf("error chunking page: %v\n", err)
Expand Down
Loading

0 comments on commit d453904

Please sign in to comment.