diff --git a/.gitignore b/.gitignore index a552964..5714632 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ vendor/ glide.lock .idea/ +image/test_result \ No newline at end of file diff --git a/README.md b/README.md index 364d47c..f8dee36 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,17 @@ #### 文本审核 - [ ] 通用类文本反作弊 +### 图像增强与特效 + +#### 图像特效 +- [x] 黑白图片上色 + +#### 图像增强 +- [x] 图像对比度增强 +- [x] 图像无损放大 +- [x] 图像清晰度增强 +- [x] 图像色彩增强 + # LISENCE the project is licensed under the [Apache License 2.0](https://github.com/chenqinghe/baidu-ai-go-sdk/blob/master/LICENSE) diff --git a/example/image/enhance_01.jpg b/example/image/enhance_01.jpg new file mode 100644 index 0000000..04677a8 Binary files /dev/null and b/example/image/enhance_01.jpg differ diff --git a/go.mod b/go.mod index 9ecd97a..e2af284 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/chenqinghe/baidu-ai-go-sdk go 1.13 -require github.com/imroc/req v0.2.4 +require ( + github.com/imroc/req v0.2.4 + github.com/stretchr/testify v1.9.0 +) diff --git a/go.sum b/go.sum index 497e9af..5941e13 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,21 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/imroc/req v0.2.4 h1:8XbvaQpERLAJV6as/cB186DtH5f0m5zAOtHEaTQ4ac0= github.com/imroc/req v0.2.4/go.mod h1:J9FsaNHDTIVyW/b5r6/Df5qKEEEq2WzZKIgKSajd1AE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/image/client.go b/image/client.go new file mode 100644 index 0000000..6ccea41 --- /dev/null +++ b/image/client.go @@ -0,0 +1,70 @@ +package image + +import ( + "context" + "fmt" + gosdk "github.com/chenqinghe/baidu-ai-go-sdk" + "github.com/imroc/req" + "net/http" +) + +type ImageClient struct { + *gosdk.Client +} + +func NewEnhanceClient(apiKey, secretKey string) *ImageClient { + return &ImageClient{ + Client: gosdk.NewClient(apiKey, secretKey), + } +} + +func (e *ImageClient) requestUrl(url string) (string, error) { + if err := e.Auth(); err != nil { + return "", err + } + + return fmt.Sprintf("%s?access_token=%s", url, e.AccessToken), nil +} + +func (e *ImageClient) posUrlEncode(ctx context.Context, url string, + inputParam *Input, maxSize int) (*EnhanceResponse, error) { + options, err := inputParam.encode(maxSize) + if err != nil { + return nil, err + } + + url, err = e.requestUrl(url) + if err != nil { + return nil, err + } + + header := req.Header{ + "Content-Type": "application/x-www-form-urlencoded", + } + + resp, err := req. + Post(url, req.Param(options), header, ctx) + if err != nil { + return nil, err + } + + if resp.Response().StatusCode != http.StatusOK { + return nil, APIError{ + ErrorCode: 0, + ErrorMsg: resp.String(), + StatusCode: resp.Response().Status, + } + } + + var response generalResponse + err = resp.ToJSON(&response) + if err != nil { + return nil, fmt.Errorf("decode enhance response: %w", err) + } + + if response.success() { + return &response.EnhanceResponse, nil + } + + return nil, response.APIError +} diff --git a/image/const.go b/image/const.go new file mode 100644 index 0000000..11a2f71 --- /dev/null +++ b/image/const.go @@ -0,0 +1,16 @@ +package image + +import "github.com/chenqinghe/baidu-ai-go-sdk/voice" + +const ( + urlContrastEnhance = "https://aip.baidubce.com/rest/2.0/image-process/v1/contrast_enhance" + urlColorEnhance = "https://aip.baidubce.com/rest/2.0/image-process/v1/color_enhance" + urlColorize = "https://aip.baidubce.com/rest/2.0/image-process/v1/colourize" + urlQualityEnhance = "https://aip.baidubce.com/rest/2.0/image-process/v1/image_quality_enhance" + urlDefinitionEnhance = "https://aip.baidubce.com/rest/2.0/image-process/v1/image_definition_enhance" +) + +const ( + image8M = 8 * voice.MB + image4M = 4 * voice.MB +) diff --git a/image/effect.go b/image/effect.go new file mode 100644 index 0000000..bc86a4f --- /dev/null +++ b/image/effect.go @@ -0,0 +1,7 @@ +package image + +import "context" + +func (e *ImageClient) Colourize(ctx context.Context, param *Input) (*EnhanceResponse, error) { + return e.posUrlEncode(ctx, urlColorize, param, image8M) +} diff --git a/image/enhance.go b/image/enhance.go new file mode 100644 index 0000000..b6d6b27 --- /dev/null +++ b/image/enhance.go @@ -0,0 +1,22 @@ +package image + +import ( + "context" +) + +// ContrastEnhance https://ai.baidu.com/ai-doc/IMAGEPROCESS/ek3bclnzn +func (e *ImageClient) ContrastEnhance(ctx context.Context, param *Input) (*EnhanceResponse, error) { + return e.posUrlEncode(ctx, urlContrastEnhance, param, image8M) +} + +func (e *ImageClient) ColorEnhance(ctx context.Context, param *Input) (*EnhanceResponse, error) { + return e.posUrlEncode(ctx, urlColorEnhance, param, image8M) +} + +func (e *ImageClient) QualityEnhance(ctx context.Context, param *Input) (*EnhanceResponse, error) { + return e.posUrlEncode(ctx, urlQualityEnhance, param, image4M) +} + +func (e *ImageClient) DefinitionEnhance(ctx context.Context, param *Input) (*EnhanceResponse, error) { + return e.posUrlEncode(ctx, urlDefinitionEnhance, param, image8M) +} diff --git a/image/enhance_test.go b/image/enhance_test.go new file mode 100644 index 0000000..baa2af4 --- /dev/null +++ b/image/enhance_test.go @@ -0,0 +1,110 @@ +package image + +import ( + "bytes" + "context" + "encoding/base64" + "github.com/stretchr/testify/require" + "io/ioutil" + "os" + "testing" +) + +var ( + client *ImageClient +) + +func preTest() { + if client != nil { + return + } + client = NewEnhanceClient(os.Getenv("BAIDU_API_KEY"), + os.Getenv("BAIDU_SECRET_KEY")) + err := os.MkdirAll("./test_result", os.ModePerm) + if err != nil { + panic(err) + } +} + +func TestAllAPI(t *testing.T) { + preTest() + + t.Run("contrastEnhance", TestContrastEnhance) + + t.Run("color_enhance", func(t *testing.T) { + err := generalCall(client.ColorEnhance, &Input{ + ImageUrl: "https://ai.bdstatic.com/file/75F5ABC751594F55B23AC4168F4A919A", + File: nil, + ImageBase64: "", + }) + require.NoError(t, err) + }) + + t.Run("quality_enhance", func(t *testing.T) { + err := generalCall(client.QualityEnhance, &Input{ + ImageUrl: "https://ai-public-console.cdn.bcebos.com/portal-pc-static/1720085025198/images/technology/imageprocess/image_quality_enhance/1.jpg", + File: nil, + ImageBase64: "", + }) + require.NoError(t, err) + }) + + t.Run("definition_enhance", func(t *testing.T) { + err := generalCall(client.DefinitionEnhance, &Input{ + ImageUrl: "https://ai-public-console.cdn.bcebos.com/portal-pc-static/1720085025198/images/technology/imageprocess/image_definition_enhance/1-1.jpg", + File: nil, + ImageBase64: "", + }) + require.NoError(t, err) + }) + + t.Run("colourize", func(t *testing.T) { + err := generalCall(client.Colourize, &Input{ + ImageUrl: "https://ai-public-console.cdn.bcebos.com/portal-pc-static/1720085025198/images/technology/imageprocess/colourize/1.jpg", + File: nil, + ImageBase64: "", + }) + require.NoError(t, err) + }) +} + +func TestContrastEnhance(t *testing.T) { + preTest() + + err := generalCall(client.ContrastEnhance, &Input{ + ImageUrl: "https://ai-public-console.cdn.bcebos.com/portal-pc-static/1720085025198/images/technology/imageprocess/contrast_enhance/1.jpg", + File: nil, + ImageBase64: "", + }) + require.NoError(t, err) + + binary, err := ioutil.ReadFile("../example/image/enhance_01.jpg") + require.NoError(t, err) + + err = generalCall(client.ContrastEnhance, &Input{ + ImageBase64: base64.StdEncoding.EncodeToString(binary), + }) + require.NoError(t, err) + + err = generalCall(client.ContrastEnhance, &Input{ + File: bytes.NewBuffer(binary), + }) + require.NoError(t, err) + + err = generalCall(client.ContrastEnhance, &Input{ + ImageUrl: "https://www.baidu.com", + }) + require.NotNil(t, err, err) +} + +func generalCall(fn func(ctx context.Context, input *Input) (*EnhanceResponse, error), input *Input) error { + resp, err := fn(context.TODO(), input) + if err != nil { + return err + } + err = decodeToLocal(resp) + if err != nil { + return err + } + return nil +} diff --git a/image/error.go b/image/error.go new file mode 100644 index 0000000..c81fbd2 --- /dev/null +++ b/image/error.go @@ -0,0 +1,22 @@ +package image + +import ( + "errors" + "fmt" +) + +var ( + ErrImageTooLarge = errors.New("image to large") + ErrInvalidImage = errors.New("invalid image") +) + +// APIError https://ai.baidu.com/ai-doc/IMAGEPROCESS/Ek3bclpgv +type APIError struct { + ErrorCode int `json:"error_code"` + ErrorMsg string `json:"error_msg"` + StatusCode string `json:"-"` +} + +func (e APIError) Error() string { + return fmt.Sprintf("%#v", e) +} diff --git a/image/model.go b/image/model.go new file mode 100644 index 0000000..8142b66 --- /dev/null +++ b/image/model.go @@ -0,0 +1,53 @@ +package image + +import ( + "encoding/base64" + "io" +) + +// Input 三选一 +type Input struct { + ImageUrl string `json:"image_url"` + File io.Reader `json:"file"` + ImageBase64 string `json:"image_base_64"` +} + +func (i *Input) encode(maxSize int) (map[string]interface{}, error) { + var invokeParam = make(map[string]interface{}) + switch { + case i.File != nil: + binary, err := io.ReadAll(i.File) + if err != nil { + return nil, err + } + base64Encode := base64.StdEncoding.EncodeToString(binary) + if len(base64Encode) > maxSize { + return nil, ErrImageTooLarge + } + invokeParam["image"] = base64Encode + case i.ImageBase64 != "": + if len(i.ImageBase64) > maxSize { + return nil, ErrImageTooLarge + } + invokeParam["image"] = i.ImageBase64 + case i.ImageUrl != "": + invokeParam["url"] = i.ImageUrl + default: + return nil, ErrInvalidImage + } + return invokeParam, nil +} + +type EnhanceResponse struct { + Image string `json:"image"` + LogID int `json:"log_id"` +} + +type generalResponse struct { + EnhanceResponse + APIError +} + +func (g *generalResponse) success() bool { + return g.ErrorCode == 0 +} diff --git a/image/test_helper.go b/image/test_helper.go new file mode 100644 index 0000000..1e9a345 --- /dev/null +++ b/image/test_helper.go @@ -0,0 +1,22 @@ +package image + +import ( + "encoding/base64" + "fmt" + "io/ioutil" + "os" + "time" +) + +func decodeToLocal(response *EnhanceResponse) error { + binary, err := base64.StdEncoding.DecodeString(response.Image) + if err != nil { + panic(err) + } + + err = ioutil.WriteFile(fmt.Sprintf("./test_result/%d.png", time.Now().UnixMilli()), binary, os.ModePerm) + if err != nil { + panic(err) + } + return nil +}