-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: ai-content-moderation plugin #11541
Changes from 18 commits
c80f932
d29f928
8ac0738
7ffa489
5214d0d
2fe1ea2
5cecf2a
cf08f04
460081e
d475eb0
6350d15
f713f87
e16a823
b21b64b
ee34e37
12529f0
57c59ab
093d7a9
7b52fa5
6e3bee2
6a2d575
6bb399c
1f4528d
ef16068
f6f3451
f3672fa
8447d6d
0949327
3a616e1
4c1f2a6
5b1be91
a3e47b2
81958e4
3da00a2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,170 @@ | ||||||||
-- | ||||||||
-- Licensed to the Apache Software Foundation (ASF) under one or more | ||||||||
-- contributor license agreements. See the NOTICE file distributed with | ||||||||
-- this work for additional information regarding copyright ownership. | ||||||||
-- The ASF licenses this file to You under the Apache License, Version 2.0 | ||||||||
-- (the "License"); you may not use this file except in compliance with | ||||||||
-- the License. You may obtain a copy of the License at | ||||||||
-- | ||||||||
-- http://www.apache.org/licenses/LICENSE-2.0 | ||||||||
-- | ||||||||
-- Unless required by applicable law or agreed to in writing, software | ||||||||
-- distributed under the License is distributed on an "AS IS" BASIS, | ||||||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||
-- See the License for the specific language governing permissions and | ||||||||
-- limitations under the License. | ||||||||
-- | ||||||||
local core = require("apisix.core") | ||||||||
local aws = require("resty.aws") | ||||||||
local http = require("resty.http") | ||||||||
local fetch_secrets = require("apisix.secret").fetch_secrets | ||||||||
|
||||||||
local aws_instance = aws() | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
then we can remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||||
local next = next | ||||||||
local pairs = pairs | ||||||||
local unpack = unpack | ||||||||
local type = type | ||||||||
local ipairs = ipairs | ||||||||
|
||||||||
local aws_comprehend_schema = { | ||||||||
type = "object", | ||||||||
properties = { | ||||||||
access_key_id = { type = "string" }, | ||||||||
secret_access_key = { type = "string" }, | ||||||||
region = { type = "string" }, | ||||||||
endpoint = { | ||||||||
type = "string", | ||||||||
pattern = [[^https?://]] | ||||||||
}, | ||||||||
}, | ||||||||
required = { "access_key_id", "secret_access_key", "region", } | ||||||||
} | ||||||||
|
||||||||
local schema = { | ||||||||
type = "object", | ||||||||
properties = { | ||||||||
provider = { | ||||||||
type = "object", | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
To make sure There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||||
properties = { | ||||||||
aws_comprehend = aws_comprehend_schema | ||||||||
}, | ||||||||
-- change to oneOf/enum while implementing support for other services | ||||||||
required = { "aws_comprehend" } | ||||||||
}, | ||||||||
moderation_categories = { | ||||||||
type = "object", | ||||||||
patternProperties = { | ||||||||
-- luacheck: push max code line length 300 | ||||||||
bzp2010 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
["^(PROFANITY|HATE_SPEECH|INSULT|HARASSMENT_OR_ABUSE|SEXUAL|VIOLENCE_OR_THREAT)$"] = { | ||||||||
-- luacheck: pop | ||||||||
type = "number", | ||||||||
bzp2010 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
minimum = 0, | ||||||||
maximum = 1 | ||||||||
} | ||||||||
}, | ||||||||
additionalProperties = false | ||||||||
}, | ||||||||
toxicity_level = { | ||||||||
bzp2010 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
type = "number", | ||||||||
minimum = 0, | ||||||||
maximum = 1, | ||||||||
default = 0.5 | ||||||||
}, | ||||||||
type = { | ||||||||
bzp2010 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
type = "string", | ||||||||
enum = { "openai" }, | ||||||||
} | ||||||||
}, | ||||||||
required = { "provider", "type" }, | ||||||||
} | ||||||||
|
||||||||
|
||||||||
local _M = { | ||||||||
version = 0.1, | ||||||||
priority = 1040, -- TODO: might change | ||||||||
name = "ai-content-moderation", | ||||||||
schema = schema, | ||||||||
} | ||||||||
|
||||||||
|
||||||||
function _M.check_schema(conf) | ||||||||
return core.schema.check(schema, conf) | ||||||||
end | ||||||||
|
||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two blank lines between functions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed. |
||||||||
function _M.rewrite(conf, ctx) | ||||||||
conf = fetch_secrets(conf, true, conf, "") | ||||||||
if not conf then | ||||||||
return 500, "failed to retrieve secrets from conf" | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use Ngx predefined constants? |
||||||||
end | ||||||||
|
||||||||
local body, err = core.request.get_body_table() | ||||||||
if not body then | ||||||||
return 400, err | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto and the following the same |
||||||||
end | ||||||||
|
||||||||
local msgs = body.messages | ||||||||
if not msgs or type(msgs) ~= "table" or #msgs < 1 then | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||
return 400, "messages not found in request body" | ||||||||
end | ||||||||
|
||||||||
local provider = conf.provider[next(conf.provider)] | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the schema should avoid this from happening. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current schema definition does not seem to be able to prevent multiple properties from being entered incorrectly. It is recommended that you consider adding a |
||||||||
|
||||||||
-- TODO support secret | ||||||||
local credentials = aws_instance:Credentials({ | ||||||||
accessKeyId = provider.access_key_id, | ||||||||
secretAccessKey = provider.secret_access_key, | ||||||||
sessionToken = provider.session_token, | ||||||||
}) | ||||||||
|
||||||||
local default_endpoint = "https://comprehend." .. provider.region .. ".amazonaws.com" | ||||||||
local scheme, host, port = unpack(http:parse_uri(provider.endpoint or default_endpoint)) | ||||||||
local endpoint = scheme .. "://" .. host | ||||||||
aws_instance.config.endpoint = endpoint | ||||||||
aws_instance.config.ssl_verify = false | ||||||||
bzp2010 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
|
||||||||
local comprehend = aws_instance:Comprehend({ | ||||||||
credentials = credentials, | ||||||||
endpoint = endpoint, | ||||||||
region = provider.region, | ||||||||
port = port, | ||||||||
}) | ||||||||
|
||||||||
local ai_module = require("apisix.plugins.ai." .. conf.type) | ||||||||
local create_request_text_segments = ai_module.create_request_text_segments | ||||||||
|
||||||||
local text_segments = create_request_text_segments(msgs) | ||||||||
local res, err = comprehend:detectToxicContent({ | ||||||||
LanguageCode = "en", | ||||||||
TextSegments = text_segments, | ||||||||
}) | ||||||||
|
||||||||
if not res then | ||||||||
core.log.error("failed to send request to ", provider, ": ", err) | ||||||||
return 500, err | ||||||||
end | ||||||||
|
||||||||
local results = res.body and res.body.ResultList | ||||||||
if not results or type(results) ~= "table" or #results < 1 then | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
return 500, "failed to get moderation results from response" | ||||||||
end | ||||||||
|
||||||||
for _, result in ipairs(results) do | ||||||||
if conf.moderation_categories then | ||||||||
for _, item in pairs(result.Labels) do | ||||||||
if not conf.moderation_categories[item.Name] then | ||||||||
goto continue | ||||||||
end | ||||||||
if item.Score > conf.moderation_categories[item.Name] then | ||||||||
return 400, "request body exceeds " .. item.Name .. " threshold" | ||||||||
end | ||||||||
::continue:: | ||||||||
end | ||||||||
end | ||||||||
|
||||||||
if result.Toxicity > conf.toxicity_level then | ||||||||
return 400, "request body exceeds toxicity threshold" | ||||||||
end | ||||||||
end | ||||||||
end | ||||||||
|
||||||||
return _M |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
-- | ||
-- Licensed to the Apache Software Foundation (ASF) under one or more | ||
-- contributor license agreements. See the NOTICE file distributed with | ||
-- this work for additional information regarding copyright ownership. | ||
-- The ASF licenses this file to You under the Apache License, Version 2.0 | ||
-- (the "License"); you may not use this file except in compliance with | ||
-- the License. You may obtain a copy of the License at | ||
-- | ||
-- http://www.apache.org/licenses/LICENSE-2.0 | ||
-- | ||
-- Unless required by applicable law or agreed to in writing, software | ||
-- distributed under the License is distributed on an "AS IS" BASIS, | ||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
-- See the License for the specific language governing permissions and | ||
-- limitations under the License. | ||
-- | ||
local core = require("apisix.core") | ||
local ipairs = ipairs | ||
|
||
local _M = {} | ||
|
||
|
||
function _M.create_request_text_segments(msgs) | ||
local text_segments = {} | ||
for _, msg in ipairs(msgs) do | ||
core.table.insert_tail(text_segments, { | ||
Text = msg.content | ||
}) | ||
end | ||
return text_segments | ||
end | ||
|
||
return _M |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ai-proxy PR also has this code so later we can merge from master after ai-proxy is merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that the name of the method there changed :D