From 703bb0b23288962891922e07f6295d230dd9b780 Mon Sep 17 00:00:00 2001 From: Gary Date: Wed, 10 Jul 2024 03:55:45 +0800 Subject: [PATCH] fix(kb): fixed some bugs in file-to-embedding process (#35) Because 1. wronng source table and uid 2. wrong destinaiton 3. fail to save vector type in db This commit fixed the bug above --- .gitignore | 1 + pkg/client/grpc/pipeline_test.go | 10 +- pkg/handler/knowledgebasefiles.go | 1 - pkg/mock/repository_i_mock.gen.go | 511 +++++++++++++++++++++++++++++- pkg/repository/chunk.go | 3 +- pkg/repository/convertedfile.go | 25 +- pkg/repository/embedding.go | 65 +++- pkg/repository/embedding_test.go | 37 +++ pkg/service/pipeline.go | 12 +- pkg/worker/worker.go | 28 +- 10 files changed, 663 insertions(+), 30 deletions(-) create mode 100644 pkg/repository/embedding_test.go diff --git a/.gitignore b/.gitignore index 4e0e0f7..1f1f6a7 100644 --- a/.gitignore +++ b/.gitignore @@ -120,3 +120,4 @@ tmp /config/config_local.yaml test_.pdf test_.md +test_pdf_base64.txt diff --git a/pkg/client/grpc/pipeline_test.go b/pkg/client/grpc/pipeline_test.go index 04cd0bc..a427cb9 100644 --- a/pkg/client/grpc/pipeline_test.go +++ b/pkg/client/grpc/pipeline_test.go @@ -112,7 +112,7 @@ package grpcclient // if err != nil { // fmt.Println(err) // } -// fmt.Println("current working diretor:", dir) +// fmt.Println("current working director:", dir) // pipelinePublicGrpcConn, err := NewGRPCConn("localhost:8081", "", "") // if err != nil { // t.Fatalf("failed to create grpc connection: %v", err) @@ -123,7 +123,7 @@ package grpcclient // ctx := metadata.NewOutgoingContext(context.Background(), md) // pipelinePublicServiceClient := pipelinev1beta.NewPipelinePublicServiceClient(pipelinePublicGrpcConn) -// base64PDF, err := readPDFtoBase64("../../../test_.pdf") +// base64PDF, err := readFileToBase64("../../../test_.pdf") // if err != nil { // t.Fatalf("failed to read pdf file: %v", err) // } @@ -139,8 +139,8 @@ package grpcclient // fmt.Println("convert result\n", res.Outputs[0].GetFields()["convert_result"].GetStringValue()[:100]) // } -// // readPDFtoBase64 read the pdf file and convert it to base64 -// func readPDFtoBase64(path string) (string, error) { +// // readFileToBase64 read the pdf file and convert it to base64 +// func readFileToBase64(path string) (string, error) { // // Open the file // file, err := os.Open(path) // if err != nil { @@ -237,7 +237,7 @@ package grpcclient // if err != nil { // fmt.Println(err) // } -// fmt.Println("current working diretor:", dir) +// fmt.Println("current working director:", dir) // pipelinePublicGrpcConn, err := NewGRPCConn("localhost:8081", "", "") // if err != nil { // t.Fatalf("failed to create grpc connection: %v", err) diff --git a/pkg/handler/knowledgebasefiles.go b/pkg/handler/knowledgebasefiles.go index 800ea6e..931f73b 100644 --- a/pkg/handler/knowledgebasefiles.go +++ b/pkg/handler/knowledgebasefiles.go @@ -138,7 +138,6 @@ func checkValidFileType(t artifactpb.FileType) bool { func (ph *PublicHandler) ListKnowledgeBaseFiles(ctx context.Context, req *artifactpb.ListKnowledgeBaseFilesRequest) (*artifactpb.ListKnowledgeBaseFilesResponse, error) { log, _ := logger.GetZapLogger(ctx) - fmt.Println("ListKnowledgeBaseFiles>>>", req) uid, err := getUserUIDFromContext(ctx) if err != nil { log.Error("failed to get user id from header", zap.Error(err)) diff --git a/pkg/mock/repository_i_mock.gen.go b/pkg/mock/repository_i_mock.gen.go index 7608a4b..0b966cd 100644 --- a/pkg/mock/repository_i_mock.gen.go +++ b/pkg/mock/repository_i_mock.gen.go @@ -26,8 +26,8 @@ type RepositoryIMock struct { beforeConvertedFileTableNameCounter uint64 ConvertedFileTableNameMock mRepositoryIMockConvertedFileTableName - funcCreateConvertedFile func(ctx context.Context, cf mm_repository.ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) error) (cp1 *mm_repository.ConvertedFile, err error) - inspectFuncCreateConvertedFile func(ctx context.Context, cf mm_repository.ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) error) + funcCreateConvertedFile func(ctx context.Context, cf mm_repository.ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) (map[string]any, error)) (cp1 *mm_repository.ConvertedFile, err error) + inspectFuncCreateConvertedFile func(ctx context.Context, cf mm_repository.ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) (map[string]any, error)) afterCreateConvertedFileCounter uint64 beforeCreateConvertedFileCounter uint64 CreateConvertedFileMock mRepositoryIMockCreateConvertedFile @@ -104,6 +104,12 @@ type RepositoryIMock struct { beforeGetConvertedFileByFileUIDCounter uint64 GetConvertedFileByFileUIDMock mRepositoryIMockGetConvertedFileByFileUID + funcGetEmbeddingByUIDs func(ctx context.Context, embUIDs []uuid.UUID) (ea1 []mm_repository.Embedding, err error) + inspectFuncGetEmbeddingByUIDs func(ctx context.Context, embUIDs []uuid.UUID) + afterGetEmbeddingByUIDsCounter uint64 + beforeGetEmbeddingByUIDsCounter uint64 + GetEmbeddingByUIDsMock mRepositoryIMockGetEmbeddingByUIDs + funcGetIncompleteFile func(ctx context.Context) (ka1 []mm_repository.KnowledgeBaseFile) inspectFuncGetIncompleteFile func(ctx context.Context) afterGetIncompleteFileCounter uint64 @@ -152,6 +158,12 @@ type RepositoryIMock struct { beforeProcessKnowledgeBaseFilesCounter uint64 ProcessKnowledgeBaseFilesMock mRepositoryIMockProcessKnowledgeBaseFiles + funcTextChunkTableName func() (s1 string) + inspectFuncTextChunkTableName func() + afterTextChunkTableNameCounter uint64 + beforeTextChunkTableNameCounter uint64 + TextChunkTableNameMock mRepositoryIMockTextChunkTableName + funcUpdateKnowledgeBase func(ctx context.Context, ownerUID string, kb mm_repository.KnowledgeBase) (kp1 *mm_repository.KnowledgeBase, err error) inspectFuncUpdateKnowledgeBase func(ctx context.Context, ownerUID string, kb mm_repository.KnowledgeBase) afterUpdateKnowledgeBaseCounter uint64 @@ -226,6 +238,9 @@ func NewRepositoryIMock(t minimock.Tester) *RepositoryIMock { m.GetConvertedFileByFileUIDMock = mRepositoryIMockGetConvertedFileByFileUID{mock: m} m.GetConvertedFileByFileUIDMock.callArgs = []*RepositoryIMockGetConvertedFileByFileUIDParams{} + m.GetEmbeddingByUIDsMock = mRepositoryIMockGetEmbeddingByUIDs{mock: m} + m.GetEmbeddingByUIDsMock.callArgs = []*RepositoryIMockGetEmbeddingByUIDsParams{} + m.GetIncompleteFileMock = mRepositoryIMockGetIncompleteFile{mock: m} m.GetIncompleteFileMock.callArgs = []*RepositoryIMockGetIncompleteFileParams{} @@ -249,6 +264,8 @@ func NewRepositoryIMock(t minimock.Tester) *RepositoryIMock { m.ProcessKnowledgeBaseFilesMock = mRepositoryIMockProcessKnowledgeBaseFiles{mock: m} m.ProcessKnowledgeBaseFilesMock.callArgs = []*RepositoryIMockProcessKnowledgeBaseFilesParams{} + m.TextChunkTableNameMock = mRepositoryIMockTextChunkTableName{mock: m} + m.UpdateKnowledgeBaseMock = mRepositoryIMockUpdateKnowledgeBase{mock: m} m.UpdateKnowledgeBaseMock.callArgs = []*RepositoryIMockUpdateKnowledgeBaseParams{} @@ -453,14 +470,14 @@ type RepositoryIMockCreateConvertedFileExpectation struct { type RepositoryIMockCreateConvertedFileParams struct { ctx context.Context cf mm_repository.ConvertedFile - callExternalService func(convertedFileUID uuid.UUID) error + callExternalService func(convertedFileUID uuid.UUID) (map[string]any, error) } // RepositoryIMockCreateConvertedFileParamPtrs contains pointers to parameters of the RepositoryI.CreateConvertedFile type RepositoryIMockCreateConvertedFileParamPtrs struct { ctx *context.Context cf *mm_repository.ConvertedFile - callExternalService *func(convertedFileUID uuid.UUID) error + callExternalService *func(convertedFileUID uuid.UUID) (map[string]any, error) } // RepositoryIMockCreateConvertedFileResults contains results of the RepositoryI.CreateConvertedFile @@ -470,7 +487,7 @@ type RepositoryIMockCreateConvertedFileResults struct { } // Expect sets up expected params for RepositoryI.CreateConvertedFile -func (mmCreateConvertedFile *mRepositoryIMockCreateConvertedFile) Expect(ctx context.Context, cf mm_repository.ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) error) *mRepositoryIMockCreateConvertedFile { +func (mmCreateConvertedFile *mRepositoryIMockCreateConvertedFile) Expect(ctx context.Context, cf mm_repository.ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) (map[string]any, error)) *mRepositoryIMockCreateConvertedFile { if mmCreateConvertedFile.mock.funcCreateConvertedFile != nil { mmCreateConvertedFile.mock.t.Fatalf("RepositoryIMock.CreateConvertedFile mock is already set by Set") } @@ -538,7 +555,7 @@ func (mmCreateConvertedFile *mRepositoryIMockCreateConvertedFile) ExpectCfParam2 } // ExpectCallExternalServiceParam3 sets up expected param callExternalService for RepositoryI.CreateConvertedFile -func (mmCreateConvertedFile *mRepositoryIMockCreateConvertedFile) ExpectCallExternalServiceParam3(callExternalService func(convertedFileUID uuid.UUID) error) *mRepositoryIMockCreateConvertedFile { +func (mmCreateConvertedFile *mRepositoryIMockCreateConvertedFile) ExpectCallExternalServiceParam3(callExternalService func(convertedFileUID uuid.UUID) (map[string]any, error)) *mRepositoryIMockCreateConvertedFile { if mmCreateConvertedFile.mock.funcCreateConvertedFile != nil { mmCreateConvertedFile.mock.t.Fatalf("RepositoryIMock.CreateConvertedFile mock is already set by Set") } @@ -560,7 +577,7 @@ func (mmCreateConvertedFile *mRepositoryIMockCreateConvertedFile) ExpectCallExte } // Inspect accepts an inspector function that has same arguments as the RepositoryI.CreateConvertedFile -func (mmCreateConvertedFile *mRepositoryIMockCreateConvertedFile) Inspect(f func(ctx context.Context, cf mm_repository.ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) error)) *mRepositoryIMockCreateConvertedFile { +func (mmCreateConvertedFile *mRepositoryIMockCreateConvertedFile) Inspect(f func(ctx context.Context, cf mm_repository.ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) (map[string]any, error))) *mRepositoryIMockCreateConvertedFile { if mmCreateConvertedFile.mock.inspectFuncCreateConvertedFile != nil { mmCreateConvertedFile.mock.t.Fatalf("Inspect function is already set for RepositoryIMock.CreateConvertedFile") } @@ -584,7 +601,7 @@ func (mmCreateConvertedFile *mRepositoryIMockCreateConvertedFile) Return(cp1 *mm } // Set uses given function f to mock the RepositoryI.CreateConvertedFile method -func (mmCreateConvertedFile *mRepositoryIMockCreateConvertedFile) Set(f func(ctx context.Context, cf mm_repository.ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) error) (cp1 *mm_repository.ConvertedFile, err error)) *RepositoryIMock { +func (mmCreateConvertedFile *mRepositoryIMockCreateConvertedFile) Set(f func(ctx context.Context, cf mm_repository.ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) (map[string]any, error)) (cp1 *mm_repository.ConvertedFile, err error)) *RepositoryIMock { if mmCreateConvertedFile.defaultExpectation != nil { mmCreateConvertedFile.mock.t.Fatalf("Default expectation is already set for the RepositoryI.CreateConvertedFile method") } @@ -599,7 +616,7 @@ func (mmCreateConvertedFile *mRepositoryIMockCreateConvertedFile) Set(f func(ctx // When sets expectation for the RepositoryI.CreateConvertedFile which will trigger the result defined by the following // Then helper -func (mmCreateConvertedFile *mRepositoryIMockCreateConvertedFile) When(ctx context.Context, cf mm_repository.ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) error) *RepositoryIMockCreateConvertedFileExpectation { +func (mmCreateConvertedFile *mRepositoryIMockCreateConvertedFile) When(ctx context.Context, cf mm_repository.ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) (map[string]any, error)) *RepositoryIMockCreateConvertedFileExpectation { if mmCreateConvertedFile.mock.funcCreateConvertedFile != nil { mmCreateConvertedFile.mock.t.Fatalf("RepositoryIMock.CreateConvertedFile mock is already set by Set") } @@ -639,7 +656,7 @@ func (mmCreateConvertedFile *mRepositoryIMockCreateConvertedFile) invocationsDon } // CreateConvertedFile implements repository.RepositoryI -func (mmCreateConvertedFile *RepositoryIMock) CreateConvertedFile(ctx context.Context, cf mm_repository.ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) error) (cp1 *mm_repository.ConvertedFile, err error) { +func (mmCreateConvertedFile *RepositoryIMock) CreateConvertedFile(ctx context.Context, cf mm_repository.ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) (map[string]any, error)) (cp1 *mm_repository.ConvertedFile, err error) { mm_atomic.AddUint64(&mmCreateConvertedFile.beforeCreateConvertedFileCounter, 1) defer mm_atomic.AddUint64(&mmCreateConvertedFile.afterCreateConvertedFileCounter, 1) @@ -4639,6 +4656,311 @@ func (m *RepositoryIMock) MinimockGetConvertedFileByFileUIDInspect() { } } +type mRepositoryIMockGetEmbeddingByUIDs struct { + mock *RepositoryIMock + defaultExpectation *RepositoryIMockGetEmbeddingByUIDsExpectation + expectations []*RepositoryIMockGetEmbeddingByUIDsExpectation + + callArgs []*RepositoryIMockGetEmbeddingByUIDsParams + mutex sync.RWMutex + + expectedInvocations uint64 +} + +// RepositoryIMockGetEmbeddingByUIDsExpectation specifies expectation struct of the RepositoryI.GetEmbeddingByUIDs +type RepositoryIMockGetEmbeddingByUIDsExpectation struct { + mock *RepositoryIMock + params *RepositoryIMockGetEmbeddingByUIDsParams + paramPtrs *RepositoryIMockGetEmbeddingByUIDsParamPtrs + results *RepositoryIMockGetEmbeddingByUIDsResults + Counter uint64 +} + +// RepositoryIMockGetEmbeddingByUIDsParams contains parameters of the RepositoryI.GetEmbeddingByUIDs +type RepositoryIMockGetEmbeddingByUIDsParams struct { + ctx context.Context + embUIDs []uuid.UUID +} + +// RepositoryIMockGetEmbeddingByUIDsParamPtrs contains pointers to parameters of the RepositoryI.GetEmbeddingByUIDs +type RepositoryIMockGetEmbeddingByUIDsParamPtrs struct { + ctx *context.Context + embUIDs *[]uuid.UUID +} + +// RepositoryIMockGetEmbeddingByUIDsResults contains results of the RepositoryI.GetEmbeddingByUIDs +type RepositoryIMockGetEmbeddingByUIDsResults struct { + ea1 []mm_repository.Embedding + err error +} + +// Expect sets up expected params for RepositoryI.GetEmbeddingByUIDs +func (mmGetEmbeddingByUIDs *mRepositoryIMockGetEmbeddingByUIDs) Expect(ctx context.Context, embUIDs []uuid.UUID) *mRepositoryIMockGetEmbeddingByUIDs { + if mmGetEmbeddingByUIDs.mock.funcGetEmbeddingByUIDs != nil { + mmGetEmbeddingByUIDs.mock.t.Fatalf("RepositoryIMock.GetEmbeddingByUIDs mock is already set by Set") + } + + if mmGetEmbeddingByUIDs.defaultExpectation == nil { + mmGetEmbeddingByUIDs.defaultExpectation = &RepositoryIMockGetEmbeddingByUIDsExpectation{} + } + + if mmGetEmbeddingByUIDs.defaultExpectation.paramPtrs != nil { + mmGetEmbeddingByUIDs.mock.t.Fatalf("RepositoryIMock.GetEmbeddingByUIDs mock is already set by ExpectParams functions") + } + + mmGetEmbeddingByUIDs.defaultExpectation.params = &RepositoryIMockGetEmbeddingByUIDsParams{ctx, embUIDs} + for _, e := range mmGetEmbeddingByUIDs.expectations { + if minimock.Equal(e.params, mmGetEmbeddingByUIDs.defaultExpectation.params) { + mmGetEmbeddingByUIDs.mock.t.Fatalf("Expectation set by When has same params: %#v", *mmGetEmbeddingByUIDs.defaultExpectation.params) + } + } + + return mmGetEmbeddingByUIDs +} + +// ExpectCtxParam1 sets up expected param ctx for RepositoryI.GetEmbeddingByUIDs +func (mmGetEmbeddingByUIDs *mRepositoryIMockGetEmbeddingByUIDs) ExpectCtxParam1(ctx context.Context) *mRepositoryIMockGetEmbeddingByUIDs { + if mmGetEmbeddingByUIDs.mock.funcGetEmbeddingByUIDs != nil { + mmGetEmbeddingByUIDs.mock.t.Fatalf("RepositoryIMock.GetEmbeddingByUIDs mock is already set by Set") + } + + if mmGetEmbeddingByUIDs.defaultExpectation == nil { + mmGetEmbeddingByUIDs.defaultExpectation = &RepositoryIMockGetEmbeddingByUIDsExpectation{} + } + + if mmGetEmbeddingByUIDs.defaultExpectation.params != nil { + mmGetEmbeddingByUIDs.mock.t.Fatalf("RepositoryIMock.GetEmbeddingByUIDs mock is already set by Expect") + } + + if mmGetEmbeddingByUIDs.defaultExpectation.paramPtrs == nil { + mmGetEmbeddingByUIDs.defaultExpectation.paramPtrs = &RepositoryIMockGetEmbeddingByUIDsParamPtrs{} + } + mmGetEmbeddingByUIDs.defaultExpectation.paramPtrs.ctx = &ctx + + return mmGetEmbeddingByUIDs +} + +// ExpectEmbUIDsParam2 sets up expected param embUIDs for RepositoryI.GetEmbeddingByUIDs +func (mmGetEmbeddingByUIDs *mRepositoryIMockGetEmbeddingByUIDs) ExpectEmbUIDsParam2(embUIDs []uuid.UUID) *mRepositoryIMockGetEmbeddingByUIDs { + if mmGetEmbeddingByUIDs.mock.funcGetEmbeddingByUIDs != nil { + mmGetEmbeddingByUIDs.mock.t.Fatalf("RepositoryIMock.GetEmbeddingByUIDs mock is already set by Set") + } + + if mmGetEmbeddingByUIDs.defaultExpectation == nil { + mmGetEmbeddingByUIDs.defaultExpectation = &RepositoryIMockGetEmbeddingByUIDsExpectation{} + } + + if mmGetEmbeddingByUIDs.defaultExpectation.params != nil { + mmGetEmbeddingByUIDs.mock.t.Fatalf("RepositoryIMock.GetEmbeddingByUIDs mock is already set by Expect") + } + + if mmGetEmbeddingByUIDs.defaultExpectation.paramPtrs == nil { + mmGetEmbeddingByUIDs.defaultExpectation.paramPtrs = &RepositoryIMockGetEmbeddingByUIDsParamPtrs{} + } + mmGetEmbeddingByUIDs.defaultExpectation.paramPtrs.embUIDs = &embUIDs + + return mmGetEmbeddingByUIDs +} + +// Inspect accepts an inspector function that has same arguments as the RepositoryI.GetEmbeddingByUIDs +func (mmGetEmbeddingByUIDs *mRepositoryIMockGetEmbeddingByUIDs) Inspect(f func(ctx context.Context, embUIDs []uuid.UUID)) *mRepositoryIMockGetEmbeddingByUIDs { + if mmGetEmbeddingByUIDs.mock.inspectFuncGetEmbeddingByUIDs != nil { + mmGetEmbeddingByUIDs.mock.t.Fatalf("Inspect function is already set for RepositoryIMock.GetEmbeddingByUIDs") + } + + mmGetEmbeddingByUIDs.mock.inspectFuncGetEmbeddingByUIDs = f + + return mmGetEmbeddingByUIDs +} + +// Return sets up results that will be returned by RepositoryI.GetEmbeddingByUIDs +func (mmGetEmbeddingByUIDs *mRepositoryIMockGetEmbeddingByUIDs) Return(ea1 []mm_repository.Embedding, err error) *RepositoryIMock { + if mmGetEmbeddingByUIDs.mock.funcGetEmbeddingByUIDs != nil { + mmGetEmbeddingByUIDs.mock.t.Fatalf("RepositoryIMock.GetEmbeddingByUIDs mock is already set by Set") + } + + if mmGetEmbeddingByUIDs.defaultExpectation == nil { + mmGetEmbeddingByUIDs.defaultExpectation = &RepositoryIMockGetEmbeddingByUIDsExpectation{mock: mmGetEmbeddingByUIDs.mock} + } + mmGetEmbeddingByUIDs.defaultExpectation.results = &RepositoryIMockGetEmbeddingByUIDsResults{ea1, err} + return mmGetEmbeddingByUIDs.mock +} + +// Set uses given function f to mock the RepositoryI.GetEmbeddingByUIDs method +func (mmGetEmbeddingByUIDs *mRepositoryIMockGetEmbeddingByUIDs) Set(f func(ctx context.Context, embUIDs []uuid.UUID) (ea1 []mm_repository.Embedding, err error)) *RepositoryIMock { + if mmGetEmbeddingByUIDs.defaultExpectation != nil { + mmGetEmbeddingByUIDs.mock.t.Fatalf("Default expectation is already set for the RepositoryI.GetEmbeddingByUIDs method") + } + + if len(mmGetEmbeddingByUIDs.expectations) > 0 { + mmGetEmbeddingByUIDs.mock.t.Fatalf("Some expectations are already set for the RepositoryI.GetEmbeddingByUIDs method") + } + + mmGetEmbeddingByUIDs.mock.funcGetEmbeddingByUIDs = f + return mmGetEmbeddingByUIDs.mock +} + +// When sets expectation for the RepositoryI.GetEmbeddingByUIDs which will trigger the result defined by the following +// Then helper +func (mmGetEmbeddingByUIDs *mRepositoryIMockGetEmbeddingByUIDs) When(ctx context.Context, embUIDs []uuid.UUID) *RepositoryIMockGetEmbeddingByUIDsExpectation { + if mmGetEmbeddingByUIDs.mock.funcGetEmbeddingByUIDs != nil { + mmGetEmbeddingByUIDs.mock.t.Fatalf("RepositoryIMock.GetEmbeddingByUIDs mock is already set by Set") + } + + expectation := &RepositoryIMockGetEmbeddingByUIDsExpectation{ + mock: mmGetEmbeddingByUIDs.mock, + params: &RepositoryIMockGetEmbeddingByUIDsParams{ctx, embUIDs}, + } + mmGetEmbeddingByUIDs.expectations = append(mmGetEmbeddingByUIDs.expectations, expectation) + return expectation +} + +// Then sets up RepositoryI.GetEmbeddingByUIDs return parameters for the expectation previously defined by the When method +func (e *RepositoryIMockGetEmbeddingByUIDsExpectation) Then(ea1 []mm_repository.Embedding, err error) *RepositoryIMock { + e.results = &RepositoryIMockGetEmbeddingByUIDsResults{ea1, err} + return e.mock +} + +// Times sets number of times RepositoryI.GetEmbeddingByUIDs should be invoked +func (mmGetEmbeddingByUIDs *mRepositoryIMockGetEmbeddingByUIDs) Times(n uint64) *mRepositoryIMockGetEmbeddingByUIDs { + if n == 0 { + mmGetEmbeddingByUIDs.mock.t.Fatalf("Times of RepositoryIMock.GetEmbeddingByUIDs mock can not be zero") + } + mm_atomic.StoreUint64(&mmGetEmbeddingByUIDs.expectedInvocations, n) + return mmGetEmbeddingByUIDs +} + +func (mmGetEmbeddingByUIDs *mRepositoryIMockGetEmbeddingByUIDs) invocationsDone() bool { + if len(mmGetEmbeddingByUIDs.expectations) == 0 && mmGetEmbeddingByUIDs.defaultExpectation == nil && mmGetEmbeddingByUIDs.mock.funcGetEmbeddingByUIDs == nil { + return true + } + + totalInvocations := mm_atomic.LoadUint64(&mmGetEmbeddingByUIDs.mock.afterGetEmbeddingByUIDsCounter) + expectedInvocations := mm_atomic.LoadUint64(&mmGetEmbeddingByUIDs.expectedInvocations) + + return totalInvocations > 0 && (expectedInvocations == 0 || expectedInvocations == totalInvocations) +} + +// GetEmbeddingByUIDs implements repository.RepositoryI +func (mmGetEmbeddingByUIDs *RepositoryIMock) GetEmbeddingByUIDs(ctx context.Context, embUIDs []uuid.UUID) (ea1 []mm_repository.Embedding, err error) { + mm_atomic.AddUint64(&mmGetEmbeddingByUIDs.beforeGetEmbeddingByUIDsCounter, 1) + defer mm_atomic.AddUint64(&mmGetEmbeddingByUIDs.afterGetEmbeddingByUIDsCounter, 1) + + if mmGetEmbeddingByUIDs.inspectFuncGetEmbeddingByUIDs != nil { + mmGetEmbeddingByUIDs.inspectFuncGetEmbeddingByUIDs(ctx, embUIDs) + } + + mm_params := RepositoryIMockGetEmbeddingByUIDsParams{ctx, embUIDs} + + // Record call args + mmGetEmbeddingByUIDs.GetEmbeddingByUIDsMock.mutex.Lock() + mmGetEmbeddingByUIDs.GetEmbeddingByUIDsMock.callArgs = append(mmGetEmbeddingByUIDs.GetEmbeddingByUIDsMock.callArgs, &mm_params) + mmGetEmbeddingByUIDs.GetEmbeddingByUIDsMock.mutex.Unlock() + + for _, e := range mmGetEmbeddingByUIDs.GetEmbeddingByUIDsMock.expectations { + if minimock.Equal(*e.params, mm_params) { + mm_atomic.AddUint64(&e.Counter, 1) + return e.results.ea1, e.results.err + } + } + + if mmGetEmbeddingByUIDs.GetEmbeddingByUIDsMock.defaultExpectation != nil { + mm_atomic.AddUint64(&mmGetEmbeddingByUIDs.GetEmbeddingByUIDsMock.defaultExpectation.Counter, 1) + mm_want := mmGetEmbeddingByUIDs.GetEmbeddingByUIDsMock.defaultExpectation.params + mm_want_ptrs := mmGetEmbeddingByUIDs.GetEmbeddingByUIDsMock.defaultExpectation.paramPtrs + + mm_got := RepositoryIMockGetEmbeddingByUIDsParams{ctx, embUIDs} + + if mm_want_ptrs != nil { + + if mm_want_ptrs.ctx != nil && !minimock.Equal(*mm_want_ptrs.ctx, mm_got.ctx) { + mmGetEmbeddingByUIDs.t.Errorf("RepositoryIMock.GetEmbeddingByUIDs got unexpected parameter ctx, want: %#v, got: %#v%s\n", *mm_want_ptrs.ctx, mm_got.ctx, minimock.Diff(*mm_want_ptrs.ctx, mm_got.ctx)) + } + + if mm_want_ptrs.embUIDs != nil && !minimock.Equal(*mm_want_ptrs.embUIDs, mm_got.embUIDs) { + mmGetEmbeddingByUIDs.t.Errorf("RepositoryIMock.GetEmbeddingByUIDs got unexpected parameter embUIDs, want: %#v, got: %#v%s\n", *mm_want_ptrs.embUIDs, mm_got.embUIDs, minimock.Diff(*mm_want_ptrs.embUIDs, mm_got.embUIDs)) + } + + } else if mm_want != nil && !minimock.Equal(*mm_want, mm_got) { + mmGetEmbeddingByUIDs.t.Errorf("RepositoryIMock.GetEmbeddingByUIDs got unexpected parameters, want: %#v, got: %#v%s\n", *mm_want, mm_got, minimock.Diff(*mm_want, mm_got)) + } + + mm_results := mmGetEmbeddingByUIDs.GetEmbeddingByUIDsMock.defaultExpectation.results + if mm_results == nil { + mmGetEmbeddingByUIDs.t.Fatal("No results are set for the RepositoryIMock.GetEmbeddingByUIDs") + } + return (*mm_results).ea1, (*mm_results).err + } + if mmGetEmbeddingByUIDs.funcGetEmbeddingByUIDs != nil { + return mmGetEmbeddingByUIDs.funcGetEmbeddingByUIDs(ctx, embUIDs) + } + mmGetEmbeddingByUIDs.t.Fatalf("Unexpected call to RepositoryIMock.GetEmbeddingByUIDs. %v %v", ctx, embUIDs) + return +} + +// GetEmbeddingByUIDsAfterCounter returns a count of finished RepositoryIMock.GetEmbeddingByUIDs invocations +func (mmGetEmbeddingByUIDs *RepositoryIMock) GetEmbeddingByUIDsAfterCounter() uint64 { + return mm_atomic.LoadUint64(&mmGetEmbeddingByUIDs.afterGetEmbeddingByUIDsCounter) +} + +// GetEmbeddingByUIDsBeforeCounter returns a count of RepositoryIMock.GetEmbeddingByUIDs invocations +func (mmGetEmbeddingByUIDs *RepositoryIMock) GetEmbeddingByUIDsBeforeCounter() uint64 { + return mm_atomic.LoadUint64(&mmGetEmbeddingByUIDs.beforeGetEmbeddingByUIDsCounter) +} + +// Calls returns a list of arguments used in each call to RepositoryIMock.GetEmbeddingByUIDs. +// The list is in the same order as the calls were made (i.e. recent calls have a higher index) +func (mmGetEmbeddingByUIDs *mRepositoryIMockGetEmbeddingByUIDs) Calls() []*RepositoryIMockGetEmbeddingByUIDsParams { + mmGetEmbeddingByUIDs.mutex.RLock() + + argCopy := make([]*RepositoryIMockGetEmbeddingByUIDsParams, len(mmGetEmbeddingByUIDs.callArgs)) + copy(argCopy, mmGetEmbeddingByUIDs.callArgs) + + mmGetEmbeddingByUIDs.mutex.RUnlock() + + return argCopy +} + +// MinimockGetEmbeddingByUIDsDone returns true if the count of the GetEmbeddingByUIDs invocations corresponds +// the number of defined expectations +func (m *RepositoryIMock) MinimockGetEmbeddingByUIDsDone() bool { + for _, e := range m.GetEmbeddingByUIDsMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + return false + } + } + + return m.GetEmbeddingByUIDsMock.invocationsDone() +} + +// MinimockGetEmbeddingByUIDsInspect logs each unmet expectation +func (m *RepositoryIMock) MinimockGetEmbeddingByUIDsInspect() { + for _, e := range m.GetEmbeddingByUIDsMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + m.t.Errorf("Expected call to RepositoryIMock.GetEmbeddingByUIDs with params: %#v", *e.params) + } + } + + afterGetEmbeddingByUIDsCounter := mm_atomic.LoadUint64(&m.afterGetEmbeddingByUIDsCounter) + // if default expectation was set then invocations count should be greater than zero + if m.GetEmbeddingByUIDsMock.defaultExpectation != nil && afterGetEmbeddingByUIDsCounter < 1 { + if m.GetEmbeddingByUIDsMock.defaultExpectation.params == nil { + m.t.Error("Expected call to RepositoryIMock.GetEmbeddingByUIDs") + } else { + m.t.Errorf("Expected call to RepositoryIMock.GetEmbeddingByUIDs with params: %#v", *m.GetEmbeddingByUIDsMock.defaultExpectation.params) + } + } + // if func was set then invocations count should be greater than zero + if m.funcGetEmbeddingByUIDs != nil && afterGetEmbeddingByUIDsCounter < 1 { + m.t.Error("Expected call to RepositoryIMock.GetEmbeddingByUIDs") + } + + if !m.GetEmbeddingByUIDsMock.invocationsDone() && afterGetEmbeddingByUIDsCounter > 0 { + m.t.Errorf("Expected %d calls to RepositoryIMock.GetEmbeddingByUIDs but found %d calls", + mm_atomic.LoadUint64(&m.GetEmbeddingByUIDsMock.expectedInvocations), afterGetEmbeddingByUIDsCounter) + } +} + type mRepositoryIMockGetIncompleteFile struct { mock *RepositoryIMock defaultExpectation *RepositoryIMockGetIncompleteFileExpectation @@ -7106,6 +7428,169 @@ func (m *RepositoryIMock) MinimockProcessKnowledgeBaseFilesInspect() { } } +type mRepositoryIMockTextChunkTableName struct { + mock *RepositoryIMock + defaultExpectation *RepositoryIMockTextChunkTableNameExpectation + expectations []*RepositoryIMockTextChunkTableNameExpectation + + expectedInvocations uint64 +} + +// RepositoryIMockTextChunkTableNameExpectation specifies expectation struct of the RepositoryI.TextChunkTableName +type RepositoryIMockTextChunkTableNameExpectation struct { + mock *RepositoryIMock + + results *RepositoryIMockTextChunkTableNameResults + Counter uint64 +} + +// RepositoryIMockTextChunkTableNameResults contains results of the RepositoryI.TextChunkTableName +type RepositoryIMockTextChunkTableNameResults struct { + s1 string +} + +// Expect sets up expected params for RepositoryI.TextChunkTableName +func (mmTextChunkTableName *mRepositoryIMockTextChunkTableName) Expect() *mRepositoryIMockTextChunkTableName { + if mmTextChunkTableName.mock.funcTextChunkTableName != nil { + mmTextChunkTableName.mock.t.Fatalf("RepositoryIMock.TextChunkTableName mock is already set by Set") + } + + if mmTextChunkTableName.defaultExpectation == nil { + mmTextChunkTableName.defaultExpectation = &RepositoryIMockTextChunkTableNameExpectation{} + } + + return mmTextChunkTableName +} + +// Inspect accepts an inspector function that has same arguments as the RepositoryI.TextChunkTableName +func (mmTextChunkTableName *mRepositoryIMockTextChunkTableName) Inspect(f func()) *mRepositoryIMockTextChunkTableName { + if mmTextChunkTableName.mock.inspectFuncTextChunkTableName != nil { + mmTextChunkTableName.mock.t.Fatalf("Inspect function is already set for RepositoryIMock.TextChunkTableName") + } + + mmTextChunkTableName.mock.inspectFuncTextChunkTableName = f + + return mmTextChunkTableName +} + +// Return sets up results that will be returned by RepositoryI.TextChunkTableName +func (mmTextChunkTableName *mRepositoryIMockTextChunkTableName) Return(s1 string) *RepositoryIMock { + if mmTextChunkTableName.mock.funcTextChunkTableName != nil { + mmTextChunkTableName.mock.t.Fatalf("RepositoryIMock.TextChunkTableName mock is already set by Set") + } + + if mmTextChunkTableName.defaultExpectation == nil { + mmTextChunkTableName.defaultExpectation = &RepositoryIMockTextChunkTableNameExpectation{mock: mmTextChunkTableName.mock} + } + mmTextChunkTableName.defaultExpectation.results = &RepositoryIMockTextChunkTableNameResults{s1} + return mmTextChunkTableName.mock +} + +// Set uses given function f to mock the RepositoryI.TextChunkTableName method +func (mmTextChunkTableName *mRepositoryIMockTextChunkTableName) Set(f func() (s1 string)) *RepositoryIMock { + if mmTextChunkTableName.defaultExpectation != nil { + mmTextChunkTableName.mock.t.Fatalf("Default expectation is already set for the RepositoryI.TextChunkTableName method") + } + + if len(mmTextChunkTableName.expectations) > 0 { + mmTextChunkTableName.mock.t.Fatalf("Some expectations are already set for the RepositoryI.TextChunkTableName method") + } + + mmTextChunkTableName.mock.funcTextChunkTableName = f + return mmTextChunkTableName.mock +} + +// Times sets number of times RepositoryI.TextChunkTableName should be invoked +func (mmTextChunkTableName *mRepositoryIMockTextChunkTableName) Times(n uint64) *mRepositoryIMockTextChunkTableName { + if n == 0 { + mmTextChunkTableName.mock.t.Fatalf("Times of RepositoryIMock.TextChunkTableName mock can not be zero") + } + mm_atomic.StoreUint64(&mmTextChunkTableName.expectedInvocations, n) + return mmTextChunkTableName +} + +func (mmTextChunkTableName *mRepositoryIMockTextChunkTableName) invocationsDone() bool { + if len(mmTextChunkTableName.expectations) == 0 && mmTextChunkTableName.defaultExpectation == nil && mmTextChunkTableName.mock.funcTextChunkTableName == nil { + return true + } + + totalInvocations := mm_atomic.LoadUint64(&mmTextChunkTableName.mock.afterTextChunkTableNameCounter) + expectedInvocations := mm_atomic.LoadUint64(&mmTextChunkTableName.expectedInvocations) + + return totalInvocations > 0 && (expectedInvocations == 0 || expectedInvocations == totalInvocations) +} + +// TextChunkTableName implements repository.RepositoryI +func (mmTextChunkTableName *RepositoryIMock) TextChunkTableName() (s1 string) { + mm_atomic.AddUint64(&mmTextChunkTableName.beforeTextChunkTableNameCounter, 1) + defer mm_atomic.AddUint64(&mmTextChunkTableName.afterTextChunkTableNameCounter, 1) + + if mmTextChunkTableName.inspectFuncTextChunkTableName != nil { + mmTextChunkTableName.inspectFuncTextChunkTableName() + } + + if mmTextChunkTableName.TextChunkTableNameMock.defaultExpectation != nil { + mm_atomic.AddUint64(&mmTextChunkTableName.TextChunkTableNameMock.defaultExpectation.Counter, 1) + + mm_results := mmTextChunkTableName.TextChunkTableNameMock.defaultExpectation.results + if mm_results == nil { + mmTextChunkTableName.t.Fatal("No results are set for the RepositoryIMock.TextChunkTableName") + } + return (*mm_results).s1 + } + if mmTextChunkTableName.funcTextChunkTableName != nil { + return mmTextChunkTableName.funcTextChunkTableName() + } + mmTextChunkTableName.t.Fatalf("Unexpected call to RepositoryIMock.TextChunkTableName.") + return +} + +// TextChunkTableNameAfterCounter returns a count of finished RepositoryIMock.TextChunkTableName invocations +func (mmTextChunkTableName *RepositoryIMock) TextChunkTableNameAfterCounter() uint64 { + return mm_atomic.LoadUint64(&mmTextChunkTableName.afterTextChunkTableNameCounter) +} + +// TextChunkTableNameBeforeCounter returns a count of RepositoryIMock.TextChunkTableName invocations +func (mmTextChunkTableName *RepositoryIMock) TextChunkTableNameBeforeCounter() uint64 { + return mm_atomic.LoadUint64(&mmTextChunkTableName.beforeTextChunkTableNameCounter) +} + +// MinimockTextChunkTableNameDone returns true if the count of the TextChunkTableName invocations corresponds +// the number of defined expectations +func (m *RepositoryIMock) MinimockTextChunkTableNameDone() bool { + for _, e := range m.TextChunkTableNameMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + return false + } + } + + return m.TextChunkTableNameMock.invocationsDone() +} + +// MinimockTextChunkTableNameInspect logs each unmet expectation +func (m *RepositoryIMock) MinimockTextChunkTableNameInspect() { + for _, e := range m.TextChunkTableNameMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + m.t.Error("Expected call to RepositoryIMock.TextChunkTableName") + } + } + + afterTextChunkTableNameCounter := mm_atomic.LoadUint64(&m.afterTextChunkTableNameCounter) + // if default expectation was set then invocations count should be greater than zero + if m.TextChunkTableNameMock.defaultExpectation != nil && afterTextChunkTableNameCounter < 1 { + m.t.Error("Expected call to RepositoryIMock.TextChunkTableName") + } + // if func was set then invocations count should be greater than zero + if m.funcTextChunkTableName != nil && afterTextChunkTableNameCounter < 1 { + m.t.Error("Expected call to RepositoryIMock.TextChunkTableName") + } + + if !m.TextChunkTableNameMock.invocationsDone() && afterTextChunkTableNameCounter > 0 { + m.t.Errorf("Expected %d calls to RepositoryIMock.TextChunkTableName but found %d calls", + mm_atomic.LoadUint64(&m.TextChunkTableNameMock.expectedInvocations), afterTextChunkTableNameCounter) + } +} + type mRepositoryIMockUpdateKnowledgeBase struct { mock *RepositoryIMock defaultExpectation *RepositoryIMockUpdateKnowledgeBaseExpectation @@ -8442,6 +8927,8 @@ func (m *RepositoryIMock) MinimockFinish() { m.MinimockGetConvertedFileByFileUIDInspect() + m.MinimockGetEmbeddingByUIDsInspect() + m.MinimockGetIncompleteFileInspect() m.MinimockGetKnowledgeBaseByOwnerAndIDInspect() @@ -8458,6 +8945,8 @@ func (m *RepositoryIMock) MinimockFinish() { m.MinimockProcessKnowledgeBaseFilesInspect() + m.MinimockTextChunkTableNameInspect() + m.MinimockUpdateKnowledgeBaseInspect() m.MinimockUpdateKnowledgeBaseFileInspect() @@ -8503,6 +8992,7 @@ func (m *RepositoryIMock) minimockDone() bool { m.MinimockDeleteKnowledgeBaseFileDone() && m.MinimockDeleteRepositoryTagDone() && m.MinimockGetConvertedFileByFileUIDDone() && + m.MinimockGetEmbeddingByUIDsDone() && m.MinimockGetIncompleteFileDone() && m.MinimockGetKnowledgeBaseByOwnerAndIDDone() && m.MinimockGetRepositoryTagDone() && @@ -8511,6 +9001,7 @@ func (m *RepositoryIMock) minimockDone() bool { m.MinimockListKnowledgeBaseFilesDone() && m.MinimockListKnowledgeBasesDone() && m.MinimockProcessKnowledgeBaseFilesDone() && + m.MinimockTextChunkTableNameDone() && m.MinimockUpdateKnowledgeBaseDone() && m.MinimockUpdateKnowledgeBaseFileDone() && m.MinimockUpsertEmbeddingsDone() && diff --git a/pkg/repository/chunk.go b/pkg/repository/chunk.go index 2d9e2f9..eacbf56 100644 --- a/pkg/repository/chunk.go +++ b/pkg/repository/chunk.go @@ -11,6 +11,7 @@ import ( ) type TextChunkI interface { + TextChunkTableName() string DeleteAndCreateChunks(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []*TextChunk, externalServiceCall func(chunkUIDs []string) (map[string]any, error)) ([]*TextChunk, error) DeleteChunksBySource(ctx context.Context, sourceTable string, sourceUID uuid.UUID) error DeleteChunksByUIDs(ctx context.Context, chunkUIDs []uuid.UUID) error @@ -63,7 +64,7 @@ var TextChunkColumn = TextChunkColumns{ } // TableName returns the table name of the TextChunk -func (TextChunk) TableName() string { +func (r *Repository) TextChunkTableName() string { return "text_chunk" } diff --git a/pkg/repository/convertedfile.go b/pkg/repository/convertedfile.go index 6e851c0..ce0f525 100644 --- a/pkg/repository/convertedfile.go +++ b/pkg/repository/convertedfile.go @@ -12,7 +12,7 @@ import ( type ConvertedFileI interface { ConvertedFileTableName() string - CreateConvertedFile(ctx context.Context, cf ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) error) (*ConvertedFile, error) + CreateConvertedFile(ctx context.Context, cf ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) (map[string]any, error)) (*ConvertedFile, error) DeleteConvertedFile(ctx context.Context, uid uuid.UUID) error GetConvertedFileByFileUID(ctx context.Context, fileUID uuid.UUID) (*ConvertedFile, error) } @@ -58,7 +58,7 @@ func (r *Repository) ConvertedFileTableName() string { return "converted_file" } -func (r *Repository) CreateConvertedFile(ctx context.Context, cf ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) error) (*ConvertedFile, error) { +func (r *Repository) CreateConvertedFile(ctx context.Context, cf ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) (map[string]any, error)) (*ConvertedFile, error) { err := r.db.Transaction(func(tx *gorm.DB) error { // Check if file_uid exists var existingFile ConvertedFile @@ -88,9 +88,17 @@ func (r *Repository) CreateConvertedFile(ctx context.Context, cf ConvertedFile, if callExternalService != nil { // Call the external service using the created record's UID - if err := callExternalService(cf.UID); err != nil { + if output, err := callExternalService(cf.UID); err != nil { // If the external service returns an error, return the error to trigger a rollback return err + } else { + // get dest from output and update the record + if dest, ok := output[ConvertedFileColumn.Destination].(string); ok { + update := map[string]any{ConvertedFileColumn.Destination: dest} + if err := tx.Model(&cf).Updates(update).Error; err != nil { + return err + } + } } } @@ -113,7 +121,6 @@ func (r *Repository) GetConvertedFileByFileUID(ctx context.Context, fileUID uuid return &cf, nil } - // DeleteConvertedFile deletes the record by UID func (r *Repository) DeleteConvertedFile(ctx context.Context, uid uuid.UUID) error { err := r.db.Transaction(func(tx *gorm.DB) error { @@ -129,3 +136,13 @@ func (r *Repository) DeleteConvertedFile(ctx context.Context, uid uuid.UUID) err } return nil } + +// UpdateConvertedFile updates the record by UID using update map. +func (r *Repository) UpdateConvertedFile(ctx context.Context, uid uuid.UUID, update map[string]any) error { + // Specify the condition to find the record by its UID + where := fmt.Sprintf("%s = ?", ConvertedFileColumn.UID) + if err := r.db.WithContext(ctx).Model(&ConvertedFile{}).Where(where, uid).Updates(update).Error; err != nil { + return err + } + return nil +} diff --git a/pkg/repository/embedding.go b/pkg/repository/embedding.go index 5c79da5..af471ac 100644 --- a/pkg/repository/embedding.go +++ b/pkg/repository/embedding.go @@ -2,6 +2,8 @@ package repository import ( "context" + "database/sql/driver" + "encoding/json" "fmt" "time" @@ -16,17 +18,68 @@ type EmbeddingI interface { UpsertEmbeddings(ctx context.Context, embeddings []Embedding, externalServiceCall func(embUIDs []string) error) ([]Embedding, error) DeleteEmbeddingsBySource(ctx context.Context, sourceTable string, sourceUID uuid.UUID) error DeleteEmbeddingsByUIDs(ctx context.Context, embUIDs []uuid.UUID) error + // GetEmbeddingByUIDs fetches embeddings by their UIDs. + GetEmbeddingByUIDs(ctx context.Context, embUIDs []uuid.UUID) ([]Embedding, error) } type Embedding struct { UID uuid.UUID `gorm:"column:uid;type:uuid;default:gen_random_uuid();primaryKey" json:"uid"` SourceUID uuid.UUID `gorm:"column:source_uid;type:uuid;not null" json:"source_uid"` SourceTable string `gorm:"column:source_table;size:255;not null" json:"source_table"` - Vector []float32 `gorm:"column:vector;type:jsonb;not null" json:"vector"` + Vector Vector `gorm:"column:vector;type:jsonb;not null" json:"vector"` Collection string `gorm:"column:collection;size:255;not null" json:"collection"` CreateTime *time.Time `gorm:"column:create_time;not null;default:CURRENT_TIMESTAMP" json:"create_time"` UpdateTime *time.Time `gorm:"column:update_time;not null;default:CURRENT_TIMESTAMP" json:"update_time"` } +type Vector []float32 + +func (v Vector) Value() (driver.Value, error) { + if v == nil { + return nil, nil + } + r, err := json.Marshal(v) + if err != nil { + return nil, err + } + return string(r), nil +} + +func (v *Vector) Scan(value interface{}) error { + if value == nil { + *v = nil + return nil + } + + b, ok := value.([]byte) + if !ok { + return fmt.Errorf("type assertion to []byte failed") + } + + return json.Unmarshal(b, v) +} + +// MarshalJSON implements the json.Marshaler interface +func (v Vector) MarshalJSON() ([]byte, error) { + if v == nil { + return []byte("null"), nil + } + return json.Marshal([]float32(v)) +} + +// UnmarshalJSON implements the json.Unmarshaler interface +func (v *Vector) UnmarshalJSON(data []byte) error { + if string(data) == "null" { + *v = nil + return nil + } + var slice []float32 + if err := json.Unmarshal(data, &slice); err != nil { + return err + } + *v = Vector(slice) + return nil +} + type EmbeddingColumns struct { UID string SourceUID string @@ -106,3 +159,13 @@ func (r *Repository) DeleteEmbeddingsByUIDs(ctx context.Context, embUIDs []uuid. where := fmt.Sprintf("%s IN (?)", EmbeddingColumn.UID) return r.db.WithContext(ctx).Where(where, embUIDs).Delete(&Embedding{}).Error } + +// GetEmbeddingByUIDs fetches embeddings by their UIDs. +func (r *Repository) GetEmbeddingByUIDs(ctx context.Context, embUIDs []uuid.UUID) ([]Embedding, error) { + var embeddings []Embedding + where := fmt.Sprintf("%s IN (?)", EmbeddingColumn.UID) + if err := r.db.WithContext(ctx).Where(where, embUIDs).Find(&embeddings).Error; err != nil { + return nil, err + } + return embeddings, nil +} diff --git a/pkg/repository/embedding_test.go b/pkg/repository/embedding_test.go new file mode 100644 index 0000000..2500f78 --- /dev/null +++ b/pkg/repository/embedding_test.go @@ -0,0 +1,37 @@ +package repository + +// import ( +// "context" +// "fmt" +// "os" +// "testing" + +// "github.com/google/uuid" +// "github.com/instill-ai/artifact-backend/config" +// "github.com/instill-ai/artifact-backend/pkg/db" +// ) + +// func TestGetEmbeddingByUIDs(t *testing.T) { +// // set file flag +// // fs := flag.NewFlagSet(os.Args[0], flag.ExitOnError) +// os.Args = []string{"", "-file", "../../config/config_local.yaml"} +// config.Init() +// // get db connection +// db := db.GetConnection() +// // get repository +// repo := NewRepository(db) +// // get embeddings +// uid := "006db525-ad0f-4951-8dd0-d226156b789b" +// // turn uid into uuid +// uidUUID, err := uuid.Parse(uid) +// if err != nil { +// t.Fatalf("Failed to parse uid: %v", err) +// } + +// embeddings, err := repo.GetEmbeddingByUIDs(context.TODO(), []uuid.UUID{uidUUID}) +// if err != nil { +// t.Fatalf("Failed to get embeddings: %v", err) +// } +// fmt.Println(embeddings) + +// } diff --git a/pkg/service/pipeline.go b/pkg/service/pipeline.go index 57516a9..8bdff24 100644 --- a/pkg/service/pipeline.go +++ b/pkg/service/pipeline.go @@ -4,16 +4,22 @@ import ( "context" "errors" + "github.com/google/uuid" + "github.com/instill-ai/artifact-backend/pkg/logger" pipelinev1beta "github.com/instill-ai/protogen-go/vdp/pipeline/v1beta" + "go.uber.org/zap" "google.golang.org/grpc/metadata" "google.golang.org/protobuf/types/known/structpb" ) -// ConcertPDFToMD using converting pipeline to convert PDF to MD and consume caller's credits -func (s *Service) ConcertPDFToMD(ctx context.Context, caller uuid.UUID, pdfBase64 string) (string, error) { + +// ConvertPDFToMD using converting pipeline to convert PDF to MD and consume caller's credits +func (s *Service) ConvertPDFToMD(ctx context.Context, caller uuid.UUID, pdfBase64 string) (string, error) { + logger, _ := logger.GetZapLogger(ctx) md := metadata.New(map[string]string{"Instill-User-Uid": caller.String(), "Instill-Auth-Type": "user"}) ctx = metadata.NewOutgoingContext(ctx, md) + req := &pipelinev1beta.TriggerOrganizationPipelineReleaseRequest{ Name: "organizations/preset/pipelines/indexing-convert-pdf/releases/v1.0.0", Inputs: []*structpb.Struct{ @@ -26,10 +32,12 @@ func (s *Service) ConcertPDFToMD(ctx context.Context, caller uuid.UUID, pdfBase6 } resp, err := s.PipelinePub.TriggerOrganizationPipelineRelease(ctx, req) if err != nil { + logger.Error("failed to trigger pipeline", zap.Error(err)) return "", err } result, err := getConvertResult(resp) if err != nil { + logger.Error("failed to get convert result", zap.Error(err)) return "", err } return result, nil diff --git a/pkg/worker/worker.go b/pkg/worker/worker.go index 2cd449a..e15a9e5 100644 --- a/pkg/worker/worker.go +++ b/pkg/worker/worker.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "fmt" + "runtime/debug" "sync" "time" @@ -110,6 +111,19 @@ func (wp *fileToEmbWorkerPool) startWorker(ctx context.Context, workerID int) { logger, _ := logger.GetZapLogger(ctx) logger.Info("Worker started", zap.Int("WorkerID", workerID)) defer wp.wg.Done() + // Defer a function to catch panics + defer func() { + if r := recover(); r != nil { + logger.Error("Panic recovered in worker", + zap.Int("WorkerID", workerID), + zap.Any("panic", r), + zap.String("stack", string(debug.Stack()))) + // Start a new worker + logger.Info("Restarting worker after panic", zap.Int("WorkerID", workerID)) + wp.wg.Add(1) + go wp.startWorker(ctx, workerID) + } + }() for { select { case <-ctx.Done(): @@ -352,7 +366,7 @@ func (wp *fileToEmbWorkerPool) processConvertingFile(ctx context.Context, file r base64Data := base64.StdEncoding.EncodeToString(data) // convert the pdf file to md - convertedMD, err := wp.svc.ConcertPDFToMD(ctx, file.CreatorUID, base64Data) + convertedMD, err := wp.svc.ConvertPDFToMD(ctx, file.CreatorUID, base64Data) if err != nil { logger.Error("Failed to convert pdf to md.", zap.String("File path", fileInMinIOPath)) return nil, artifactpb.FileProcessStatus_FILE_PROCESS_STATUS_UNSPECIFIED, err @@ -565,8 +579,8 @@ func (wp *fileToEmbWorkerPool) processEmbeddingFile(ctx context.Context, file re embeddings := make([]repository.Embedding, len(vectors)) for i, v := range vectors { embeddings[i] = repository.Embedding{ - SourceUID: sourceUID, - SourceTable: sourceTable, + SourceTable: wp.svc.Repository.TextChunkTableName(), + SourceUID: chunks[i].UID, Vector: v, Collection: collection, } @@ -602,13 +616,15 @@ func (wp *fileToEmbWorkerPool) saveConvertedFile(ctx context.Context, kbUID, fil _, err := wp.svc.Repository.CreateConvertedFile( ctx, repository.ConvertedFile{KbUID: kbUID, FileUID: fileUID, Name: name, Type: "text/markdown", Destination: "destination"}, - func(convertedFileUID uuid.UUID) error { + func(convertedFileUID uuid.UUID) (map[string]any, error) { // save the converted file into object storage err := wp.svc.MinIO.SaveConvertedFile(ctx, kbUID.String(), convertedFileUID.String(), "md", convertedFile) if err != nil { - return err + return nil, err } - return nil + output := make(map[string]any) + output[repository.ConvertedFileColumn.Destination] = wp.svc.MinIO.GetConvertedFilePathInKnowledgeBase(kbUID.String(), convertedFileUID.String(), "md") + return output, nil }) if err != nil { logger.Error("Failed to save converted file into object storage and metadata into database.", zap.String("FileUID", fileUID.String()))