From 4b490d4313a1d74f52c8400463ad2bf93bbced41 Mon Sep 17 00:00:00 2001 From: "gary.y" Date: Fri, 29 Nov 2024 16:40:03 +0800 Subject: [PATCH] feat(artifact): implement search chunks and sources --- go.mod | 2 +- go.sum | 4 +- pkg/handler/chunks.go | 164 ++++++++++++-- pkg/handler/knowledgebase.go | 2 +- pkg/mock/repository_i_mock.gen.go | 317 ++++++++++++++++++++++++++++ pkg/repository/knowledgebase.go | 11 + pkg/repository/knowledgebasefile.go | 23 +- pkg/service/permission.go | 12 ++ 8 files changed, 507 insertions(+), 28 deletions(-) diff --git a/go.mod b/go.mod index e1ac530..98e33c8 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 github.com/influxdata/influxdb-client-go/v2 v2.12.3 - github.com/instill-ai/protogen-go v0.3.3-alpha.0.20241125163328-c29704e47ba4 + github.com/instill-ai/protogen-go v0.3.3-alpha.0.20241129082755-59b3c0c34fe0 github.com/instill-ai/usage-client v0.3.0-alpha.0.20240319060111-4a3a39f2fd61 github.com/instill-ai/x v0.3.0-alpha.0.20231219052200-6230a89e386c github.com/knadh/koanf v1.5.0 diff --git a/go.sum b/go.sum index a78749a..3f14fbe 100644 --- a/go.sum +++ b/go.sum @@ -344,8 +344,8 @@ github.com/influxdata/influxdb-client-go/v2 v2.12.3 h1:28nRlNMRIV4QbtIUvxhWqaxn0 github.com/influxdata/influxdb-client-go/v2 v2.12.3/go.mod h1:IrrLUbCjjfkmRuaCiGQg4m2GbkaeJDcuWoxiWdQEbA0= github.com/influxdata/line-protocol v0.0.0-20200327222509-2487e7298839 h1:W9WBk7wlPfJLvMCdtV4zPulc4uCPrlywQOmbFOhgQNU= github.com/influxdata/line-protocol v0.0.0-20200327222509-2487e7298839/go.mod h1:xaLFMmpvUxqXtVkUJfg9QmT88cDaCJ3ZKgdZ78oO8Qo= -github.com/instill-ai/protogen-go v0.3.3-alpha.0.20241125163328-c29704e47ba4 h1:k8X9gMiCwHWShB1FITaWwmlzthFnor1Jj0tSaFG+9x8= -github.com/instill-ai/protogen-go v0.3.3-alpha.0.20241125163328-c29704e47ba4/go.mod h1:rf0UY7VpEgpaLudYEcjx5rnbuwlBaaLyD4FQmWLtgAY= +github.com/instill-ai/protogen-go v0.3.3-alpha.0.20241129082755-59b3c0c34fe0 h1:Fok/s7GQoNMUA++1WbDdiZ6Ut8AXSuWNkTU2Q/0G9QA= +github.com/instill-ai/protogen-go v0.3.3-alpha.0.20241129082755-59b3c0c34fe0/go.mod h1:rf0UY7VpEgpaLudYEcjx5rnbuwlBaaLyD4FQmWLtgAY= github.com/instill-ai/usage-client v0.3.0-alpha.0.20240319060111-4a3a39f2fd61 h1:smPTvmXDhn/QC7y/TPXyMTqbbRd0gvzmFgWBChwTfhE= github.com/instill-ai/usage-client v0.3.0-alpha.0.20240319060111-4a3a39f2fd61/go.mod h1:/TAHs4ybuylk5icuy+MQtHRc4XUnIyXzeNKxX9qDFhw= github.com/instill-ai/x v0.3.0-alpha.0.20231219052200-6230a89e386c h1:a2RVkpIV2QcrGnSHAou+t/L+vBsaIfFvk5inVg5Uh4s= diff --git a/pkg/handler/chunks.go b/pkg/handler/chunks.go index ebb9e00..e2a5471 100644 --- a/pkg/handler/chunks.go +++ b/pkg/handler/chunks.go @@ -15,6 +15,19 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) +// convertToProtoChunk +func convertToProtoChunk(chunk repository.TextChunk) *artifactpb.Chunk { + return &artifactpb.Chunk{ + ChunkUid: chunk.UID.String(), + Retrievable: chunk.Retrievable, + StartPos: uint32(chunk.StartPos), + EndPos: uint32(chunk.EndPos), + Tokens: uint32(chunk.Tokens), + CreateTime: timestamppb.New(*chunk.CreateTime), + OriginalFileUid: chunk.KbFileUID.String(), + } +} + func (ph *PublicHandler) ListChunks(ctx context.Context, req *artifactpb.ListChunksRequest) (*artifactpb.ListChunksResponse, error) { log, _ := logger.GetZapLogger(ctx) authUID, err := getUserUIDFromContext(ctx) @@ -73,15 +86,7 @@ func (ph *PublicHandler) ListChunks(ctx context.Context, req *artifactpb.ListChu res := make([]*artifactpb.Chunk, 0, len(chunks)) for _, chunk := range chunks { - res = append(res, &artifactpb.Chunk{ - ChunkUid: chunk.UID.String(), - Retrievable: chunk.Retrievable, - StartPos: uint32(chunk.StartPos), - EndPos: uint32(chunk.EndPos), - Tokens: uint32(chunk.Tokens), - CreateTime: timestamppb.New(*chunk.CreateTime), - OriginalFileUid: kbf.UID.String(), - }) + res = append(res, convertToProtoChunk(chunk)) } return &artifactpb.ListChunksResponse{ @@ -89,6 +94,69 @@ func (ph *PublicHandler) ListChunks(ctx context.Context, req *artifactpb.ListChu }, nil } +func (ph *PublicHandler) SearchChunks(ctx context.Context, req *artifactpb.SearchChunksRequest) (*artifactpb.SearchChunksResponse, error) { + log, _ := logger.GetZapLogger(ctx) + _, err := getUserUIDFromContext(ctx) + if err != nil { + log.Error("failed to get user id from header", zap.Error(err)) + return nil, fmt.Errorf("failed to get user id from header: %v. err: %w", err, customerror.ErrUnauthenticated) + } + // check if user can access the namespace + ns, err := ph.service.GetNamespaceAndCheckPermission(ctx, req.NamespaceId) + if err != nil { + log.Error("failed to get namespace and check permission", zap.Error(err)) + return nil, fmt.Errorf("failed to get namespace and check permission: %w", err) + } + + chunkUIDs := make([]uuid.UUID, 0, len(req.ChunkUids)) + for _, chunkUID := range req.ChunkUids { + chunkUID, err := uuid.FromString(chunkUID) + if err != nil { + log.Error("failed to parse chunk uid", zap.Error(err)) + return nil, fmt.Errorf("failed to parse chunk uid: %w", err) + } + chunkUIDs = append(chunkUIDs, chunkUID) + } + // check if the chunkUIs is more than 20 + if len(chunkUIDs) > 25 { + log.Error("chunk uids is more than 20", zap.Int("chunk_uids_count", len(chunkUIDs))) + return nil, fmt.Errorf("chunk uids is more than 20") + } + chunks, err := ph.service.Repository.GetChunksByUIDs(ctx, chunkUIDs) + if err != nil { + log.Error("failed to get chunks by uids", zap.Error(err)) + return nil, fmt.Errorf("failed to get chunks by uids: %w", err) + } + + // get the kbUIDs from chunks + kbUIDs := make([]uuid.UUID, 0, len(chunks)) + for _, chunk := range chunks { + kbUIDs = append(kbUIDs, chunk.KbUID) + } + // use kbUIDs to get the knowledge bases + knowledgeBases, err := ph.service.Repository.GetKnowledgeBasesByUIDs(ctx, kbUIDs) + if err != nil { + log.Error("failed to get knowledge bases by uids", zap.Error(err)) + return nil, fmt.Errorf("failed to get knowledge bases by uids: %w", err) + } + // check if the chunks's knowledge base's owner(namespace uid) is the same as namespace uuid in path + for _, knowledgeBase := range knowledgeBases { + if knowledgeBase.Owner != ns.NsUID.String() { + log.Error("chunks's namespace is not the same as namespace in path", zap.String("namespace_id_in_path", ns.NsUID.String()), zap.String("namespace_id_in_chunks", knowledgeBase.Owner)) + return nil, fmt.Errorf("chunks's namespace is not the same as namespace in path") + } + } + + // populate the response + protoChunks := make([]*artifactpb.Chunk, 0, len(chunks)) + for _, chunk := range chunks { + protoChunks = append(protoChunks, convertToProtoChunk(chunk)) + } + return &artifactpb.SearchChunksResponse{ + Chunks: protoChunks, + }, nil +} + func (ph *PublicHandler) UpdateChunk(ctx context.Context, req *artifactpb.UpdateChunkRequest) (*artifactpb.UpdateChunkResponse, error) { log, _ := logger.GetZapLogger(ctx) authUID, err := getUserUIDFromContext(ctx) @@ -131,15 +199,7 @@ func (ph *PublicHandler) UpdateChunk(ctx context.Context, req *artifactpb.Update return &artifactpb.UpdateChunkResponse{ // Populate the response fields appropriately - Chunk: &artifactpb.Chunk{ - ChunkUid: chunk.UID.String(), - Retrievable: chunk.Retrievable, - StartPos: uint32(chunk.StartPos), - EndPos: uint32(chunk.EndPos), - Tokens: uint32(chunk.Tokens), - CreateTime: timestamppb.New(*chunk.CreateTime), - // OriginalFileUid: chunk.FileUID.String(), - }, + Chunk: convertToProtoChunk(*chunk), }, nil } @@ -189,3 +249,71 @@ func (ph *PublicHandler) GetSourceFile(ctx context.Context, req *artifactpb.GetS }, }, nil } + +// SearchSourceFiles +func (ph *PublicHandler) SearchSourceFiles(ctx context.Context, req *artifactpb.SearchSourceFilesRequest) (*artifactpb.SearchSourceFilesResponse, error) { + log, _ := logger.GetZapLogger(ctx) + authUID, err := getUserUIDFromContext(ctx) + if err != nil { + log.Error("failed to get user id from header", zap.Error(err)) + return nil, fmt.Errorf("failed to get user id from header: %v. err: %w", err, customerror.ErrUnauthenticated) + } + + // Check if user can access the namespace + _, err = ph.service.GetNamespaceAndCheckPermission(ctx, req.NamespaceId) + if err != nil { + log.Error("failed to get namespace and check permission", zap.Error(err)) + return nil, fmt.Errorf("failed to get namespace and check permission: %w", err) + } + + fileUIDs := make([]uuid.UUID, 0, len(req.FileUids)) + for _, fileUID := range req.FileUids { + uid, err := uuid.FromString(fileUID) + if err != nil { + log.Error("failed to parse file uid", zap.Error(err)) + return nil, fmt.Errorf("failed to parse file uid: %v. err: %w", err, customerror.ErrInvalidArgument) + } + fileUIDs = append(fileUIDs, uid) + } + + sources := make([]*artifactpb.SourceFile, 0, len(fileUIDs)) + for _, fileUID := range fileUIDs { + source, err := ph.service.Repository.GetTruthSourceByFileUID(ctx, fileUID) + if err != nil { + log.Error("failed to get truth source by file uid", zap.Error(err)) + return nil, fmt.Errorf("failed to get truth source by file uid. err: %w", err) + } + + // ACL check for each source file + granted, err := ph.service.ACLClient.CheckPermission(ctx, "knowledgebase", source.KbUID, "reader") + if err != nil { + log.Error("failed to check permission", zap.Error(err)) + return nil, fmt.Errorf("failed to check permission. err: %w", err) + } + if !granted { + log.Error("no permission to access source file", + zap.String("user_uid", authUID), + zap.String("kb_uid", source.KbUID.String())) + return nil, fmt.Errorf("no permission to access source file. err: %w. user_uid: %s. kb_uid: %s", customerror.ErrNoPermission, authUID, source.KbUID.String()) + } + + // Get file content from MinIO + content, err := ph.service.MinIO.GetFile(ctx, minio.KnowledgeBaseBucketName, source.Dest) + if err != nil { + log.Error("failed to get file from minio", zap.Error(err)) + continue + } + + sources = append(sources, &artifactpb.SourceFile{ + OriginalFileUid: source.OriginalFileUID.String(), + OriginalFileName: source.OriginalFileName, + Content: string(content), + CreateTime: timestamppb.New(source.CreateTime), + UpdateTime: timestamppb.New(source.UpdateTime), + }) + } + + return &artifactpb.SearchSourceFilesResponse{ + SourceFiles: sources, + }, nil +} diff --git a/pkg/handler/knowledgebase.go b/pkg/handler/knowledgebase.go index dba916f..acc89f6 100644 --- a/pkg/handler/knowledgebase.go +++ b/pkg/handler/knowledgebase.go @@ -41,7 +41,7 @@ func (ph *PublicHandler) CreateCatalog(ctx context.Context, req *artifactpb.Crea ns, err := ph.service.GetNamespaceByNsID(ctx, req.GetNamespaceId()) if err != nil { log.Error( - "failed to check namespace permission", + "failed to get namespace", zap.Error(err), zap.String("owner_id(ns_id)", req.GetNamespaceId()), zap.String("auth_uid", authUID)) diff --git a/pkg/mock/repository_i_mock.gen.go b/pkg/mock/repository_i_mock.gen.go index f560272..87a44ae 100644 --- a/pkg/mock/repository_i_mock.gen.go +++ b/pkg/mock/repository_i_mock.gen.go @@ -196,6 +196,12 @@ type RepositoryIMock struct { beforeGetKnowledgeBaseFilesByFileUIDsCounter uint64 GetKnowledgeBaseFilesByFileUIDsMock mRepositoryIMockGetKnowledgeBaseFilesByFileUIDs + funcGetKnowledgeBasesByUIDs func(ctx context.Context, kbUIDs []uuid.UUID) (ka1 []mm_repository.KnowledgeBase, err error) + inspectFuncGetKnowledgeBasesByUIDs func(ctx context.Context, kbUIDs []uuid.UUID) + afterGetKnowledgeBasesByUIDsCounter uint64 + beforeGetKnowledgeBasesByUIDsCounter uint64 + GetKnowledgeBasesByUIDsMock mRepositoryIMockGetKnowledgeBasesByUIDs + funcGetKnowledgebaseFileByKbUIDAndFileID func(ctx context.Context, kbUID uuid.UUID, fileID string) (kp1 *mm_repository.KnowledgeBaseFile, err error) inspectFuncGetKnowledgebaseFileByKbUIDAndFileID func(ctx context.Context, kbUID uuid.UUID, fileID string) afterGetKnowledgebaseFileByKbUIDAndFileIDCounter uint64 @@ -531,6 +537,9 @@ func NewRepositoryIMock(t minimock.Tester) *RepositoryIMock { m.GetKnowledgeBaseFilesByFileUIDsMock = mRepositoryIMockGetKnowledgeBaseFilesByFileUIDs{mock: m} m.GetKnowledgeBaseFilesByFileUIDsMock.callArgs = []*RepositoryIMockGetKnowledgeBaseFilesByFileUIDsParams{} + m.GetKnowledgeBasesByUIDsMock = mRepositoryIMockGetKnowledgeBasesByUIDs{mock: m} + m.GetKnowledgeBasesByUIDsMock.callArgs = []*RepositoryIMockGetKnowledgeBasesByUIDsParams{} + m.GetKnowledgebaseFileByKbUIDAndFileIDMock = mRepositoryIMockGetKnowledgebaseFileByKbUIDAndFileID{mock: m} m.GetKnowledgebaseFileByKbUIDAndFileIDMock.callArgs = []*RepositoryIMockGetKnowledgebaseFileByKbUIDAndFileIDParams{} @@ -9397,6 +9406,311 @@ func (m *RepositoryIMock) MinimockGetKnowledgeBaseFilesByFileUIDsInspect() { } } +type mRepositoryIMockGetKnowledgeBasesByUIDs struct { + mock *RepositoryIMock + defaultExpectation *RepositoryIMockGetKnowledgeBasesByUIDsExpectation + expectations []*RepositoryIMockGetKnowledgeBasesByUIDsExpectation + + callArgs []*RepositoryIMockGetKnowledgeBasesByUIDsParams + mutex sync.RWMutex + + expectedInvocations uint64 +} + +// RepositoryIMockGetKnowledgeBasesByUIDsExpectation specifies expectation struct of the RepositoryI.GetKnowledgeBasesByUIDs +type RepositoryIMockGetKnowledgeBasesByUIDsExpectation struct { + mock *RepositoryIMock + params *RepositoryIMockGetKnowledgeBasesByUIDsParams + paramPtrs *RepositoryIMockGetKnowledgeBasesByUIDsParamPtrs + results *RepositoryIMockGetKnowledgeBasesByUIDsResults + Counter uint64 +} + +// RepositoryIMockGetKnowledgeBasesByUIDsParams contains parameters of the RepositoryI.GetKnowledgeBasesByUIDs +type RepositoryIMockGetKnowledgeBasesByUIDsParams struct { + ctx context.Context + kbUIDs []uuid.UUID +} + +// RepositoryIMockGetKnowledgeBasesByUIDsParamPtrs contains pointers to parameters of the RepositoryI.GetKnowledgeBasesByUIDs +type RepositoryIMockGetKnowledgeBasesByUIDsParamPtrs struct { + ctx *context.Context + kbUIDs *[]uuid.UUID +} + +// RepositoryIMockGetKnowledgeBasesByUIDsResults contains results of the RepositoryI.GetKnowledgeBasesByUIDs +type RepositoryIMockGetKnowledgeBasesByUIDsResults struct { + ka1 []mm_repository.KnowledgeBase + err error +} + +// Expect sets up expected params for RepositoryI.GetKnowledgeBasesByUIDs +func (mmGetKnowledgeBasesByUIDs *mRepositoryIMockGetKnowledgeBasesByUIDs) Expect(ctx context.Context, kbUIDs []uuid.UUID) *mRepositoryIMockGetKnowledgeBasesByUIDs { + if mmGetKnowledgeBasesByUIDs.mock.funcGetKnowledgeBasesByUIDs != nil { + mmGetKnowledgeBasesByUIDs.mock.t.Fatalf("RepositoryIMock.GetKnowledgeBasesByUIDs mock is already set by Set") + } + + if mmGetKnowledgeBasesByUIDs.defaultExpectation == nil { + mmGetKnowledgeBasesByUIDs.defaultExpectation = &RepositoryIMockGetKnowledgeBasesByUIDsExpectation{} + } + + if mmGetKnowledgeBasesByUIDs.defaultExpectation.paramPtrs != nil { + mmGetKnowledgeBasesByUIDs.mock.t.Fatalf("RepositoryIMock.GetKnowledgeBasesByUIDs mock is already set by ExpectParams functions") + } + + mmGetKnowledgeBasesByUIDs.defaultExpectation.params = &RepositoryIMockGetKnowledgeBasesByUIDsParams{ctx, kbUIDs} + for _, e := range mmGetKnowledgeBasesByUIDs.expectations { + if minimock.Equal(e.params, mmGetKnowledgeBasesByUIDs.defaultExpectation.params) { + mmGetKnowledgeBasesByUIDs.mock.t.Fatalf("Expectation set by When has same params: %#v", *mmGetKnowledgeBasesByUIDs.defaultExpectation.params) + } + } + + return mmGetKnowledgeBasesByUIDs +} + +// ExpectCtxParam1 sets up expected param ctx for RepositoryI.GetKnowledgeBasesByUIDs +func (mmGetKnowledgeBasesByUIDs *mRepositoryIMockGetKnowledgeBasesByUIDs) ExpectCtxParam1(ctx context.Context) *mRepositoryIMockGetKnowledgeBasesByUIDs { + if mmGetKnowledgeBasesByUIDs.mock.funcGetKnowledgeBasesByUIDs != nil { + mmGetKnowledgeBasesByUIDs.mock.t.Fatalf("RepositoryIMock.GetKnowledgeBasesByUIDs mock is already set by Set") + } + + if mmGetKnowledgeBasesByUIDs.defaultExpectation == nil { + mmGetKnowledgeBasesByUIDs.defaultExpectation = &RepositoryIMockGetKnowledgeBasesByUIDsExpectation{} + } + + if mmGetKnowledgeBasesByUIDs.defaultExpectation.params != nil { + mmGetKnowledgeBasesByUIDs.mock.t.Fatalf("RepositoryIMock.GetKnowledgeBasesByUIDs mock is already set by Expect") + } + + if mmGetKnowledgeBasesByUIDs.defaultExpectation.paramPtrs == nil { + mmGetKnowledgeBasesByUIDs.defaultExpectation.paramPtrs = &RepositoryIMockGetKnowledgeBasesByUIDsParamPtrs{} + } + mmGetKnowledgeBasesByUIDs.defaultExpectation.paramPtrs.ctx = &ctx + + return mmGetKnowledgeBasesByUIDs +} + +// ExpectKbUIDsParam2 sets up expected param kbUIDs for RepositoryI.GetKnowledgeBasesByUIDs +func (mmGetKnowledgeBasesByUIDs *mRepositoryIMockGetKnowledgeBasesByUIDs) ExpectKbUIDsParam2(kbUIDs []uuid.UUID) *mRepositoryIMockGetKnowledgeBasesByUIDs { + if mmGetKnowledgeBasesByUIDs.mock.funcGetKnowledgeBasesByUIDs != nil { + mmGetKnowledgeBasesByUIDs.mock.t.Fatalf("RepositoryIMock.GetKnowledgeBasesByUIDs mock is already set by Set") + } + + if mmGetKnowledgeBasesByUIDs.defaultExpectation == nil { + mmGetKnowledgeBasesByUIDs.defaultExpectation = &RepositoryIMockGetKnowledgeBasesByUIDsExpectation{} + } + + if mmGetKnowledgeBasesByUIDs.defaultExpectation.params != nil { + mmGetKnowledgeBasesByUIDs.mock.t.Fatalf("RepositoryIMock.GetKnowledgeBasesByUIDs mock is already set by Expect") + } + + if mmGetKnowledgeBasesByUIDs.defaultExpectation.paramPtrs == nil { + mmGetKnowledgeBasesByUIDs.defaultExpectation.paramPtrs = &RepositoryIMockGetKnowledgeBasesByUIDsParamPtrs{} + } + mmGetKnowledgeBasesByUIDs.defaultExpectation.paramPtrs.kbUIDs = &kbUIDs + + return mmGetKnowledgeBasesByUIDs +} + +// Inspect accepts an inspector function that has same arguments as the RepositoryI.GetKnowledgeBasesByUIDs +func (mmGetKnowledgeBasesByUIDs *mRepositoryIMockGetKnowledgeBasesByUIDs) Inspect(f func(ctx context.Context, kbUIDs []uuid.UUID)) *mRepositoryIMockGetKnowledgeBasesByUIDs { + if mmGetKnowledgeBasesByUIDs.mock.inspectFuncGetKnowledgeBasesByUIDs != nil { + mmGetKnowledgeBasesByUIDs.mock.t.Fatalf("Inspect function is already set for RepositoryIMock.GetKnowledgeBasesByUIDs") + } + + mmGetKnowledgeBasesByUIDs.mock.inspectFuncGetKnowledgeBasesByUIDs = f + + return mmGetKnowledgeBasesByUIDs +} + +// Return sets up results that will be returned by RepositoryI.GetKnowledgeBasesByUIDs +func (mmGetKnowledgeBasesByUIDs *mRepositoryIMockGetKnowledgeBasesByUIDs) Return(ka1 []mm_repository.KnowledgeBase, err error) *RepositoryIMock { + if mmGetKnowledgeBasesByUIDs.mock.funcGetKnowledgeBasesByUIDs != nil { + mmGetKnowledgeBasesByUIDs.mock.t.Fatalf("RepositoryIMock.GetKnowledgeBasesByUIDs mock is already set by Set") + } + + if mmGetKnowledgeBasesByUIDs.defaultExpectation == nil { + mmGetKnowledgeBasesByUIDs.defaultExpectation = &RepositoryIMockGetKnowledgeBasesByUIDsExpectation{mock: mmGetKnowledgeBasesByUIDs.mock} + } + mmGetKnowledgeBasesByUIDs.defaultExpectation.results = &RepositoryIMockGetKnowledgeBasesByUIDsResults{ka1, err} + return mmGetKnowledgeBasesByUIDs.mock +} + +// Set uses given function f to mock the RepositoryI.GetKnowledgeBasesByUIDs method +func (mmGetKnowledgeBasesByUIDs *mRepositoryIMockGetKnowledgeBasesByUIDs) Set(f func(ctx context.Context, kbUIDs []uuid.UUID) (ka1 []mm_repository.KnowledgeBase, err error)) *RepositoryIMock { + if mmGetKnowledgeBasesByUIDs.defaultExpectation != nil { + mmGetKnowledgeBasesByUIDs.mock.t.Fatalf("Default expectation is already set for the RepositoryI.GetKnowledgeBasesByUIDs method") + } + + if len(mmGetKnowledgeBasesByUIDs.expectations) > 0 { + mmGetKnowledgeBasesByUIDs.mock.t.Fatalf("Some expectations are already set for the RepositoryI.GetKnowledgeBasesByUIDs method") + } + + mmGetKnowledgeBasesByUIDs.mock.funcGetKnowledgeBasesByUIDs = f + return mmGetKnowledgeBasesByUIDs.mock +} + +// When sets expectation for the RepositoryI.GetKnowledgeBasesByUIDs which will trigger the result defined by the following +// Then helper +func (mmGetKnowledgeBasesByUIDs *mRepositoryIMockGetKnowledgeBasesByUIDs) When(ctx context.Context, kbUIDs []uuid.UUID) *RepositoryIMockGetKnowledgeBasesByUIDsExpectation { + if mmGetKnowledgeBasesByUIDs.mock.funcGetKnowledgeBasesByUIDs != nil { + mmGetKnowledgeBasesByUIDs.mock.t.Fatalf("RepositoryIMock.GetKnowledgeBasesByUIDs mock is already set by Set") + } + + expectation := &RepositoryIMockGetKnowledgeBasesByUIDsExpectation{ + mock: mmGetKnowledgeBasesByUIDs.mock, + params: &RepositoryIMockGetKnowledgeBasesByUIDsParams{ctx, kbUIDs}, + } + mmGetKnowledgeBasesByUIDs.expectations = append(mmGetKnowledgeBasesByUIDs.expectations, expectation) + return expectation +} + +// Then sets up RepositoryI.GetKnowledgeBasesByUIDs return parameters for the expectation previously defined by the When method +func (e *RepositoryIMockGetKnowledgeBasesByUIDsExpectation) Then(ka1 []mm_repository.KnowledgeBase, err error) *RepositoryIMock { + e.results = &RepositoryIMockGetKnowledgeBasesByUIDsResults{ka1, err} + return e.mock +} + +// Times sets number of times RepositoryI.GetKnowledgeBasesByUIDs should be invoked +func (mmGetKnowledgeBasesByUIDs *mRepositoryIMockGetKnowledgeBasesByUIDs) Times(n uint64) *mRepositoryIMockGetKnowledgeBasesByUIDs { + if n == 0 { + mmGetKnowledgeBasesByUIDs.mock.t.Fatalf("Times of RepositoryIMock.GetKnowledgeBasesByUIDs mock can not be zero") + } + mm_atomic.StoreUint64(&mmGetKnowledgeBasesByUIDs.expectedInvocations, n) + return mmGetKnowledgeBasesByUIDs +} + +func (mmGetKnowledgeBasesByUIDs *mRepositoryIMockGetKnowledgeBasesByUIDs) invocationsDone() bool { + if len(mmGetKnowledgeBasesByUIDs.expectations) == 0 && mmGetKnowledgeBasesByUIDs.defaultExpectation == nil && mmGetKnowledgeBasesByUIDs.mock.funcGetKnowledgeBasesByUIDs == nil { + return true + } + + totalInvocations := mm_atomic.LoadUint64(&mmGetKnowledgeBasesByUIDs.mock.afterGetKnowledgeBasesByUIDsCounter) + expectedInvocations := mm_atomic.LoadUint64(&mmGetKnowledgeBasesByUIDs.expectedInvocations) + + return totalInvocations > 0 && (expectedInvocations == 0 || expectedInvocations == totalInvocations) +} + +// GetKnowledgeBasesByUIDs implements repository.RepositoryI +func (mmGetKnowledgeBasesByUIDs *RepositoryIMock) GetKnowledgeBasesByUIDs(ctx context.Context, kbUIDs []uuid.UUID) (ka1 []mm_repository.KnowledgeBase, err error) { + mm_atomic.AddUint64(&mmGetKnowledgeBasesByUIDs.beforeGetKnowledgeBasesByUIDsCounter, 1) + defer mm_atomic.AddUint64(&mmGetKnowledgeBasesByUIDs.afterGetKnowledgeBasesByUIDsCounter, 1) + + if mmGetKnowledgeBasesByUIDs.inspectFuncGetKnowledgeBasesByUIDs != nil { + mmGetKnowledgeBasesByUIDs.inspectFuncGetKnowledgeBasesByUIDs(ctx, kbUIDs) + } + + mm_params := RepositoryIMockGetKnowledgeBasesByUIDsParams{ctx, kbUIDs} + + // Record call args + mmGetKnowledgeBasesByUIDs.GetKnowledgeBasesByUIDsMock.mutex.Lock() + mmGetKnowledgeBasesByUIDs.GetKnowledgeBasesByUIDsMock.callArgs = append(mmGetKnowledgeBasesByUIDs.GetKnowledgeBasesByUIDsMock.callArgs, &mm_params) + mmGetKnowledgeBasesByUIDs.GetKnowledgeBasesByUIDsMock.mutex.Unlock() + + for _, e := range mmGetKnowledgeBasesByUIDs.GetKnowledgeBasesByUIDsMock.expectations { + if minimock.Equal(*e.params, mm_params) { + mm_atomic.AddUint64(&e.Counter, 1) + return e.results.ka1, e.results.err + } + } + + if mmGetKnowledgeBasesByUIDs.GetKnowledgeBasesByUIDsMock.defaultExpectation != nil { + mm_atomic.AddUint64(&mmGetKnowledgeBasesByUIDs.GetKnowledgeBasesByUIDsMock.defaultExpectation.Counter, 1) + mm_want := mmGetKnowledgeBasesByUIDs.GetKnowledgeBasesByUIDsMock.defaultExpectation.params + mm_want_ptrs := mmGetKnowledgeBasesByUIDs.GetKnowledgeBasesByUIDsMock.defaultExpectation.paramPtrs + + mm_got := RepositoryIMockGetKnowledgeBasesByUIDsParams{ctx, kbUIDs} + + if mm_want_ptrs != nil { + + if mm_want_ptrs.ctx != nil && !minimock.Equal(*mm_want_ptrs.ctx, mm_got.ctx) { + mmGetKnowledgeBasesByUIDs.t.Errorf("RepositoryIMock.GetKnowledgeBasesByUIDs got unexpected parameter ctx, want: %#v, got: %#v%s\n", *mm_want_ptrs.ctx, mm_got.ctx, minimock.Diff(*mm_want_ptrs.ctx, mm_got.ctx)) + } + + if mm_want_ptrs.kbUIDs != nil && !minimock.Equal(*mm_want_ptrs.kbUIDs, mm_got.kbUIDs) { + mmGetKnowledgeBasesByUIDs.t.Errorf("RepositoryIMock.GetKnowledgeBasesByUIDs got unexpected parameter kbUIDs, want: %#v, got: %#v%s\n", *mm_want_ptrs.kbUIDs, mm_got.kbUIDs, minimock.Diff(*mm_want_ptrs.kbUIDs, mm_got.kbUIDs)) + } + + } else if mm_want != nil && !minimock.Equal(*mm_want, mm_got) { + mmGetKnowledgeBasesByUIDs.t.Errorf("RepositoryIMock.GetKnowledgeBasesByUIDs got unexpected parameters, want: %#v, got: %#v%s\n", *mm_want, mm_got, minimock.Diff(*mm_want, mm_got)) + } + + mm_results := mmGetKnowledgeBasesByUIDs.GetKnowledgeBasesByUIDsMock.defaultExpectation.results + if mm_results == nil { + mmGetKnowledgeBasesByUIDs.t.Fatal("No results are set for the RepositoryIMock.GetKnowledgeBasesByUIDs") + } + return (*mm_results).ka1, (*mm_results).err + } + if mmGetKnowledgeBasesByUIDs.funcGetKnowledgeBasesByUIDs != nil { + return mmGetKnowledgeBasesByUIDs.funcGetKnowledgeBasesByUIDs(ctx, kbUIDs) + } + mmGetKnowledgeBasesByUIDs.t.Fatalf("Unexpected call to RepositoryIMock.GetKnowledgeBasesByUIDs. %v %v", ctx, kbUIDs) + return +} + +// GetKnowledgeBasesByUIDsAfterCounter returns a count of finished RepositoryIMock.GetKnowledgeBasesByUIDs invocations +func (mmGetKnowledgeBasesByUIDs *RepositoryIMock) GetKnowledgeBasesByUIDsAfterCounter() uint64 { + return mm_atomic.LoadUint64(&mmGetKnowledgeBasesByUIDs.afterGetKnowledgeBasesByUIDsCounter) +} + +// GetKnowledgeBasesByUIDsBeforeCounter returns a count of RepositoryIMock.GetKnowledgeBasesByUIDs invocations +func (mmGetKnowledgeBasesByUIDs *RepositoryIMock) GetKnowledgeBasesByUIDsBeforeCounter() uint64 { + return mm_atomic.LoadUint64(&mmGetKnowledgeBasesByUIDs.beforeGetKnowledgeBasesByUIDsCounter) +} + +// Calls returns a list of arguments used in each call to RepositoryIMock.GetKnowledgeBasesByUIDs. +// The list is in the same order as the calls were made (i.e. recent calls have a higher index) +func (mmGetKnowledgeBasesByUIDs *mRepositoryIMockGetKnowledgeBasesByUIDs) Calls() []*RepositoryIMockGetKnowledgeBasesByUIDsParams { + mmGetKnowledgeBasesByUIDs.mutex.RLock() + + argCopy := make([]*RepositoryIMockGetKnowledgeBasesByUIDsParams, len(mmGetKnowledgeBasesByUIDs.callArgs)) + copy(argCopy, mmGetKnowledgeBasesByUIDs.callArgs) + + mmGetKnowledgeBasesByUIDs.mutex.RUnlock() + + return argCopy +} + +// MinimockGetKnowledgeBasesByUIDsDone returns true if the count of the GetKnowledgeBasesByUIDs invocations corresponds +// the number of defined expectations +func (m *RepositoryIMock) MinimockGetKnowledgeBasesByUIDsDone() bool { + for _, e := range m.GetKnowledgeBasesByUIDsMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + return false + } + } + + return m.GetKnowledgeBasesByUIDsMock.invocationsDone() +} + +// MinimockGetKnowledgeBasesByUIDsInspect logs each unmet expectation +func (m *RepositoryIMock) MinimockGetKnowledgeBasesByUIDsInspect() { + for _, e := range m.GetKnowledgeBasesByUIDsMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + m.t.Errorf("Expected call to RepositoryIMock.GetKnowledgeBasesByUIDs with params: %#v", *e.params) + } + } + + afterGetKnowledgeBasesByUIDsCounter := mm_atomic.LoadUint64(&m.afterGetKnowledgeBasesByUIDsCounter) + // if default expectation was set then invocations count should be greater than zero + if m.GetKnowledgeBasesByUIDsMock.defaultExpectation != nil && afterGetKnowledgeBasesByUIDsCounter < 1 { + if m.GetKnowledgeBasesByUIDsMock.defaultExpectation.params == nil { + m.t.Error("Expected call to RepositoryIMock.GetKnowledgeBasesByUIDs") + } else { + m.t.Errorf("Expected call to RepositoryIMock.GetKnowledgeBasesByUIDs with params: %#v", *m.GetKnowledgeBasesByUIDsMock.defaultExpectation.params) + } + } + // if func was set then invocations count should be greater than zero + if m.funcGetKnowledgeBasesByUIDs != nil && afterGetKnowledgeBasesByUIDsCounter < 1 { + m.t.Error("Expected call to RepositoryIMock.GetKnowledgeBasesByUIDs") + } + + if !m.GetKnowledgeBasesByUIDsMock.invocationsDone() && afterGetKnowledgeBasesByUIDsCounter > 0 { + m.t.Errorf("Expected %d calls to RepositoryIMock.GetKnowledgeBasesByUIDs but found %d calls", + mm_atomic.LoadUint64(&m.GetKnowledgeBasesByUIDsMock.expectedInvocations), afterGetKnowledgeBasesByUIDsCounter) + } +} + type mRepositoryIMockGetKnowledgebaseFileByKbUIDAndFileID struct { mock *RepositoryIMock defaultExpectation *RepositoryIMockGetKnowledgebaseFileByKbUIDAndFileIDExpectation @@ -21801,6 +22115,8 @@ func (m *RepositoryIMock) MinimockFinish() { m.MinimockGetKnowledgeBaseFilesByFileUIDsInspect() + m.MinimockGetKnowledgeBasesByUIDsInspect() + m.MinimockGetKnowledgebaseFileByKbUIDAndFileIDInspect() m.MinimockGetNeedProcessFilesInspect() @@ -21930,6 +22246,7 @@ func (m *RepositoryIMock) minimockDone() bool { m.MinimockGetKnowledgeBaseByOwnerAndKbIDDone() && m.MinimockGetKnowledgeBaseCountByOwnerDone() && m.MinimockGetKnowledgeBaseFilesByFileUIDsDone() && + m.MinimockGetKnowledgeBasesByUIDsDone() && m.MinimockGetKnowledgebaseFileByKbUIDAndFileIDDone() && m.MinimockGetNeedProcessFilesDone() && m.MinimockGetObjectByUIDDone() && diff --git a/pkg/repository/knowledgebase.go b/pkg/repository/knowledgebase.go index db710ed..016d083 100644 --- a/pkg/repository/knowledgebase.go +++ b/pkg/repository/knowledgebase.go @@ -23,6 +23,7 @@ type KnowledgeBaseI interface { GetKnowledgeBaseByOwnerAndKbID(ctx context.Context, ownerUID uuid.UUID, kbID string) (*KnowledgeBase, error) GetKnowledgeBaseCountByOwner(ctx context.Context, ownerUID string, catalogType artifactpb.CatalogType) (int64, error) IncreaseKnowledgeBaseUsage(ctx context.Context, tx *gorm.DB, kbUID string, amount int) error + GetKnowledgeBasesByUIDs(ctx context.Context, kbUIDs []uuid.UUID) ([]KnowledgeBase, error) } type KnowledgeBase struct { @@ -297,3 +298,13 @@ func (r *Repository) IncreaseKnowledgeBaseUsage(ctx context.Context, tx *gorm.DB } return nil } + +// get the knowledge bases by uids +func (r *Repository) GetKnowledgeBasesByUIDs(ctx context.Context, kbUIDs []uuid.UUID) ([]KnowledgeBase, error) { + var knowledgeBases []KnowledgeBase + whereString := fmt.Sprintf("%v IN (?) AND %v IS NULL", KnowledgeBaseColumn.UID, KnowledgeBaseColumn.DeleteTime) + if err := r.db.WithContext(ctx).Where(whereString, kbUIDs).Find(&knowledgeBases).Error; err != nil { + return nil, err + } + return knowledgeBases, nil +} diff --git a/pkg/repository/knowledgebasefile.go b/pkg/repository/knowledgebasefile.go index df7692d..fe77392 100644 --- a/pkg/repository/knowledgebasefile.go +++ b/pkg/repository/knowledgebasefile.go @@ -559,9 +559,12 @@ func (r *Repository) GetKnowledgebaseFileByKbUIDAndFileID(ctx context.Context, k } type SourceMeta struct { - KbUID uuid.UUID - Dest string - CreateTime time.Time + OriginalFileUID uuid.UUID + OriginalFileName string + KbUID uuid.UUID + Dest string + CreateTime time.Time + UpdateTime time.Time } // GetTruthSourceByFileUID returns the truth source file destination of minIO by file UID @@ -579,15 +582,19 @@ func (r *Repository) GetTruthSourceByFileUID(ctx context.Context, fileUID uuid.U return nil, err } // assign truth source file destination and create time + originalFileUID := file.UID + originalFileName := file.Name var kbUID uuid.UUID var dest string var createTime time.Time + var updateTime time.Time switch file.Type { // if the file type is text or markdown, the destination is the file destination case artifactpb.FileType_FILE_TYPE_TEXT.String(), artifactpb.FileType_FILE_TYPE_MARKDOWN.String(): kbUID = file.KnowledgeBaseUID dest = file.Destination createTime = *file.CreateTime + updateTime = *file.UpdateTime // if the file type is pdf, get the converted file destination case artifactpb.FileType_FILE_TYPE_PDF.String(), artifactpb.FileType_FILE_TYPE_HTML.String(), @@ -613,12 +620,16 @@ func (r *Repository) GetTruthSourceByFileUID(ctx context.Context, fileUID uuid.U kbUID = convertedFile.KbUID dest = convertedFile.Destination createTime = *convertedFile.CreateTime + updateTime = *convertedFile.UpdateTime } return &SourceMeta{ - Dest: dest, - CreateTime: createTime, - KbUID: kbUID, + OriginalFileUID: originalFileUID, + OriginalFileName: originalFileName, + Dest: dest, + CreateTime: createTime, + UpdateTime: updateTime, + KbUID: kbUID, }, nil } diff --git a/pkg/service/permission.go b/pkg/service/permission.go index 254d861..9f30e14 100644 --- a/pkg/service/permission.go +++ b/pkg/service/permission.go @@ -44,3 +44,15 @@ func (s *Service) CheckCatalogUserPermission(ctx context.Context, nsID, catalogI return ns, catalog, nil } + +func (s *Service) GetNamespaceAndCheckPermission(ctx context.Context, nsID string) (*resource.Namespace, error) { + ns, err := s.GetNamespaceByNsID(ctx, nsID) + if err != nil { + return nil, err + } + err = s.CheckNamespacePermission(ctx, ns) + if err != nil { + return nil, err + } + return ns, nil +}