diff --git a/pkg/service/pipeline.go b/pkg/service/pipeline.go index 08c1f2d..17ab8f9 100644 --- a/pkg/service/pipeline.go +++ b/pkg/service/pipeline.go @@ -20,13 +20,13 @@ const chunkLength = 1024 const chunkOverlap = 200 const NamespaceID = "preset" - // Note: this pipeline is for the old indexing pipeline const ConvertDocToMDPipelineID = "indexing-convert-pdf" const DocToMDVersion = "v1.1.1" // TODO: the pipeline id is not correct, need to update the pipeline id const ConvertDocToMDPipelineID2 = "indexing-advanced-convert-doc" + // TODO: the version is not correct, need to update the version const DocToMDVersion2 = "v1.0.1" @@ -300,13 +300,29 @@ func (s *Service) SplitTextPipe(ctx context.Context, caller uuid.UUID, requester return filteredResult, nil } -// EmbeddingTextPipe uses the embedding pipeline to convert text into vectors and consume caller's credits. -// It processes the input texts in batches, triggers the embedding pipeline for each batch, and collects the results. -// The function returns a 2D slice of float32 representing the vectors for the input texts. +// EmbeddingTextPipe converts multiple text inputs into vector embeddings using a pipeline service. +// It processes texts in parallel batches for efficiency while managing resource usage. +// +// Parameters: +// - ctx: Context for the operation +// - caller: UUID of the calling user +// - requester: UUID of the requesting entity (optional) +// - texts: Slice of strings to be converted to embeddings +// +// Returns: +// - [][]float32: 2D slice where each inner slice is a vector embedding +// - error: Any error encountered during processing +// +// The function: +// - Processes texts in batches of 32 +// - Limits concurrent processing to 5 goroutines +// - Maintains input order in the output +// - Cancels all operations if any batch fails func (s *Service) EmbeddingTextPipe(ctx context.Context, caller uuid.UUID, requester uuid.UUID, texts []string) ([][]float32, error) { ctx, ctxCancel := context.WithCancel(ctx) defer ctxCancel() const maxBatchSize = 32 + const maxConcurrentGoroutines = 5 var md metadata.MD if requester != uuid.Nil { md = metadata.New(map[string]string{ @@ -338,7 +354,10 @@ func (s *Service) EmbeddingTextPipe(ctx context.Context, caller uuid.UUID, reque // - Extract the vector from the response. // - Send the result to the results channel. // If an error occurs, send the error to the error channel. + // Create a semaphore channel to limit concurrent goroutines to maxConcurrentGoroutines + sem := make(chan struct{}, maxConcurrentGoroutines) for i := 0; i < len(texts); i += maxBatchSize { + end := i + maxBatchSize if end > len(texts) { end = len(texts) @@ -346,11 +365,17 @@ func (s *Service) EmbeddingTextPipe(ctx context.Context, caller uuid.UUID, reque batch := texts[i:end] batchIndex := i / maxBatchSize + + // Acquire semaphore before starting goroutine + sem <- struct{}{} wg.Add(1) go utils.GoRecover(func() { + // Release semaphore when goroutine completes + defer func() { <-sem }() + defer wg.Done() + func(batch []string, index int) { ctx_ := metadata.NewOutgoingContext(ctx, md) - defer wg.Done() inputs := make([]*structpb.Struct, 0, len(batch)) for _, text := range batch { diff --git a/pkg/worker/worker.go b/pkg/worker/worker.go index 9590bf7..a3d9cf7 100644 --- a/pkg/worker/worker.go +++ b/pkg/worker/worker.go @@ -652,9 +652,34 @@ func (wp *fileToEmbWorkerPool) processChunkingFile(ctx context.Context, file rep } -// processEmbeddingFile processes a file with embedding status. -// It retrieves chunks from MinIO, calls the embedding pipeline, saves the embeddings into the vector database and metadata into the database, -// and updates the file status to completed in the database. +// processEmbeddingFile processes a file that is ready for embedding by: +// 1. Validating the file's process status is "EMBEDDING" +// 2. Retrieving text chunks from MinIO storage and database metadata +// - Will retry once if initial chunk retrieval fails +// +// 3. Updating file metadata with embedding pipeline version info +// - Uses TextEmbedPipelineID and TextEmbedVersion from service config +// +// 4. Calling the embedding pipeline to generate vectors from text chunks +// - Uses file creator and requester UIDs for pipeline execution +// +// 5. Saving embeddings to vector database (Milvus) and metadata to SQL database +// - Creates embeddings collection named after knowledge base UID +// - Links embeddings to source text chunks and file metadata +// +// 6. Updating file status to "COMPLETED" in database +// +// Parameters: +// - ctx: Context for the operation +// - file: KnowledgeBaseFile struct containing file metadata +// +// Returns: +// - updatedFile: Updated KnowledgeBaseFile after processing +// - nextStatus: Next file process status (COMPLETED if successful) +// - err: Error if any step fails +// +// The function handles errors at each step and returns appropriate status codes. +// If chunk retrieval fails initially, it will retry once after a 1 second delay. func (wp *fileToEmbWorkerPool) processEmbeddingFile(ctx context.Context, file repository.KnowledgeBaseFile) (updatedFile *repository.KnowledgeBaseFile, nextStatus artifactpb.FileProcessStatus, err error) { logger, _ := logger.GetZapLogger(ctx) // check the file status is embedding @@ -822,34 +847,74 @@ type MilvusEmbedding struct { } // saveEmbeddings saves embeddings into the vector database and updates the metadata in the database. +// Processes embeddings in batches of 50 to avoid timeout issues. +const batchSize = 50 + func (wp *fileToEmbWorkerPool) saveEmbeddings(ctx context.Context, kbUID string, embeddings []repository.Embedding) error { logger, _ := logger.GetZapLogger(ctx) - externalServiceCall := func(embUIDs []string) error { - // save the embeddings into vector database - milvusEmbeddings := make([]milvus.Embedding, len(embeddings)) - for i, emb := range embeddings { - milvusEmbeddings[i] = milvus.Embedding{ - SourceTable: emb.SourceTable, - SourceUID: emb.SourceUID.String(), - EmbeddingUID: emb.UID.String(), - Vector: emb.Vector, + if len(embeddings) == 0 { + logger.Debug("No embeddings to save") + return nil + } + + totalEmbeddings := len(embeddings) + + // Process embeddings in batches + for i := 0; i < totalEmbeddings; i += batchSize { + // Add context check + if err := ctx.Err(); err != nil { + return fmt.Errorf("context cancelled while processing embeddings: %w", err) + } + + end := i + batchSize + if end > totalEmbeddings { + end = totalEmbeddings + } + + currentBatch := embeddings[i:end] + + externalServiceCall := func(_ []string) error { + // save the embeddings into vector database + milvusEmbeddings := make([]milvus.Embedding, len(currentBatch)) + for j, emb := range currentBatch { + milvusEmbeddings[j] = milvus.Embedding{ + SourceTable: emb.SourceTable, + SourceUID: emb.SourceUID.String(), + EmbeddingUID: emb.UID.String(), + Vector: emb.Vector, + } } + err := wp.svc.MilvusClient.InsertVectorsToKnowledgeBaseCollection(ctx, kbUID, milvusEmbeddings) + if err != nil { + logger.Error("Failed to save embeddings batch into vector database.", + zap.String("KbUID", kbUID), + zap.Int("batch", i/batchSize+1), + zap.Int("batchSize", len(currentBatch))) + return err + } + return nil } - err := wp.svc.MilvusClient.InsertVectorsToKnowledgeBaseCollection(ctx, kbUID, milvusEmbeddings) + + _, err := wp.svc.Repository.UpsertEmbeddings(ctx, currentBatch, externalServiceCall) if err != nil { - logger.Error("Failed to save embeddings into vector database.", zap.String("KbUID", kbUID)) + logger.Error("Failed to save embeddings batch into vector database and metadata into database.", + zap.String("KbUID", kbUID), + zap.Int("batch", i/batchSize+1), + zap.Int("batchSize", len(currentBatch))) return err } - return nil - } - _, err := wp.svc.Repository.UpsertEmbeddings(ctx, embeddings, externalServiceCall) - if err != nil { - logger.Error("Failed to save embeddings into vector database and metadata into database.", zap.String("KbUID", kbUID)) - return err + + logger.Info("Embeddings batch saved successfully", + zap.String("KbUID", kbUID), + zap.Int("batch", i/batchSize+1), + zap.Int("batchSize", len(currentBatch)), + zap.Int("progress", end), + zap.Int("total", totalEmbeddings)) } - // info how many embeddings saved in which kb - logger.Info("Embeddings saved into vector database and metadata into database.", - zap.String("KbUID", kbUID), zap.Int("Embeddings count", len(embeddings))) + + logger.Info("All embeddings saved into vector database and metadata into database.", + zap.String("KbUID", kbUID), + zap.Int("total embeddings", totalEmbeddings)) return nil }