Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(artifact): implement search chunks and sources #133

Merged
merged 1 commit into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require (
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1
github.com/influxdata/influxdb-client-go/v2 v2.12.3
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20241125163328-c29704e47ba4
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20241129082755-59b3c0c34fe0
github.com/instill-ai/usage-client v0.3.0-alpha.0.20240319060111-4a3a39f2fd61
github.com/instill-ai/x v0.3.0-alpha.0.20231219052200-6230a89e386c
github.com/knadh/koanf v1.5.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ github.com/influxdata/influxdb-client-go/v2 v2.12.3 h1:28nRlNMRIV4QbtIUvxhWqaxn0
github.com/influxdata/influxdb-client-go/v2 v2.12.3/go.mod h1:IrrLUbCjjfkmRuaCiGQg4m2GbkaeJDcuWoxiWdQEbA0=
github.com/influxdata/line-protocol v0.0.0-20200327222509-2487e7298839 h1:W9WBk7wlPfJLvMCdtV4zPulc4uCPrlywQOmbFOhgQNU=
github.com/influxdata/line-protocol v0.0.0-20200327222509-2487e7298839/go.mod h1:xaLFMmpvUxqXtVkUJfg9QmT88cDaCJ3ZKgdZ78oO8Qo=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20241125163328-c29704e47ba4 h1:k8X9gMiCwHWShB1FITaWwmlzthFnor1Jj0tSaFG+9x8=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20241125163328-c29704e47ba4/go.mod h1:rf0UY7VpEgpaLudYEcjx5rnbuwlBaaLyD4FQmWLtgAY=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20241129082755-59b3c0c34fe0 h1:Fok/s7GQoNMUA++1WbDdiZ6Ut8AXSuWNkTU2Q/0G9QA=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20241129082755-59b3c0c34fe0/go.mod h1:rf0UY7VpEgpaLudYEcjx5rnbuwlBaaLyD4FQmWLtgAY=
github.com/instill-ai/usage-client v0.3.0-alpha.0.20240319060111-4a3a39f2fd61 h1:smPTvmXDhn/QC7y/TPXyMTqbbRd0gvzmFgWBChwTfhE=
github.com/instill-ai/usage-client v0.3.0-alpha.0.20240319060111-4a3a39f2fd61/go.mod h1:/TAHs4ybuylk5icuy+MQtHRc4XUnIyXzeNKxX9qDFhw=
github.com/instill-ai/x v0.3.0-alpha.0.20231219052200-6230a89e386c h1:a2RVkpIV2QcrGnSHAou+t/L+vBsaIfFvk5inVg5Uh4s=
Expand Down
164 changes: 146 additions & 18 deletions pkg/handler/chunks.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
)

// convertToProtoChunk
func convertToProtoChunk(chunk repository.TextChunk) *artifactpb.Chunk {
return &artifactpb.Chunk{
ChunkUid: chunk.UID.String(),
Retrievable: chunk.Retrievable,
StartPos: uint32(chunk.StartPos),
EndPos: uint32(chunk.EndPos),
Tokens: uint32(chunk.Tokens),
CreateTime: timestamppb.New(*chunk.CreateTime),
OriginalFileUid: chunk.KbFileUID.String(),
}
}

func (ph *PublicHandler) ListChunks(ctx context.Context, req *artifactpb.ListChunksRequest) (*artifactpb.ListChunksResponse, error) {
log, _ := logger.GetZapLogger(ctx)
authUID, err := getUserUIDFromContext(ctx)
Expand Down Expand Up @@ -73,22 +86,77 @@ func (ph *PublicHandler) ListChunks(ctx context.Context, req *artifactpb.ListChu

res := make([]*artifactpb.Chunk, 0, len(chunks))
for _, chunk := range chunks {
res = append(res, &artifactpb.Chunk{
ChunkUid: chunk.UID.String(),
Retrievable: chunk.Retrievable,
StartPos: uint32(chunk.StartPos),
EndPos: uint32(chunk.EndPos),
Tokens: uint32(chunk.Tokens),
CreateTime: timestamppb.New(*chunk.CreateTime),
OriginalFileUid: kbf.UID.String(),
})
res = append(res, convertToProtoChunk(chunk))
}

return &artifactpb.ListChunksResponse{
Chunks: res,
}, nil
}

func (ph *PublicHandler) SearchChunks(ctx context.Context, req *artifactpb.SearchChunksRequest) (*artifactpb.SearchChunksResponse, error) {
log, _ := logger.GetZapLogger(ctx)
_, err := getUserUIDFromContext(ctx)
if err != nil {
log.Error("failed to get user id from header", zap.Error(err))
return nil, fmt.Errorf("failed to get user id from header: %v. err: %w", err, customerror.ErrUnauthenticated)
}
// check if user can access the namespace
ns, err := ph.service.GetNamespaceAndCheckPermission(ctx, req.NamespaceId)
if err != nil {
log.Error("failed to get namespace and check permission", zap.Error(err))
return nil, fmt.Errorf("failed to get namespace and check permission: %w", err)
}

chunkUIDs := make([]uuid.UUID, 0, len(req.ChunkUids))
for _, chunkUID := range req.ChunkUids {
chunkUID, err := uuid.FromString(chunkUID)
if err != nil {
log.Error("failed to parse chunk uid", zap.Error(err))
return nil, fmt.Errorf("failed to parse chunk uid: %w", err)
}
chunkUIDs = append(chunkUIDs, chunkUID)
}
// check if the chunkUIs is more than 20
if len(chunkUIDs) > 25 {
log.Error("chunk uids is more than 20", zap.Int("chunk_uids_count", len(chunkUIDs)))
return nil, fmt.Errorf("chunk uids is more than 20")
}
chunks, err := ph.service.Repository.GetChunksByUIDs(ctx, chunkUIDs)
if err != nil {
log.Error("failed to get chunks by uids", zap.Error(err))
return nil, fmt.Errorf("failed to get chunks by uids: %w", err)
}

// get the kbUIDs from chunks
kbUIDs := make([]uuid.UUID, 0, len(chunks))
for _, chunk := range chunks {
kbUIDs = append(kbUIDs, chunk.KbUID)
}
// use kbUIDs to get the knowledge bases
knowledgeBases, err := ph.service.Repository.GetKnowledgeBasesByUIDs(ctx, kbUIDs)
if err != nil {
log.Error("failed to get knowledge bases by uids", zap.Error(err))
return nil, fmt.Errorf("failed to get knowledge bases by uids: %w", err)
}
// check if the chunks's knowledge base's owner(namespace uid) is the same as namespace uuid in path
for _, knowledgeBase := range knowledgeBases {
if knowledgeBase.Owner != ns.NsUID.String() {
log.Error("chunks's namespace is not the same as namespace in path", zap.String("namespace_id_in_path", ns.NsUID.String()), zap.String("namespace_id_in_chunks", knowledgeBase.Owner))
return nil, fmt.Errorf("chunks's namespace is not the same as namespace in path")
}
}

// populate the response
protoChunks := make([]*artifactpb.Chunk, 0, len(chunks))
for _, chunk := range chunks {
protoChunks = append(protoChunks, convertToProtoChunk(chunk))
}
return &artifactpb.SearchChunksResponse{
Chunks: protoChunks,
}, nil
}

func (ph *PublicHandler) UpdateChunk(ctx context.Context, req *artifactpb.UpdateChunkRequest) (*artifactpb.UpdateChunkResponse, error) {
log, _ := logger.GetZapLogger(ctx)
authUID, err := getUserUIDFromContext(ctx)
Expand Down Expand Up @@ -131,15 +199,7 @@ func (ph *PublicHandler) UpdateChunk(ctx context.Context, req *artifactpb.Update

return &artifactpb.UpdateChunkResponse{
// Populate the response fields appropriately
Chunk: &artifactpb.Chunk{
ChunkUid: chunk.UID.String(),
Retrievable: chunk.Retrievable,
StartPos: uint32(chunk.StartPos),
EndPos: uint32(chunk.EndPos),
Tokens: uint32(chunk.Tokens),
CreateTime: timestamppb.New(*chunk.CreateTime),
// OriginalFileUid: chunk.FileUID.String(),
},
Chunk: convertToProtoChunk(*chunk),
}, nil
}

Expand Down Expand Up @@ -189,3 +249,71 @@ func (ph *PublicHandler) GetSourceFile(ctx context.Context, req *artifactpb.GetS
},
}, nil
}

// SearchSourceFiles
func (ph *PublicHandler) SearchSourceFiles(ctx context.Context, req *artifactpb.SearchSourceFilesRequest) (*artifactpb.SearchSourceFilesResponse, error) {
log, _ := logger.GetZapLogger(ctx)
authUID, err := getUserUIDFromContext(ctx)
if err != nil {
log.Error("failed to get user id from header", zap.Error(err))
return nil, fmt.Errorf("failed to get user id from header: %v. err: %w", err, customerror.ErrUnauthenticated)
}

// Check if user can access the namespace
_, err = ph.service.GetNamespaceAndCheckPermission(ctx, req.NamespaceId)
if err != nil {
log.Error("failed to get namespace and check permission", zap.Error(err))
return nil, fmt.Errorf("failed to get namespace and check permission: %w", err)
}

fileUIDs := make([]uuid.UUID, 0, len(req.FileUids))
for _, fileUID := range req.FileUids {
uid, err := uuid.FromString(fileUID)
if err != nil {
log.Error("failed to parse file uid", zap.Error(err))
return nil, fmt.Errorf("failed to parse file uid: %v. err: %w", err, customerror.ErrInvalidArgument)
}
fileUIDs = append(fileUIDs, uid)
}

sources := make([]*artifactpb.SourceFile, 0, len(fileUIDs))
for _, fileUID := range fileUIDs {
source, err := ph.service.Repository.GetTruthSourceByFileUID(ctx, fileUID)
if err != nil {
log.Error("failed to get truth source by file uid", zap.Error(err))
return nil, fmt.Errorf("failed to get truth source by file uid. err: %w", err)
}

// ACL check for each source file
granted, err := ph.service.ACLClient.CheckPermission(ctx, "knowledgebase", source.KbUID, "reader")
if err != nil {
log.Error("failed to check permission", zap.Error(err))
return nil, fmt.Errorf("failed to check permission. err: %w", err)
}
if !granted {
log.Error("no permission to access source file",
zap.String("user_uid", authUID),
zap.String("kb_uid", source.KbUID.String()))
return nil, fmt.Errorf("no permission to access source file. err: %w. user_uid: %s. kb_uid: %s", customerror.ErrNoPermission, authUID, source.KbUID.String())
}

// Get file content from MinIO
content, err := ph.service.MinIO.GetFile(ctx, minio.KnowledgeBaseBucketName, source.Dest)
if err != nil {
log.Error("failed to get file from minio", zap.Error(err))
continue
}

sources = append(sources, &artifactpb.SourceFile{
OriginalFileUid: source.OriginalFileUID.String(),
OriginalFileName: source.OriginalFileName,
Content: string(content),
CreateTime: timestamppb.New(source.CreateTime),
UpdateTime: timestamppb.New(source.UpdateTime),
})
}

return &artifactpb.SearchSourceFilesResponse{
SourceFiles: sources,
}, nil
}
2 changes: 1 addition & 1 deletion pkg/handler/knowledgebase.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (ph *PublicHandler) CreateCatalog(ctx context.Context, req *artifactpb.Crea
ns, err := ph.service.GetNamespaceByNsID(ctx, req.GetNamespaceId())
if err != nil {
log.Error(
"failed to check namespace permission",
"failed to get namespace",
zap.Error(err),
zap.String("owner_id(ns_id)", req.GetNamespaceId()),
zap.String("auth_uid", authUID))
Expand Down
Loading
Loading