Skip to content

Commit

Permalink
code review
Browse files Browse the repository at this point in the history
  • Loading branch information
shreemaan-abhishek committed Sep 25, 2024
1 parent edc3ccc commit d49665d
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 28 deletions.
1 change: 0 additions & 1 deletion apisix/plugins/ai-prompt-decorator.lua
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,4 @@ function _M.rewrite(conf, ctx)
end


_M.__decorate = decorate -- for ai-rag plugin
return _M
32 changes: 14 additions & 18 deletions apisix/plugins/ai-rag.lua
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@ local ngx_req = ngx.req

local http = require("resty.http")
local core = require("apisix.core")
local decorate = require("apisix.plugins.ai-prompt-decorator").__decorate

local azure_openai_embeddings = require("apisix.plugins.ai-rag.embeddings.azure_openai").schema
local azure_ai_search_schema = require("apisix.plugins.ai-rag.vector-search.azure_ai_search").schema

local INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local BAD_REQUEST = ngx.HTTP_BAD_REQUEST
local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local HTTP_BAD_REQUEST = ngx.HTTP_BAD_REQUEST

local schema = {
type = "object",
Expand Down Expand Up @@ -70,7 +69,7 @@ local request_schema = {

local _M = {
version = 0.1,
priority = 1060, -- TODO check with other ai plugins
priority = 1060,
name = "ai-rag",
schema = schema,
}
Expand All @@ -85,11 +84,11 @@ function _M.access(conf, ctx)
local httpc = http.new()
local body_tab, err = core.request.get_json_request_body_table()
if not body_tab then
return BAD_REQUEST, err
return HTTP_BAD_REQUEST, err
end
if not body_tab["ai_rag"] then
core.log.error("request body must have \"ai-rag\" field")
return BAD_REQUEST
return HTTP_BAD_REQUEST
end

local embeddings_provider = next(conf.embeddings_provider)
Expand All @@ -110,7 +109,7 @@ function _M.access(conf, ctx)
local ok, err = core.schema.check(request_schema, body_tab)
if not ok then
core.log.error("request body fails schema check: ", err)
return BAD_REQUEST
return HTTP_BAD_REQUEST
end

local embeddings, status, err = embeddings_driver.get_embeddings(embeddings_provider_conf,
Expand All @@ -133,22 +132,19 @@ function _M.access(conf, ctx)
-- also, these values will cause failure when proxying requests to LLM.
body_tab["ai_rag"] = nil

local prepend = {
{
role = "user",
content = res
}
}
local decorator_conf = {
prepend = prepend
}
if not body_tab.messages then
body_tab.messages = {}
end
decorate(decorator_conf, body_tab)

local augment = {
role = "user",
content = res
}
core.table.insert_tail(body_tab.messages, augment)

local req_body_json, err = core.json.encode(body_tab)
if not req_body_json then
return INTERNAL_SERVER_ERROR, err
return HTTP_INTERNAL_SERVER_ERROR, err
end

ngx_req.set_body_data(req_body_json)
Expand Down
12 changes: 6 additions & 6 deletions apisix/plugins/ai-rag/embeddings/azure_openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
-- limitations under the License.
--
local core = require("apisix.core")
local INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local type = type

local _M = {}
Expand All @@ -36,7 +36,7 @@ _M.schema = {
function _M.get_embeddings(conf, body, httpc)
local body_tab, err = core.json.encode(body)
if not body_tab then
return nil, INTERNAL_SERVER_ERROR, err
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end

local res, err = httpc:request_uri(conf.endpoint, {
Expand All @@ -49,7 +49,7 @@ function _M.get_embeddings(conf, body, httpc)
})

if not res or not res.body then
return nil, INTERNAL_SERVER_ERROR, err
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end

if res.status ~= 200 then
Expand All @@ -58,16 +58,16 @@ function _M.get_embeddings(conf, body, httpc)

local res_tab, err = core.json.decode(res.body)
if not res_tab then
return nil, INTERNAL_SERVER_ERROR, err
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end

if type(res_tab.data) ~= "table" or core.table.isempty(res_tab.data) then
return nil, INTERNAL_SERVER_ERROR, res.body
return nil, HTTP_INTERNAL_SERVER_ERROR, res.body
end

local embeddings, err = core.json.encode(res_tab.data[1].embedding)
if not embeddings then
return nil, INTERNAL_SERVER_ERROR, err
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end

return res_tab.data[1].embedding
Expand Down
6 changes: 3 additions & 3 deletions apisix/plugins/ai-rag/vector-search/azure_ai_search.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
-- limitations under the License.
--
local core = require("apisix.core")
local INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR

local _M = {}

Expand Down Expand Up @@ -44,7 +44,7 @@ function _M.search(conf, search_body, httpc)
}
local final_body, err = core.json.encode(body)
if not final_body then
return nil, INTERNAL_SERVER_ERROR, err
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end

local res, err = httpc:request_uri(conf.endpoint, {
Expand All @@ -57,7 +57,7 @@ function _M.search(conf, search_body, httpc)
})

if not res or not res.body then
return nil, INTERNAL_SERVER_ERROR, err
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end

if res.status ~= 200 then
Expand Down

0 comments on commit d49665d

Please sign in to comment.