diff --git a/.env.example b/.env.example new file mode 100644 index 000000000..12437de38 --- /dev/null +++ b/.env.example @@ -0,0 +1,27 @@ +MAIL_HOST= +MAIL_PORT= +MAIL_USERNAME= +MAIL_PASSWORD= +MAIL_FROM_ADDRESS= +MAIL_FROM_NAME= + +MAIL_TO= +MAIL_CC= +MAIL_BCC= + +AWS_ACCESS_KEY_ID= +AWS_ACCESS_KEY_SECRET= +AWS_DEFAULT_REGION= +AWS_BUCKET= +AWS_URL= + +ALIYUN_ACCESS_KEY_ID= +ALIYUN_ACCESS_KEY_SECRET= +ALIYUN_BUCKET= +ALIYUN_URL= +ALIYUN_ENDPOINT= + +TENCENT_ACCESS_KEY_ID= +TENCENT_ACCESS_KEY_SECRET= +TENCENT_BUCKET= +TENCENT_URL= \ No newline at end of file diff --git a/.gitignore b/.gitignore index fc14e6ad4..b05c439cc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ go.sum -config/.env +.env .idea .DS_Store \ No newline at end of file diff --git a/README.md b/README.md index e17766591..024c949e0 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -

+

English | [中文](./README_zh.md) @@ -29,10 +29,9 @@ Golang developers quickly build their own applications. ## Roadmap -- [ ] Optimize migration -- [ ] Orm relationships - [ ] Custom .env path - [ ] Database read-write separation +- [ ] Extend Redis driver ## Documentation diff --git a/README_zh.md b/README_zh.md index e9dafbb6e..c19cc6737 100644 --- a/README_zh.md +++ b/README_zh.md @@ -1,4 +1,4 @@ -

+

[English](./README.md) | 中文 @@ -28,10 +28,9 @@ Goravel 是一个功能完备、具有良好扩展能力的 Web 应用程序框 ## 路线图 -- [ ] 优化迁移 -- [ ] Orm 关联关系 - [ ] 自定义 .env 路径 - [ ] 数据库读写分离 +- [ ] 扩展 Redis 驱动 ## 文档 diff --git a/auth/auth.go b/auth/auth.go index 64b34cff9..b68e46cfa 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -2,13 +2,13 @@ package auth import ( "errors" - "reflect" "strings" "time" contractauth "github.com/goravel/framework/contracts/auth" "github.com/goravel/framework/contracts/http" "github.com/goravel/framework/facades" + "github.com/goravel/framework/support/database" supporttime "github.com/goravel/framework/support/time" "github.com/golang-jwt/jwt/v4" @@ -20,15 +20,6 @@ const ctxKey = "GoravelAuth" var ( unit = time.Minute - - ErrorRefreshTimeExceeded = errors.New("refresh time exceeded") - ErrorTokenExpired = errors.New("token expired") - ErrorNoPrimaryKeyField = errors.New("the primaryKey field was not found in the model, set primaryKey like orm.Model") - ErrorEmptySecret = errors.New("secret is required") - ErrorTokenDisabled = errors.New("token is disabled") - ErrorParseTokenFirst = errors.New("parse token first") - ErrorInvalidClaims = errors.New("invalid claims") - ErrorInvalidToken = errors.New("invalid token") ) type Claims struct { @@ -47,7 +38,7 @@ type Auth struct { guard string } -func NewAuth(guard string) contractauth.Auth { +func NewAuth(guard string) *Auth { return &Auth{ guard: guard, } @@ -66,10 +57,12 @@ func (app *Auth) User(ctx http.Context, user any) error { if auth[app.guard].Claims == nil { return ErrorParseTokenFirst } + if auth[app.guard].Claims.Key == "" { + return ErrorInvalidKey + } if auth[app.guard].Token == "" { return ErrorTokenExpired } - //todo Unit test if err := facades.Orm.Query().Find(user, clause.Eq{Column: clause.PrimaryColumn, Value: auth[app.guard].Claims.Key}); err != nil { return err } @@ -116,25 +109,12 @@ func (app *Auth) Parse(ctx http.Context, token string) error { } func (app *Auth) Login(ctx http.Context, user any) (token string, err error) { - t := reflect.TypeOf(user).Elem() - v := reflect.ValueOf(user).Elem() - for i := 0; i < t.NumField(); i++ { - if t.Field(i).Name == "Model" { - if v.Field(i).Type().Kind() == reflect.Struct { - structField := v.Field(i).Type() - for j := 0; j < structField.NumField(); j++ { - if structField.Field(j).Tag.Get("gorm") == "primaryKey" { - return app.LoginUsingID(ctx, v.Field(i).Field(j).Interface()) - } - } - } - } - if t.Field(i).Tag.Get("gorm") == "primaryKey" { - return app.LoginUsingID(ctx, v.Field(i).Interface()) - } + id := database.GetID(user) + if id == nil { + return "", ErrorNoPrimaryKeyField } - return "", ErrorNoPrimaryKeyField + return app.LoginUsingID(ctx, id) } func (app *Auth) LoginUsingID(ctx http.Context, id any) (token string, err error) { @@ -146,8 +126,12 @@ func (app *Auth) LoginUsingID(ctx http.Context, id any) (token string, err error nowTime := supporttime.Now() ttl := facades.Config.GetInt("jwt.ttl") expireTime := nowTime.Add(time.Duration(ttl) * unit) + key := cast.ToString(id) + if key == "" { + return "", ErrorInvalidKey + } claims := Claims{ - cast.ToString(id), + key, jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(expireTime), IssuedAt: jwt.NewNumericDate(nowTime), diff --git a/auth/auth_test.go b/auth/auth_test.go index 92469b94f..91c399a67 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -5,15 +5,13 @@ import ( "testing" "time" - contractauth "github.com/goravel/framework/contracts/auth" - "github.com/goravel/framework/database/orm" - "github.com/goravel/framework/http" - "github.com/goravel/framework/testing/mock" - - "github.com/stretchr/testify/assert" testifymock "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" "gorm.io/gorm/clause" + + "github.com/goravel/framework/database/orm" + "github.com/goravel/framework/http" + "github.com/goravel/framework/testing/mock" ) var guard = "user" @@ -25,15 +23,14 @@ type User struct { type AuthTestSuite struct { suite.Suite + auth *Auth } -var app contractauth.Auth - func TestAuthTestSuite(t *testing.T) { unit = time.Second - app = NewAuth(guard) - - suite.Run(t, new(AuthTestSuite)) + suite.Run(t, &AuthTestSuite{ + auth: NewAuth(guard), + }) } func (s *AuthTestSuite) SetupTest() { @@ -44,9 +41,21 @@ func (s *AuthTestSuite) TestLoginUsingID_EmptySecret() { mockConfig := mock.Config() mockConfig.On("GetString", "jwt.secret").Return("").Once() - token, err := app.LoginUsingID(http.Background(), 1) - assert.Empty(s.T(), token) - assert.ErrorIs(s.T(), err, ErrorEmptySecret) + token, err := s.auth.LoginUsingID(http.Background(), 1) + s.Empty(token) + s.ErrorIs(err, ErrorEmptySecret) + + mockConfig.AssertExpectations(s.T()) +} + +func (s *AuthTestSuite) TestLoginUsingID_InvalidKey() { + mockConfig := mock.Config() + mockConfig.On("GetString", "jwt.secret").Return("Goravel").Once() + mockConfig.On("GetInt", "jwt.ttl").Return(2).Once() + + token, err := s.auth.LoginUsingID(http.Background(), "") + s.Empty(token) + s.ErrorIs(err, ErrorInvalidKey) mockConfig.AssertExpectations(s.T()) } @@ -56,9 +65,9 @@ func (s *AuthTestSuite) TestLoginUsingID() { mockConfig.On("GetString", "jwt.secret").Return("Goravel").Once() mockConfig.On("GetInt", "jwt.ttl").Return(2).Once() - token, err := app.LoginUsingID(http.Background(), 1) - assert.NotEmpty(s.T(), token) - assert.Nil(s.T(), err) + token, err := s.auth.LoginUsingID(http.Background(), 1) + s.NotEmpty(token) + s.Nil(err) mockConfig.AssertExpectations(s.T()) } @@ -71,9 +80,9 @@ func (s *AuthTestSuite) TestLogin_Model() { var user User user.ID = 1 user.Name = "Goravel" - token, err := app.Login(http.Background(), &user) - assert.NotEmpty(s.T(), token) - assert.Nil(s.T(), err) + token, err := s.auth.Login(http.Background(), &user) + s.NotEmpty(token) + s.Nil(err) mockConfig.AssertExpectations(s.T()) } @@ -91,9 +100,9 @@ func (s *AuthTestSuite) TestLogin_CustomModel() { var user CustomUser user.ID = 1 user.Name = "Goravel" - token, err := app.Login(http.Background(), &user) - assert.NotEmpty(s.T(), token) - assert.Nil(s.T(), err) + token, err := s.auth.Login(http.Background(), &user) + s.NotEmpty(token) + s.Nil(err) mockConfig.AssertExpectations(s.T()) } @@ -107,9 +116,24 @@ func (s *AuthTestSuite) TestLogin_ErrorModel() { var errorUser ErrorUser errorUser.ID = 1 errorUser.Name = "Goravel" - token, err := app.Login(http.Background(), &errorUser) - assert.Empty(s.T(), token) - assert.EqualError(s.T(), err, "the primaryKey field was not found in the model, set primaryKey like orm.Model") + token, err := s.auth.Login(http.Background(), &errorUser) + s.Empty(token) + s.EqualError(err, "the primaryKey field was not found in the model, set primaryKey like orm.Model") +} + +func (s *AuthTestSuite) TestLogin_NoPrimaryKey() { + type User struct { + ID uint + Name string + } + + ctx := http.Background() + var user User + user.ID = 1 + user.Name = "Goravel" + token, err := s.auth.Login(ctx, &user) + s.Empty(token) + s.ErrorIs(err, ErrorNoPrimaryKeyField) } func (s *AuthTestSuite) TestParse_TokenDisabled() { @@ -117,8 +141,8 @@ func (s *AuthTestSuite) TestParse_TokenDisabled() { mockCache := mock.Cache() mockCache.On("GetBool", "jwt:disabled:"+token, false).Return(true).Once() - err := app.Parse(http.Background(), token) - assert.EqualError(s.T(), err, "token is disabled") + err := s.auth.Parse(http.Background(), token) + s.EqualError(err, "token is disabled") } func (s *AuthTestSuite) TestParse_TokenInvalid() { @@ -129,8 +153,8 @@ func (s *AuthTestSuite) TestParse_TokenInvalid() { mockCache := mock.Cache() mockCache.On("GetBool", "jwt:disabled:"+token, false).Return(false).Once() - err := app.Parse(http.Background(), token) - assert.NotNil(s.T(), err) + err := s.auth.Parse(http.Background(), token) + s.NotNil(err) mockConfig.AssertExpectations(s.T()) } @@ -141,16 +165,16 @@ func (s *AuthTestSuite) TestParse_TokenExpired() { mockConfig.On("GetInt", "jwt.ttl").Return(2).Once() ctx := http.Background() - token, err := app.LoginUsingID(ctx, 1) - assert.Nil(s.T(), err) + token, err := s.auth.LoginUsingID(ctx, 1) + s.Nil(err) time.Sleep(2 * unit) mockCache := mock.Cache() mockCache.On("GetBool", "jwt:disabled:"+token, false).Return(false).Once() - err = app.Parse(ctx, token) - assert.ErrorIs(s.T(), err, ErrorTokenExpired) + err = s.auth.Parse(ctx, token) + s.ErrorIs(err, ErrorTokenExpired) mockConfig.AssertExpectations(s.T()) } @@ -161,14 +185,14 @@ func (s *AuthTestSuite) TestParse_Success() { mockConfig.On("GetInt", "jwt.ttl").Return(2).Once() ctx := http.Background() - token, err := app.LoginUsingID(ctx, 1) - assert.Nil(s.T(), err) + token, err := s.auth.LoginUsingID(ctx, 1) + s.Nil(err) mockCache := mock.Cache() mockCache.On("GetBool", "jwt:disabled:"+token, false).Return(false).Once() - err = app.Parse(ctx, token) - assert.Nil(s.T(), err) + err = s.auth.Parse(ctx, token) + s.Nil(err) mockConfig.AssertExpectations(s.T()) } @@ -179,14 +203,14 @@ func (s *AuthTestSuite) TestParse_SuccessWithPrefix() { mockConfig.On("GetInt", "jwt.ttl").Return(2).Once() ctx := http.Background() - token, err := app.LoginUsingID(ctx, 1) - assert.Nil(s.T(), err) + token, err := s.auth.LoginUsingID(ctx, 1) + s.Nil(err) mockCache := mock.Cache() mockCache.On("GetBool", "jwt:disabled:"+token, false).Return(false).Once() - err = app.Parse(ctx, "Bearer "+token) - assert.Nil(s.T(), err) + err = s.auth.Parse(ctx, "Bearer "+token) + s.Nil(err) mockConfig.AssertExpectations(s.T()) } @@ -196,8 +220,8 @@ func (s *AuthTestSuite) TestUser_NoParse() { ctx := http.Background() var user User - err := app.User(ctx, user) - assert.EqualError(s.T(), err, "parse token first") + err := s.auth.User(ctx, user) + s.EqualError(err, "parse token first") mockConfig.AssertExpectations(s.T()) } @@ -208,23 +232,23 @@ func (s *AuthTestSuite) TestUser_DBError() { mockConfig.On("GetInt", "jwt.ttl").Return(2).Once() ctx := http.Background() - token, err := app.LoginUsingID(ctx, 1) - assert.Nil(s.T(), err) + token, err := s.auth.LoginUsingID(ctx, 1) + s.Nil(err) mockCache := mock.Cache() mockCache.On("GetBool", "jwt:disabled:"+token, false).Return(false).Once() - err = app.Parse(ctx, token) - assert.Nil(s.T(), err) + err = s.auth.Parse(ctx, token) + s.Nil(err) var user User - mockOrm, mockDB, _ := mock.Orm() + mockOrm, mockDB, _, _ := mock.Orm() mockOrm.On("Query").Return(mockDB) mockDB.On("Find", &user, clause.Eq{Column: clause.PrimaryColumn, Value: "1"}).Return(errors.New("error")).Once() - err = app.User(ctx, &user) - assert.EqualError(s.T(), err, "error") + err = s.auth.User(ctx, &user) + s.EqualError(err, "error") mockConfig.AssertExpectations(s.T()) } @@ -235,34 +259,34 @@ func (s *AuthTestSuite) TestUser_Expired() { mockConfig.On("GetInt", "jwt.ttl").Return(2) ctx := http.Background() - token, err := app.LoginUsingID(ctx, 1) - assert.NotEmpty(s.T(), token) - assert.Nil(s.T(), err) + token, err := s.auth.LoginUsingID(ctx, 1) + s.NotEmpty(token) + s.Nil(err) mockCache := mock.Cache() mockCache.On("GetBool", "jwt:disabled:"+token, false).Return(false).Once() time.Sleep(2 * unit) - err = app.Parse(ctx, token) - assert.ErrorIs(s.T(), err, ErrorTokenExpired) + err = s.auth.Parse(ctx, token) + s.ErrorIs(err, ErrorTokenExpired) var user User - err = app.User(ctx, &user) - assert.EqualError(s.T(), err, "token expired") + err = s.auth.User(ctx, &user) + s.EqualError(err, "token expired") mockConfig.On("GetInt", "jwt.refresh_ttl").Return(2).Once() - token, err = app.Refresh(ctx) - assert.NotEmpty(s.T(), token) - assert.Nil(s.T(), err) + token, err = s.auth.Refresh(ctx) + s.NotEmpty(token) + s.Nil(err) - mockOrm, mockDB, _ := mock.Orm() + mockOrm, mockDB, _, _ := mock.Orm() mockOrm.On("Query").Return(mockDB) mockDB.On("Find", &user, clause.Eq{Column: clause.PrimaryColumn, Value: "1"}).Return(nil).Once() - err = app.User(ctx, &user) - assert.Nil(s.T(), err) + err = s.auth.User(ctx, &user) + s.Nil(err) mockConfig.AssertExpectations(s.T()) } @@ -273,29 +297,29 @@ func (s *AuthTestSuite) TestUser_RefreshExpired() { mockConfig.On("GetInt", "jwt.ttl").Return(2).Once() ctx := http.Background() - token, err := app.LoginUsingID(ctx, 1) - assert.NotEmpty(s.T(), token) - assert.Nil(s.T(), err) + token, err := s.auth.LoginUsingID(ctx, 1) + s.NotEmpty(token) + s.Nil(err) mockCache := mock.Cache() mockCache.On("GetBool", "jwt:disabled:"+token, false).Return(false).Once() time.Sleep(2 * unit) - err = app.Parse(ctx, token) - assert.ErrorIs(s.T(), err, ErrorTokenExpired) + err = s.auth.Parse(ctx, token) + s.ErrorIs(err, ErrorTokenExpired) var user User - err = app.User(ctx, &user) - assert.EqualError(s.T(), err, "token expired") + err = s.auth.User(ctx, &user) + s.EqualError(err, "token expired") mockConfig.On("GetInt", "jwt.refresh_ttl").Return(1).Once() time.Sleep(2 * unit) - token, err = app.Refresh(ctx) - assert.Empty(s.T(), token) - assert.EqualError(s.T(), err, "refresh time exceeded") + token, err = s.auth.Refresh(ctx) + s.Empty(token) + s.EqualError(err, "refresh time exceeded") mockConfig.AssertExpectations(s.T()) } @@ -306,22 +330,22 @@ func (s *AuthTestSuite) TestUser_Success() { mockConfig.On("GetInt", "jwt.ttl").Return(2).Once() ctx := http.Background() - token, err := app.LoginUsingID(ctx, 1) - assert.Nil(s.T(), err) + token, err := s.auth.LoginUsingID(ctx, 1) + s.Nil(err) mockCache := mock.Cache() mockCache.On("GetBool", "jwt:disabled:"+token, false).Return(false).Once() - err = app.Parse(ctx, token) - assert.Nil(s.T(), err) + err = s.auth.Parse(ctx, token) + s.Nil(err) var user User - mockOrm, mockDB, _ := mock.Orm() + mockOrm, mockDB, _, _ := mock.Orm() mockOrm.On("Query").Return(mockDB) mockDB.On("Find", &user, clause.Eq{Column: clause.PrimaryColumn, Value: "1"}).Return(nil).Once() - err = app.User(ctx, &user) - assert.Nil(s.T(), err) + err = s.auth.User(ctx, &user) + s.Nil(err) mockConfig.AssertExpectations(s.T()) } @@ -330,9 +354,9 @@ func (s *AuthTestSuite) TestRefresh_NotParse() { mockConfig := mock.Config() ctx := http.Background() - token, err := app.Refresh(ctx) - assert.Empty(s.T(), token) - assert.EqualError(s.T(), err, "parse token first") + token, err := s.auth.Refresh(ctx) + s.Empty(token) + s.EqualError(err, "parse token first") mockConfig.AssertExpectations(s.T()) } @@ -343,21 +367,21 @@ func (s *AuthTestSuite) TestRefresh_RefreshTimeExceeded() { mockConfig.On("GetInt", "jwt.ttl").Return(2).Once() ctx := http.Background() - token, err := app.LoginUsingID(ctx, 1) - assert.Nil(s.T(), err) + token, err := s.auth.LoginUsingID(ctx, 1) + s.Nil(err) mockCache := mock.Cache() mockCache.On("GetBool", "jwt:disabled:"+token, false).Return(false).Once() - err = app.Parse(ctx, token) - assert.Nil(s.T(), err) + err = s.auth.Parse(ctx, token) + s.Nil(err) mockConfig.On("GetInt", "jwt.refresh_ttl").Return(1).Once() time.Sleep(4 * unit) - token, err = app.Refresh(ctx) - assert.Empty(s.T(), token) - assert.EqualError(s.T(), err, "refresh time exceeded") + token, err = s.auth.Refresh(ctx) + s.Empty(token) + s.EqualError(err, "refresh time exceeded") mockConfig.AssertExpectations(s.T()) } @@ -368,21 +392,21 @@ func (s *AuthTestSuite) TestRefresh_Success() { mockConfig.On("GetInt", "jwt.ttl").Return(2) ctx := http.Background() - token, err := app.LoginUsingID(ctx, 1) - assert.Nil(s.T(), err) + token, err := s.auth.LoginUsingID(ctx, 1) + s.Nil(err) mockCache := mock.Cache() mockCache.On("GetBool", "jwt:disabled:"+token, false).Return(false).Once() - err = app.Parse(ctx, token) - assert.Nil(s.T(), err) + err = s.auth.Parse(ctx, token) + s.Nil(err) mockConfig.On("GetInt", "jwt.refresh_ttl").Return(1).Once() time.Sleep(2 * unit) - token, err = app.Refresh(ctx) - assert.NotEmpty(s.T(), token) - assert.Nil(s.T(), err) + token, err = s.auth.Refresh(ctx) + s.NotEmpty(token) + s.Nil(err) mockConfig.AssertExpectations(s.T()) } @@ -393,16 +417,16 @@ func (s *AuthTestSuite) TestLogout_CacheUnsupported() { mockConfig.On("GetInt", "jwt.ttl").Return(2) ctx := http.Background() - token, err := app.LoginUsingID(ctx, 1) - assert.NotEmpty(s.T(), token) - assert.Nil(s.T(), err) - assert.EqualError(s.T(), app.Logout(ctx), "cache support is required") + token, err := s.auth.LoginUsingID(ctx, 1) + s.NotEmpty(token) + s.Nil(err) + s.EqualError(s.auth.Logout(ctx), "cache support is required") mockConfig.AssertExpectations(s.T()) } func (s *AuthTestSuite) TestLogout_NotParse() { - assert.Nil(s.T(), app.Logout(http.Background())) + s.Nil(s.auth.Logout(http.Background())) } func (s *AuthTestSuite) TestLogout_SetDisabledCacheError() { @@ -411,18 +435,18 @@ func (s *AuthTestSuite) TestLogout_SetDisabledCacheError() { mockConfig.On("GetInt", "jwt.ttl").Return(2) ctx := http.Background() - token, err := app.LoginUsingID(ctx, 1) - assert.Nil(s.T(), err) + token, err := s.auth.LoginUsingID(ctx, 1) + s.Nil(err) mockCache := mock.Cache() mockCache.On("GetBool", "jwt:disabled:"+token, false).Return(false).Once() - err = app.Parse(ctx, token) - assert.Nil(s.T(), err) + err = s.auth.Parse(ctx, token) + s.Nil(err) mockCache.On("Put", testifymock.Anything, true, 2*unit).Return(errors.New("error")).Once() - assert.EqualError(s.T(), app.Logout(ctx), "error") + s.EqualError(s.auth.Logout(ctx), "error") mockConfig.AssertExpectations(s.T()) } @@ -433,19 +457,19 @@ func (s *AuthTestSuite) TestLogout_Success() { mockConfig.On("GetInt", "jwt.ttl").Return(2) ctx := http.Background() - token, err := app.LoginUsingID(ctx, 1) - assert.NotEmpty(s.T(), token) - assert.Nil(s.T(), err) + token, err := s.auth.LoginUsingID(ctx, 1) + s.NotEmpty(token) + s.Nil(err) mockCache := mock.Cache() mockCache.On("GetBool", "jwt:disabled:"+token, false).Return(false).Once() - err = app.Parse(ctx, token) - assert.Nil(s.T(), err) + err = s.auth.Parse(ctx, token) + s.Nil(err) mockCache.On("Put", testifymock.Anything, true, 2*unit).Return(nil).Once() - assert.Nil(s.T(), app.Logout(ctx)) + s.Nil(s.auth.Logout(ctx)) mockConfig.AssertExpectations(s.T()) } diff --git a/auth/errors.go b/auth/errors.go new file mode 100644 index 000000000..888976707 --- /dev/null +++ b/auth/errors.go @@ -0,0 +1,17 @@ +package auth + +import ( + "errors" +) + +var ( + ErrorRefreshTimeExceeded = errors.New("refresh time exceeded") + ErrorTokenExpired = errors.New("token expired") + ErrorNoPrimaryKeyField = errors.New("the primaryKey field was not found in the model, set primaryKey like orm.Model") + ErrorEmptySecret = errors.New("secret is required") + ErrorTokenDisabled = errors.New("token is disabled") + ErrorParseTokenFirst = errors.New("parse token first") + ErrorInvalidClaims = errors.New("invalid claims") + ErrorInvalidToken = errors.New("invalid token") + ErrorInvalidKey = errors.New("invalid key") +) diff --git a/cache/application.go b/cache/application.go index 105fb4c98..d12a342ad 100644 --- a/cache/application.go +++ b/cache/application.go @@ -2,7 +2,9 @@ package cache import ( "context" + "github.com/gookit/color" + "github.com/goravel/framework/contracts/cache" "github.com/goravel/framework/facades" ) @@ -14,18 +16,25 @@ func (app *Application) Init() cache.Store { defaultStore := facades.Config.GetString("cache.default") driver := facades.Config.GetString("cache.stores." + defaultStore + ".driver") if driver == "redis" { - return NewRedis(context.Background()) + redis, err := NewRedis(context.Background()) + if err != nil { + color.Redf("[Cache] %v\n", err) + return nil + } + + return redis } + if driver == "custom" { if custom, ok := facades.Config.Get("cache.stores." + defaultStore + ".via").(cache.Store); ok { return custom } - color.Redln("%s doesn't implement contracts/cache/store", defaultStore) + color.Redf("[Cache] %s doesn't implement contracts/cache/store\n", defaultStore) return nil } - color.Redln("Not supported cache store: %s", defaultStore) + color.Redf("[Cache] Not supported cache store: %s\n", defaultStore) return nil } diff --git a/cache/application_test.go b/cache/application_test.go index 21003ccae..40ed9b76c 100644 --- a/cache/application_test.go +++ b/cache/application_test.go @@ -1,74 +1,348 @@ package cache import ( - "os" + "context" + "log" "testing" "time" - "github.com/goravel/framework/config" - "github.com/goravel/framework/console" - "github.com/goravel/framework/facades" - "github.com/goravel/framework/testing/file" - "github.com/stretchr/testify/assert" -) + "github.com/goravel/framework/contracts/cache" + testingdocker "github.com/goravel/framework/testing/docker" + "github.com/goravel/framework/testing/mock" -func TestInit(t *testing.T) { - initConfig() + "github.com/ory/dockertest/v3" + "github.com/stretchr/testify/suite" +) - assert.NotPanics(t, func() { - app := Application{} - app.Init() - }) +type ApplicationTestSuite struct { + suite.Suite + stores []cache.Store + redisDocker *dockertest.Resource } -func TestClearCommand(t *testing.T) { - initConfig() - - consoleServiceProvider := console.ServiceProvider{} - consoleServiceProvider.Register() - - cacheServiceProvider := ServiceProvider{} - cacheServiceProvider.Register() - cacheServiceProvider.Boot() - - err := facades.Cache.Put("test-clear-command", "goravel", 5*time.Second) - assert.Nil(t, err) - assert.True(t, facades.Cache.Has("test-clear-command")) +func TestApplicationTestSuite(t *testing.T) { + redisPool, redisDocker, redisStore, err := getRedisDocker() + if err != nil { + log.Fatalf("Get redis error: %s", err) + } - assert.NotPanics(t, func() { - facades.Artisan.Call("cache:clear") + suite.Run(t, &ApplicationTestSuite{ + stores: []cache.Store{ + redisStore, + }, + redisDocker: redisDocker, }) - assert.False(t, facades.Cache.Has("test-clear-command")) + if err := redisPool.Purge(redisDocker); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } +} + +func (s *ApplicationTestSuite) SetupTest() { } -func initConfig() { - file.CreateEnv() - configServiceProvider := config.ServiceProvider{} - configServiceProvider.Register() +func (s *ApplicationTestSuite) TestInitRedis() { + tests := []struct { + description string + setup func(description string) + }{ + { + description: "success", + setup: func(description string) { + mockConfig := mock.Config() + mockConfig.On("GetString", "cache.default").Return("redis").Twice() + mockConfig.On("GetString", "cache.stores.redis.driver").Return("redis").Once() + mockConfig.On("GetString", "cache.stores.redis.connection").Return("default").Once() + mockConfig.On("GetString", "database.redis.default.host").Return("localhost").Once() + mockConfig.On("GetString", "database.redis.default.port").Return(s.redisDocker.GetPort("6379/tcp")).Once() + mockConfig.On("GetString", "database.redis.default.password").Return("").Once() + mockConfig.On("GetInt", "database.redis.default.database").Return(0).Once() + mockConfig.On("GetString", "cache.prefix").Return("goravel_cache").Once() - facadesConfig := facades.Config - facadesConfig.Add("cache", map[string]interface{}{ - "default": facadesConfig.Env("CACHE_DRIVER", "redis"), - "stores": map[string]interface{}{ - "redis": map[string]interface{}{ - "driver": "redis", - "connection": "default", + app := Application{} + s.NotNil(app.Init(), description) + + mockConfig.AssertExpectations(s.T()) }, }, - "prefix": "goravel_cache", - }) + { + description: "error", + setup: func(description string) { + mockConfig := mock.Config() + mockConfig.On("GetString", "cache.default").Return("redis").Twice() + mockConfig.On("GetString", "cache.stores.redis.driver").Return("redis").Once() + mockConfig.On("GetString", "cache.stores.redis.connection").Return("default").Once() + mockConfig.On("GetString", "database.redis.default.host").Return("").Once() - facadesConfig.Add("database", map[string]interface{}{ - "redis": map[string]interface{}{ - "default": map[string]interface{}{ - "host": facadesConfig.Env("REDIS_HOST", "127.0.0.1"), - "password": facadesConfig.Env("REDIS_PASSWORD", ""), - "port": facadesConfig.Env("REDIS_PORT", 6379), - "database": facadesConfig.Env("REDIS_DB", 0), + app := Application{} + s.Nil(app.Init(), description) + + mockConfig.AssertExpectations(s.T()) }, }, - }) + } + + for _, test := range tests { + test.setup(test.description) + } +} + +func (s *ApplicationTestSuite) TestAdd() { + for _, store := range s.stores { + s.Nil(store.Put("name", "Goravel", 1*time.Second)) + s.False(store.Add("name", "World", 1*time.Second)) + s.True(store.Add("name1", "World", 1*time.Second)) + s.True(store.Has("name1")) + time.Sleep(2 * time.Second) + s.False(store.Has("name1")) + s.True(store.Flush()) + } +} + +func (s *ApplicationTestSuite) TestForever() { + for _, store := range s.stores { + s.True(store.Forever("name", "Goravel")) + s.Equal("Goravel", store.Get("name", "").(string)) + s.True(store.Flush()) + } +} + +func (s *ApplicationTestSuite) TestForget() { + for _, store := range s.stores { + val := store.Forget("test-forget") + s.True(val) + + err := store.Put("test-forget", "goravel", 5*time.Second) + s.Nil(err) + s.True(store.Forget("test-forget")) + } +} + +func (s *ApplicationTestSuite) TestFlush() { + for _, store := range s.stores { + s.Nil(store.Put("test-flush", "goravel", 5*time.Second)) + s.Equal("goravel", store.Get("test-flush", nil).(string)) + + s.True(store.Flush()) + s.False(store.Has("test-flush")) + } +} + +func (s *ApplicationTestSuite) TestGet() { + for _, store := range s.stores { + s.Nil(store.Put("name", "Goravel", 1*time.Second)) + s.Equal("Goravel", store.Get("name", "").(string)) + s.Equal("World", store.Get("name1", "World").(string)) + s.Equal("World1", store.Get("name2", func() interface{} { + return "World1" + }).(string)) + s.True(store.Forget("name")) + s.True(store.Flush()) + } +} + +func (s *ApplicationTestSuite) TestGetBool() { + for _, store := range s.stores { + s.Equal(true, store.GetBool("test-get-bool", true)) + s.Nil(store.Put("test-get-bool", true, 2*time.Second)) + s.Equal(true, store.GetBool("test-get-bool", false)) + } +} + +func (s *ApplicationTestSuite) TestGetInt() { + for _, store := range s.stores { + s.Equal(2, store.GetInt("test-get-int", 2)) + s.Nil(store.Put("test-get-int", 3, 2*time.Second)) + s.Equal(3, store.GetInt("test-get-int", 2)) + } +} + +func (s *ApplicationTestSuite) TestGetString() { + for _, store := range s.stores { + s.Equal("2", store.GetString("test-get-string", "2")) + s.Nil(store.Put("test-get-string", "3", 2*time.Second)) + s.Equal("3", store.GetString("test-get-string", "2")) + } +} + +func (s *ApplicationTestSuite) TestHas() { + for _, store := range s.stores { + s.False(store.Has("test-has")) + s.Nil(store.Put("test-has", "goravel", 5*time.Second)) + s.True(store.Has("test-has")) + } +} + +func (s *ApplicationTestSuite) TestPull() { + for _, store := range s.stores { + s.Nil(store.Put("name", "Goravel", 1*time.Second)) + s.True(store.Has("name")) + s.Equal("Goravel", store.Pull("name", "").(string)) + s.False(store.Has("name")) + } +} + +func (s *ApplicationTestSuite) TestPut() { + for _, store := range s.stores { + s.Nil(store.Put("name", "Goravel", 1*time.Second)) + s.True(store.Has("name")) + s.Equal("Goravel", store.Get("name", "").(string)) + time.Sleep(2 * time.Second) + s.False(store.Has("name")) + } +} + +func (s *ApplicationTestSuite) TestRemember() { + for _, store := range s.stores { + s.Nil(store.Put("name", "Goravel", 1*time.Second)) + value, err := store.Remember("name", 1*time.Second, func() interface{} { + return "World" + }) + s.Nil(err) + s.Equal("Goravel", value) + + value, err = store.Remember("name1", 1*time.Second, func() interface{} { + return "World1" + }) + s.Nil(err) + s.Equal("World1", value) + time.Sleep(2 * time.Second) + s.False(store.Has("name1")) + s.True(store.Flush()) + } +} - os.Remove(".env") +func (s *ApplicationTestSuite) TestRememberForever() { + for _, store := range s.stores { + s.Nil(store.Put("name", "Goravel", 1*time.Second)) + value, err := store.RememberForever("name", func() interface{} { + return "World" + }) + s.Nil(err) + s.Equal("Goravel", value) + + value, err = store.RememberForever("name1", func() interface{} { + return "World1" + }) + s.Nil(err) + s.Equal("World1", value) + s.True(store.Flush()) + } +} + +func (s *ApplicationTestSuite) TestCustomDriver() { + mockConfig := mock.Config() + mockConfig.On("GetString", "cache.default").Return("store").Once() + mockConfig.On("GetString", "cache.stores.store.driver").Return("custom").Once() + mockConfig.On("Get", "cache.stores.store.via").Return(&Store{}).Once() + + app := Application{} + store := app.Init() + s.NotNil(store) + s.Equal("Goravel", store.Get("name", "Goravel").(string)) + + mockConfig.AssertExpectations(s.T()) +} + +func getRedisDocker() (*dockertest.Pool, *dockertest.Resource, cache.Store, error) { + pool, resource, err := testingdocker.Redis() + if err != nil { + return nil, nil, nil, err + } + + _ = resource.Expire(60) + + var store cache.Store + if err := pool.Retry(func() error { + var err error + mockConfig := mock.Config() + mockConfig.On("GetString", "cache.default").Return("redis").Once() + mockConfig.On("GetString", "cache.stores.redis.connection").Return("default").Once() + mockConfig.On("GetString", "database.redis.default.host").Return("localhost").Once() + mockConfig.On("GetString", "database.redis.default.port").Return(resource.GetPort("6379/tcp")).Once() + mockConfig.On("GetString", "database.redis.default.password").Return(resource.GetPort("")).Once() + mockConfig.On("GetInt", "database.redis.default.database").Return(0).Once() + mockConfig.On("GetString", "cache.prefix").Return("goravel_cache").Once() + store, err = NewRedis(context.Background()) + + return err + }); err != nil { + return nil, nil, nil, err + } + + return pool, resource, store, nil +} + +type Store struct { +} + +func (r *Store) WithContext(ctx context.Context) cache.Store { + return r +} + +//Get Retrieve an item from the cache by key. +func (r *Store) Get(key string, def interface{}) interface{} { + return def +} + +//Get Retrieve an item from the cache by key. +func (r *Store) GetInt(key string, def int) int { + return def +} + +//Get Retrieve an item from the cache by key. +func (r *Store) GetBool(key string, def bool) bool { + return def } + +//Get Retrieve an item from the cache by key. +func (r *Store) GetString(key string, def string) string { + return def +} + +//Has Check an item exists in the cache. +func (r *Store) Has(key string) bool { + return true +} + +//Put Store an item in the cache for a given number of seconds. +func (r *Store) Put(key string, value interface{}, seconds time.Duration) error { + return nil +} + +//Pull Retrieve an item from the cache and delete it. +func (r *Store) Pull(key string, def interface{}) interface{} { + return def +} + +//Add Store an item in the cache if the key does not exist. +func (r *Store) Add(key string, value interface{}, seconds time.Duration) bool { + return true +} + +//Remember Get an item from the cache, or execute the given Closure and store the result. +func (r *Store) Remember(key string, ttl time.Duration, callback func() interface{}) (interface{}, error) { + return "", nil +} + +//RememberForever Get an item from the cache, or execute the given Closure and store the result forever. +func (r *Store) RememberForever(key string, callback func() interface{}) (interface{}, error) { + return "", nil +} + +//Forever Store an item in the cache indefinitely. +func (r *Store) Forever(key string, value interface{}) bool { + return true +} + +//Forget Remove an item from the cache. +func (r *Store) Forget(key string) bool { + return true +} + +//Flush Remove all items from the cache. +func (r *Store) Flush() bool { + return true +} + +var _ cache.Store = &Store{} diff --git a/cache/console/clear_command.go b/cache/console/clear_command.go index b5568ad45..eb8cad485 100644 --- a/cache/console/clear_command.go +++ b/cache/console/clear_command.go @@ -2,6 +2,7 @@ package console import ( "github.com/gookit/color" + "github.com/goravel/framework/contracts/console" "github.com/goravel/framework/contracts/console/command" "github.com/goravel/framework/facades" diff --git a/cache/redis_store.go b/cache/redis.go similarity index 89% rename from cache/redis_store.go rename to cache/redis.go index b1d0f3243..abd42dc1d 100644 --- a/cache/redis_store.go +++ b/cache/redis.go @@ -10,7 +10,7 @@ import ( "github.com/goravel/framework/facades" "github.com/go-redis/redis/v8" - "github.com/gookit/color" + "github.com/pkg/errors" ) type Redis struct { @@ -19,15 +19,15 @@ type Redis struct { redis *redis.Client } -func NewRedis(ctx context.Context) cache.Store { - connection := facades.Config.GetString("cache.stores." + facades.Config.GetString("cache.default") + ".connection") +func NewRedis(ctx context.Context) (*Redis, error) { + connection := facades.Config.GetString(fmt.Sprintf("cache.stores.%s.connection", facades.Config.GetString("cache.default"))) if connection == "" { connection = "default" } host := facades.Config.GetString("database.redis." + connection + ".host") if host == "" { - return nil + return nil, nil } client := redis.NewClient(&redis.Options{ @@ -36,22 +36,62 @@ func NewRedis(ctx context.Context) cache.Store { DB: facades.Config.GetInt("database.redis." + connection + ".database"), }) - _, err := client.Ping(context.Background()).Result() - if err != nil { - color.Redln(fmt.Sprintf("[Cache] Init connection error, %s", err.Error())) - - return nil + if _, err := client.Ping(context.Background()).Result(); err != nil { + return nil, errors.WithMessage(err, "init connection error") } return &Redis{ ctx: ctx, - prefix: facades.Config.GetString("cache.prefix" + ":"), + prefix: facades.Config.GetString("cache.prefix") + ":", redis: client, - } + }, nil } func (r *Redis) WithContext(ctx context.Context) cache.Store { - return NewRedis(ctx) + store, _ := NewRedis(ctx) + + return store +} + +//Add Store an item in the cache if the key does not exist. +func (r *Redis) Add(key string, value interface{}, seconds time.Duration) bool { + val, err := r.redis.SetNX(r.ctx, r.prefix+key, value, seconds).Result() + if err != nil { + return false + } + + return val +} + +//Forever Store an item in the cache indefinitely. +func (r *Redis) Forever(key string, value interface{}) bool { + if err := r.Put(key, value, 0); err != nil { + return false + } + + return true +} + +//Forget Remove an item from the cache. +func (r *Redis) Forget(key string) bool { + _, err := r.redis.Del(r.ctx, r.prefix+key).Result() + + if err != nil { + return false + } + + return true +} + +//Flush Remove all items from the cache. +func (r *Redis) Flush() bool { + res, err := r.redis.FlushAll(r.ctx).Result() + + if err != nil || res != "OK" { + return false + } + + return true } //Get Retrieve an item from the cache by key. @@ -107,16 +147,6 @@ func (r *Redis) Has(key string) bool { return true } -//Put Store an item in the cache for a given number of seconds. -func (r *Redis) Put(key string, value interface{}, seconds time.Duration) error { - err := r.redis.Set(r.ctx, r.prefix+key, value, seconds).Err() - if err != nil { - return err - } - - return nil -} - //Pull Retrieve an item from the cache and delete it. func (r *Redis) Pull(key string, def interface{}) interface{} { val, err := r.redis.Get(r.ctx, r.prefix+key).Result() @@ -129,14 +159,14 @@ func (r *Redis) Pull(key string, def interface{}) interface{} { return val } -//Add Store an item in the cache if the key does not exist. -func (r *Redis) Add(key string, value interface{}, seconds time.Duration) bool { - val, err := r.redis.SetNX(r.ctx, r.prefix+key, value, seconds).Result() +//Put Store an item in the cache for a given number of seconds. +func (r *Redis) Put(key string, value interface{}, seconds time.Duration) error { + err := r.redis.Set(r.ctx, r.prefix+key, value, seconds).Err() if err != nil { - return false + return err } - return val + return nil } //Remember Get an item from the cache, or execute the given Closure and store the result. @@ -172,34 +202,3 @@ func (r *Redis) RememberForever(key string, callback func() interface{}) (interf return val, nil } - -//Forever Store an item in the cache indefinitely. -func (r *Redis) Forever(key string, value interface{}) bool { - if err := r.Put(key, value, 0); err != nil { - return false - } - - return true -} - -//Forget Remove an item from the cache. -func (r *Redis) Forget(key string) bool { - _, err := r.redis.Del(r.ctx, r.prefix+key).Result() - - if err != nil { - return false - } - - return true -} - -//Flush Remove all items from the cache. -func (r *Redis) Flush() bool { - res, err := r.redis.FlushAll(r.ctx).Result() - - if err != nil || res != "OK" { - return false - } - - return true -} diff --git a/cache/redis_store_test.go b/cache/redis_store_test.go deleted file mode 100644 index 5efa67ec7..000000000 --- a/cache/redis_store_test.go +++ /dev/null @@ -1,142 +0,0 @@ -package cache - -import ( - "context" - "testing" - "time" - - "github.com/goravel/framework/contracts/cache" - - "github.com/go-redis/redis/v8" - "github.com/stretchr/testify/assert" -) - -func instance() cache.Store { - return &Redis{redis: redis.NewClient(&redis.Options{ - Addr: "127.0.0.1:6379", - Password: "", - DB: 0, - }), - prefix: "goravel_cache:", - ctx: context.Background(), - } -} - -func TestGet(t *testing.T) { - r := instance() - - assert.Equal(t, "default", r.Get("test-get", "default").(string)) - assert.Equal(t, "default", r.Get("test-get", func() interface{} { - return "default" - }).(string)) -} - -func TestGetBool(t *testing.T) { - r := instance() - - assert.Equal(t, true, r.GetBool("test-get-bool", true)) - assert.Nil(t, r.Put("test-get-bool", true, 2*time.Second)) - assert.Equal(t, true, r.GetBool("test-get-bool", false)) -} - -func TestGetInt(t *testing.T) { - r := instance() - - assert.Equal(t, 2, r.GetInt("test-get-int", 2)) - assert.Nil(t, r.Put("test-get-int", 3, 2*time.Second)) - assert.Equal(t, 3, r.GetInt("test-get-int", 2)) -} - -func TestGetString(t *testing.T) { - r := instance() - - assert.Equal(t, "2", r.GetString("test-get-string", "2")) - assert.Nil(t, r.Put("test-get-string", "3", 2*time.Second)) - assert.Equal(t, "3", r.GetString("test-get-string", "2")) -} - -func TestHas(t *testing.T) { - r := instance() - - assert.False(t, r.Has("test-has")) - err := r.Put("test-has", "goravel", 5*time.Second) - assert.Nil(t, err) - assert.True(t, r.Has("test-has")) -} - -func TestPut(t *testing.T) { - r := instance() - - assert.Nil(t, r.Put("test-put", "goravel", 5*time.Second)) - assert.True(t, r.Has("test-put")) - assert.Equal(t, "goravel", r.Get("test-put", "default")) -} - -func TestPull(t *testing.T) { - r := instance() - - assert.Nil(t, r.Put("test-put", "goravel", 5*time.Second)) - assert.True(t, r.Has("test-put")) - assert.Equal(t, "goravel", r.Get("test-put", "default")) -} - -func TestAdd(t *testing.T) { - r := instance() - - assert.True(t, r.Add("test-add", "goravel", 5*time.Second)) - assert.True(t, r.Has("test-put")) - assert.False(t, r.Add("test-add", "goravel", 5*time.Second)) -} - -func TestRemember(t *testing.T) { - r := instance() - - val, err := r.Remember("test-remember", 5*time.Second, func() interface{} { - return "goravel" - }) - - assert.Nil(t, err) - assert.Equal(t, "goravel", val.(string)) -} - -func TestRememberForever(t *testing.T) { - r := instance() - - val, err := r.RememberForever("test-remember-forever", func() interface{} { - return "goravel" - }) - - assert.Nil(t, err) - assert.Equal(t, "goravel", val.(string)) -} - -func TestForever(t *testing.T) { - r := instance() - - val := r.Forever("test-forever", "goravel") - - assert.True(t, val) - assert.Equal(t, "goravel", r.Get("test-forever", nil)) -} - -func TestForget(t *testing.T) { - r := instance() - - val := r.Forget("test-forget") - assert.True(t, val) - - err := r.Put("test-forget", "goravel", 5*time.Second) - assert.Nil(t, err) - assert.True(t, r.Forget("test-forget")) -} - -func TestFlush(t *testing.T) { - r := instance() - - err := r.Put("test-flush", "goravel", 5*time.Second) - assert.Nil(t, err) - assert.Equal(t, "goravel", r.Get("test-flush", nil).(string)) - - r.Flush() - assert.False(t, r.Has("test-flush")) -} diff --git a/config/application.go b/config/application.go index 0711e5ddb..922680654 100644 --- a/config/application.go +++ b/config/application.go @@ -3,27 +3,27 @@ package config import ( "os" - "github.com/goravel/framework/contracts/config" - "github.com/goravel/framework/support/file" - "github.com/goravel/framework/testing" - "github.com/gookit/color" "github.com/spf13/cast" "github.com/spf13/viper" + + "github.com/goravel/framework/support/file" + "github.com/goravel/framework/testing" ) type Application struct { vip *viper.Viper } -func (app *Application) Init() config.Config { - if !file.Exists(".env") { +func NewApplication(envPath string) *Application { + if !file.Exists(envPath) { color.Redln("Please create .env and initialize it first\nRun command: \ncp .env.example .env && go run . artisan key:generate") os.Exit(0) } + app := &Application{} app.vip = viper.New() - app.vip.SetConfigName(".env") + app.vip.SetConfigName(envPath) app.vip.SetConfigType("env") app.vip.AddConfigPath(".") err := app.vip.ReadInConfig() diff --git a/config/application_test.go b/config/application_test.go index b57db6085..d25625a36 100644 --- a/config/application_test.go +++ b/config/application_test.go @@ -1,78 +1,70 @@ package config import ( - "os" "testing" - "github.com/goravel/framework/testing/file" - "github.com/stretchr/testify/assert" + "github.com/gookit/color" + "github.com/stretchr/testify/suite" + + "github.com/goravel/framework/support/file" ) -func TestInit(t *testing.T) { - err := file.CreateEnv() - assert.Nil(t, err) - assert.NotPanics(t, func() { - app := Application{} - app.Init() +type ApplicationTestSuite struct { + suite.Suite + config *Application +} + +func TestApplicationTestSuite(t *testing.T) { + if !file.Exists("../.env") { + color.Redln("No config tests run, need create .env based on .env.example, then initialize it") + return + } + + suite.Run(t, &ApplicationTestSuite{ + config: NewApplication("../.env"), }) } -func TestEnv(t *testing.T) { - app := Application{} - app.Init() +func (s *ApplicationTestSuite) SetupTest() { + +} - assert.Equal(t, "goravel", app.Env("APP_NAME").(string)) - assert.Equal(t, "127.0.0.1", app.Env("DB_HOST", "127.0.0.1").(string)) +func (s *ApplicationTestSuite) TestEnv() { + s.Equal("goravel", s.config.Env("APP_NAME", "goravel").(string)) + s.Equal("127.0.0.1", s.config.Env("DB_HOST", "127.0.0.1").(string)) } -func TestAdd(t *testing.T) { - app := Application{} - app.Init() - app.Add("app", map[string]interface{}{ +func (s *ApplicationTestSuite) TestAdd() { + s.config.Add("app", map[string]any{ "env": "local", }) - assert.Equal(t, "local", app.GetString("app.env")) + s.Equal("local", s.config.GetString("app.env")) } -func TestGet(t *testing.T) { - app := Application{} - app.Init() - - assert.Equal(t, "goravel", app.Get("APP_NAME").(string)) +func (s *ApplicationTestSuite) TestGet() { + s.Equal("goravel", s.config.Get("APP_NAME", "goravel").(string)) } -func TestGetString(t *testing.T) { - app := Application{} - app.Init() - - app.Add("database", map[string]interface{}{ - "default": app.Env("DB_CONNECTION", "mysql"), - "connections": map[string]interface{}{ - "mysql": map[string]interface{}{ - "host": app.Env("DB_HOST", "127.0.0.1"), +func (s *ApplicationTestSuite) TestGetString() { + s.config.Add("database", map[string]any{ + "default": s.config.Env("DB_CONNECTION", "mysql"), + "connections": map[string]any{ + "mysql": map[string]any{ + "host": s.config.Env("DB_HOST", "127.0.0.1"), }, }, }) - assert.Equal(t, "goravel", app.GetString("APP_NAME")) - assert.Equal(t, "127.0.0.1", app.GetString("database.connections.mysql.host")) - assert.Equal(t, "mysql", app.GetString("database.default")) + s.Equal("goravel", s.config.GetString("APP_NAME", "goravel")) + s.Equal("127.0.0.1", s.config.GetString("database.connections.mysql.host")) + s.Equal("mysql", s.config.GetString("database.default")) } -func TestGetInt(t *testing.T) { - app := Application{} - app.Init() - - assert.Equal(t, app.GetInt("DB_PORT"), 3306) +func (s *ApplicationTestSuite) TestGetInt() { + s.Equal(s.config.GetInt("DB_PORT", 3306), 3306) } -func TestGetBool(t *testing.T) { - app := Application{} - app.Init() - - assert.Equal(t, true, app.GetBool("APP_DEBUG")) - - err := os.Remove(".env") - assert.Nil(t, err) +func (s *ApplicationTestSuite) TestGetBool() { + s.Equal(true, s.config.GetBool("APP_DEBUG", true)) } diff --git a/config/service_provider.go b/config/service_provider.go index 1ca42fb99..59224a4c2 100644 --- a/config/service_provider.go +++ b/config/service_provider.go @@ -8,8 +8,7 @@ type ServiceProvider struct { } func (config *ServiceProvider) Register() { - app := Application{} - facades.Config = app.Init() + facades.Config = NewApplication(".env") } func (config *ServiceProvider) Boot() { diff --git a/console/application_test.go b/console/cli_test.go similarity index 75% rename from console/application_test.go rename to console/cli_test.go index 4d4a4294c..8ee0cf9d4 100644 --- a/console/application_test.go +++ b/console/cli_test.go @@ -3,11 +3,24 @@ package console import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/goravel/framework/contracts/console" "github.com/goravel/framework/contracts/console/command" - "github.com/stretchr/testify/assert" ) +var testCommand = 0 + +func TestRun(t *testing.T) { + cli := NewCli() + cli.Register([]console.Command{ + &TestCommand{}, + }) + + cli.Call("test") + assert.Equal(t, 1, testCommand) +} + type TestCommand struct { } @@ -24,24 +37,7 @@ func (receiver *TestCommand) Extend() command.Extend { } func (receiver *TestCommand) Handle(ctx console.Context) error { - return nil -} - -func TestInit(t *testing.T) { - assert.NotPanics(t, func() { - app := Application{} - app.Init() - }) -} + testCommand++ -func TestRun(t *testing.T) { - app := Application{} - cli := app.Init() - cli.Register([]console.Command{ - &TestCommand{}, - }) - - assert.NotPanics(t, func() { - cli.Call("test") - }) + return nil } diff --git a/console/service_provider.go b/console/service_provider.go index e8399b733..6550a5012 100644 --- a/console/service_provider.go +++ b/console/service_provider.go @@ -14,8 +14,7 @@ func (receiver *ServiceProvider) Boot() { } func (receiver *ServiceProvider) Register() { - app := Application{} - facades.Artisan = app.Init() + facades.Artisan = NewCli() } func (receiver *ServiceProvider) registerCommands() { diff --git a/contracts/auth/access/mocks/Gate.go b/contracts/auth/access/mocks/Gate.go index cd24c496a..29fa923e3 100644 --- a/contracts/auth/access/mocks/Gate.go +++ b/contracts/auth/access/mocks/Gate.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -118,13 +118,13 @@ func (_m *Gate) WithContext(ctx context.Context) access.Gate { return r0 } -type NewGateT interface { +type mockConstructorTestingTNewGate interface { mock.TestingT Cleanup(func()) } // NewGate creates a new instance of Gate. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewGate(t NewGateT) *Gate { +func NewGate(t mockConstructorTestingTNewGate) *Gate { mock := &Gate{} mock.Mock.Test(t) diff --git a/contracts/auth/mocks/Auth.go b/contracts/auth/mocks/Auth.go index 5a4c6fa84..198441045 100644 --- a/contracts/auth/mocks/Auth.go +++ b/contracts/auth/mocks/Auth.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -135,13 +135,13 @@ func (_m *Auth) User(ctx http.Context, user interface{}) error { return r0 } -type NewAuthT interface { +type mockConstructorTestingTNewAuth interface { mock.TestingT Cleanup(func()) } // NewAuth creates a new instance of Auth. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewAuth(t NewAuthT) *Auth { +func NewAuth(t mockConstructorTestingTNewAuth) *Auth { mock := &Auth{} mock.Mock.Test(t) diff --git a/contracts/cache/mocks/Store.go b/contracts/cache/mocks/Store.go index de0104855..c4a5f5ff2 100644 --- a/contracts/cache/mocks/Store.go +++ b/contracts/cache/mocks/Store.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -237,13 +237,13 @@ func (_m *Store) WithContext(ctx context.Context) cache.Store { return r0 } -type NewStoreT interface { +type mockConstructorTestingTNewStore interface { mock.TestingT Cleanup(func()) } // NewStore creates a new instance of Store. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewStore(t NewStoreT) *Store { +func NewStore(t mockConstructorTestingTNewStore) *Store { mock := &Store{} mock.Mock.Test(t) diff --git a/contracts/config/config.go b/contracts/config/config.go index 46353cbcb..90c91d868 100644 --- a/contracts/config/config.go +++ b/contracts/config/config.go @@ -3,15 +3,15 @@ package config //go:generate mockery --name=Config type Config interface { //Env Get config from env. - Env(envName string, defaultValue ...interface{}) interface{} + Env(envName string, defaultValue ...any) any //Add config to application. - Add(name string, configuration map[string]interface{}) + Add(name string, configuration map[string]any) //Get config from application. - Get(path string, defaultValue ...interface{}) interface{} + Get(path string, defaultValue ...any) any //GetString Get string type config from application. - GetString(path string, defaultValue ...interface{}) string + GetString(path string, defaultValue ...any) string //GetInt Get int type config from application. - GetInt(path string, defaultValue ...interface{}) int + GetInt(path string, defaultValue ...any) int //GetBool Get bool type config from application. - GetBool(path string, defaultValue ...interface{}) bool + GetBool(path string, defaultValue ...any) bool } diff --git a/contracts/config/mocks/Config.go b/contracts/config/mocks/Config.go index a728a4f5b..3502041ce 100644 --- a/contracts/config/mocks/Config.go +++ b/contracts/config/mocks/Config.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -103,13 +103,13 @@ func (_m *Config) GetString(path string, defaultValue ...interface{}) string { return r0 } -type NewConfigT interface { +type mockConstructorTestingTNewConfig interface { mock.TestingT Cleanup(func()) } // NewConfig creates a new instance of Config. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewConfig(t NewConfigT) *Config { +func NewConfig(t mockConstructorTestingTNewConfig) *Config { mock := &Config{} mock.Mock.Test(t) diff --git a/contracts/console/mocks/Artisan.go b/contracts/console/mocks/Artisan.go index 92d41bc16..15a5d588f 100644 --- a/contracts/console/mocks/Artisan.go +++ b/contracts/console/mocks/Artisan.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -32,13 +32,13 @@ func (_m *Artisan) Run(args []string, exitIfArtisan bool) { _m.Called(args, exitIfArtisan) } -type NewArtisanT interface { +type mockConstructorTestingTNewArtisan interface { mock.TestingT Cleanup(func()) } // NewArtisan creates a new instance of Artisan. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewArtisan(t NewArtisanT) *Artisan { +func NewArtisan(t mockConstructorTestingTNewArtisan) *Artisan { mock := &Artisan{} mock.Mock.Test(t) diff --git a/contracts/console/mocks/Context.go b/contracts/console/mocks/Context.go index f190b1449..ab9841340 100644 --- a/contracts/console/mocks/Context.go +++ b/contracts/console/mocks/Context.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -53,13 +53,13 @@ func (_m *Context) Option(key string) string { return r0 } -type NewContextT interface { +type mockConstructorTestingTNewContext interface { mock.TestingT Cleanup(func()) } // NewContext creates a new instance of Context. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewContext(t NewContextT) *Context { +func NewContext(t mockConstructorTestingTNewContext) *Context { mock := &Context{} mock.Mock.Test(t) diff --git a/contracts/database/orm/mocks/Association.go b/contracts/database/orm/mocks/Association.go new file mode 100644 index 000000000..f90514b28 --- /dev/null +++ b/contracts/database/orm/mocks/Association.go @@ -0,0 +1,118 @@ +// Code generated by mockery v2.14.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// Association is an autogenerated mock type for the Association type +type Association struct { + mock.Mock +} + +// Append provides a mock function with given fields: values +func (_m *Association) Append(values ...interface{}) error { + var _ca []interface{} + _ca = append(_ca, values...) + ret := _m.Called(_ca...) + + var r0 error + if rf, ok := ret.Get(0).(func(...interface{}) error); ok { + r0 = rf(values...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Clear provides a mock function with given fields: +func (_m *Association) Clear() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Count provides a mock function with given fields: +func (_m *Association) Count() int64 { + ret := _m.Called() + + var r0 int64 + if rf, ok := ret.Get(0).(func() int64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int64) + } + + return r0 +} + +// Delete provides a mock function with given fields: values +func (_m *Association) Delete(values ...interface{}) error { + var _ca []interface{} + _ca = append(_ca, values...) + ret := _m.Called(_ca...) + + var r0 error + if rf, ok := ret.Get(0).(func(...interface{}) error); ok { + r0 = rf(values...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Find provides a mock function with given fields: out, conds +func (_m *Association) Find(out interface{}, conds ...interface{}) error { + var _ca []interface{} + _ca = append(_ca, out) + _ca = append(_ca, conds...) + ret := _m.Called(_ca...) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}, ...interface{}) error); ok { + r0 = rf(out, conds...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Replace provides a mock function with given fields: values +func (_m *Association) Replace(values ...interface{}) error { + var _ca []interface{} + _ca = append(_ca, values...) + ret := _m.Called(_ca...) + + var r0 error + if rf, ok := ret.Get(0).(func(...interface{}) error); ok { + r0 = rf(values...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type mockConstructorTestingTNewAssociation interface { + mock.TestingT + Cleanup(func()) +} + +// NewAssociation creates a new instance of Association. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewAssociation(t mockConstructorTestingTNewAssociation) *Association { + mock := &Association{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/contracts/database/orm/mocks/DB.go b/contracts/database/orm/mocks/DB.go index 70334f528..34fe6fe3f 100644 --- a/contracts/database/orm/mocks/DB.go +++ b/contracts/database/orm/mocks/DB.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -12,6 +12,22 @@ type DB struct { mock.Mock } +// Association provides a mock function with given fields: association +func (_m *DB) Association(association string) orm.Association { + ret := _m.Called(association) + + var r0 orm.Association + if rf, ok := ret.Get(0).(func(string) orm.Association); ok { + r0 = rf(association) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(orm.Association) + } + } + + return r0 +} + // Begin provides a mock function with given fields: func (_m *DB) Begin() (orm.Transaction, error) { ret := _m.Called() @@ -278,6 +294,40 @@ func (_m *DB) Limit(limit int) orm.Query { return r0 } +// Load provides a mock function with given fields: dest, relation, args +func (_m *DB) Load(dest interface{}, relation string, args ...interface{}) error { + var _ca []interface{} + _ca = append(_ca, dest, relation) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}, string, ...interface{}) error); ok { + r0 = rf(dest, relation, args...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// LoadMissing provides a mock function with given fields: dest, relation, args +func (_m *DB) LoadMissing(dest interface{}, relation string, args ...interface{}) error { + var _ca []interface{} + _ca = append(_ca, dest, relation) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}, string, ...interface{}) error); ok { + r0 = rf(dest, relation, args...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // Model provides a mock function with given fields: value func (_m *DB) Model(value interface{}) orm.Query { ret := _m.Called(value) @@ -310,6 +360,28 @@ func (_m *DB) Offset(offset int) orm.Query { return r0 } +// Omit provides a mock function with given fields: columns +func (_m *DB) Omit(columns ...string) orm.Query { + _va := make([]interface{}, len(columns)) + for _i := range columns { + _va[_i] = columns[_i] + } + var _ca []interface{} + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 orm.Query + if rf, ok := ret.Get(0).(func(...string) orm.Query); ok { + r0 = rf(columns...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(orm.Query) + } + } + + return r0 +} + // OrWhere provides a mock function with given fields: query, args func (_m *DB) OrWhere(query interface{}, args ...interface{}) orm.Query { var _ca []interface{} @@ -513,6 +585,25 @@ func (_m *DB) Where(query interface{}, args ...interface{}) orm.Query { return r0 } +// With provides a mock function with given fields: query, args +func (_m *DB) With(query string, args ...interface{}) orm.Query { + var _ca []interface{} + _ca = append(_ca, query) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + var r0 orm.Query + if rf, ok := ret.Get(0).(func(string, ...interface{}) orm.Query); ok { + r0 = rf(query, args...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(orm.Query) + } + } + + return r0 +} + // WithTrashed provides a mock function with given fields: func (_m *DB) WithTrashed() orm.Query { ret := _m.Called() @@ -529,13 +620,13 @@ func (_m *DB) WithTrashed() orm.Query { return r0 } -type NewDBT interface { +type mockConstructorTestingTNewDB interface { mock.TestingT Cleanup(func()) } // NewDB creates a new instance of DB. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewDB(t NewDBT) *DB { +func NewDB(t mockConstructorTestingTNewDB) *DB { mock := &DB{} mock.Mock.Test(t) diff --git a/contracts/database/orm/mocks/Orm.go b/contracts/database/orm/mocks/Orm.go index 29fa4799b..5f2cec934 100644 --- a/contracts/database/orm/mocks/Orm.go +++ b/contracts/database/orm/mocks/Orm.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -7,6 +7,8 @@ import ( orm "github.com/goravel/framework/contracts/database/orm" mock "github.com/stretchr/testify/mock" + + sql "database/sql" ) // Orm is an autogenerated mock type for the Orm type @@ -30,6 +32,29 @@ func (_m *Orm) Connection(name string) orm.Orm { return r0 } +// DB provides a mock function with given fields: +func (_m *Orm) DB() (*sql.DB, error) { + ret := _m.Called() + + var r0 *sql.DB + if rf, ok := ret.Get(0).(func() *sql.DB); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sql.DB) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Query provides a mock function with given fields: func (_m *Orm) Query() orm.DB { ret := _m.Called() @@ -76,13 +101,13 @@ func (_m *Orm) WithContext(ctx context.Context) orm.Orm { return r0 } -type NewOrmT interface { +type mockConstructorTestingTNewOrm interface { mock.TestingT Cleanup(func()) } // NewOrm creates a new instance of Orm. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewOrm(t NewOrmT) *Orm { +func NewOrm(t mockConstructorTestingTNewOrm) *Orm { mock := &Orm{} mock.Mock.Test(t) diff --git a/contracts/database/orm/mocks/Transaction.go b/contracts/database/orm/mocks/Transaction.go index 40433811b..0512f0ace 100644 --- a/contracts/database/orm/mocks/Transaction.go +++ b/contracts/database/orm/mocks/Transaction.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -12,6 +12,22 @@ type Transaction struct { mock.Mock } +// Association provides a mock function with given fields: association +func (_m *Transaction) Association(association string) orm.Association { + ret := _m.Called(association) + + var r0 orm.Association + if rf, ok := ret.Get(0).(func(string) orm.Association); ok { + r0 = rf(association) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(orm.Association) + } + } + + return r0 +} + // Commit provides a mock function with given fields: func (_m *Transaction) Commit() error { ret := _m.Called() @@ -269,6 +285,40 @@ func (_m *Transaction) Limit(limit int) orm.Query { return r0 } +// Load provides a mock function with given fields: dest, relation, args +func (_m *Transaction) Load(dest interface{}, relation string, args ...interface{}) error { + var _ca []interface{} + _ca = append(_ca, dest, relation) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}, string, ...interface{}) error); ok { + r0 = rf(dest, relation, args...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// LoadMissing provides a mock function with given fields: dest, relation, args +func (_m *Transaction) LoadMissing(dest interface{}, relation string, args ...interface{}) error { + var _ca []interface{} + _ca = append(_ca, dest, relation) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}, string, ...interface{}) error); ok { + r0 = rf(dest, relation, args...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // Model provides a mock function with given fields: value func (_m *Transaction) Model(value interface{}) orm.Query { ret := _m.Called(value) @@ -301,6 +351,28 @@ func (_m *Transaction) Offset(offset int) orm.Query { return r0 } +// Omit provides a mock function with given fields: columns +func (_m *Transaction) Omit(columns ...string) orm.Query { + _va := make([]interface{}, len(columns)) + for _i := range columns { + _va[_i] = columns[_i] + } + var _ca []interface{} + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 orm.Query + if rf, ok := ret.Get(0).(func(...string) orm.Query); ok { + r0 = rf(columns...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(orm.Query) + } + } + + return r0 +} + // OrWhere provides a mock function with given fields: query, args func (_m *Transaction) OrWhere(query interface{}, args ...interface{}) orm.Query { var _ca []interface{} @@ -518,6 +590,25 @@ func (_m *Transaction) Where(query interface{}, args ...interface{}) orm.Query { return r0 } +// With provides a mock function with given fields: query, args +func (_m *Transaction) With(query string, args ...interface{}) orm.Query { + var _ca []interface{} + _ca = append(_ca, query) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + var r0 orm.Query + if rf, ok := ret.Get(0).(func(string, ...interface{}) orm.Query); ok { + r0 = rf(query, args...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(orm.Query) + } + } + + return r0 +} + // WithTrashed provides a mock function with given fields: func (_m *Transaction) WithTrashed() orm.Query { ret := _m.Called() @@ -534,13 +625,13 @@ func (_m *Transaction) WithTrashed() orm.Query { return r0 } -type NewTransactionT interface { +type mockConstructorTestingTNewTransaction interface { mock.TestingT Cleanup(func()) } // NewTransaction creates a new instance of Transaction. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewTransaction(t NewTransactionT) *Transaction { +func NewTransaction(t mockConstructorTestingTNewTransaction) *Transaction { mock := &Transaction{} mock.Mock.Test(t) diff --git a/contracts/database/orm/orm.go b/contracts/database/orm/orm.go index d55a42d77..4b17825d4 100644 --- a/contracts/database/orm/orm.go +++ b/contracts/database/orm/orm.go @@ -1,10 +1,14 @@ package orm -import "context" +import ( + "context" + "database/sql" +) //go:generate mockery --name=Orm type Orm interface { Connection(name string) Orm + DB() (*sql.DB, error) Query() DB Transaction(txFunc func(tx Transaction) error) error WithContext(ctx context.Context) Orm @@ -24,34 +28,49 @@ type Transaction interface { } type Query interface { + Association(association string) Association Driver() Driver Count(count *int64) error - Create(value interface{}) error - Delete(value interface{}, conds ...interface{}) error - Distinct(args ...interface{}) Query - Exec(sql string, values ...interface{}) error - Find(dest interface{}, conds ...interface{}) error - First(dest interface{}) error - FirstOrCreate(dest interface{}, conds ...interface{}) error - ForceDelete(value interface{}, conds ...interface{}) error - Get(dest interface{}) error + Create(value any) error + Delete(value any, conds ...any) error + Distinct(args ...any) Query + Exec(sql string, values ...any) error + Find(dest any, conds ...any) error + First(dest any) error + FirstOrCreate(dest any, conds ...any) error + ForceDelete(value any, conds ...any) error + Get(dest any) error Group(name string) Query - Having(query interface{}, args ...interface{}) Query - Join(query string, args ...interface{}) Query + Having(query any, args ...any) Query + Join(query string, args ...any) Query Limit(limit int) Query - Model(value interface{}) Query + Load(dest any, relation string, args ...any) error + LoadMissing(dest any, relation string, args ...any) error + Model(value any) Query Offset(offset int) Query - Order(value interface{}) Query - OrWhere(query interface{}, args ...interface{}) Query - Pluck(column string, dest interface{}) error - Raw(sql string, values ...interface{}) Query - Save(value interface{}) error - Scan(dest interface{}) error + Omit(columns ...string) Query + Order(value any) Query + OrWhere(query any, args ...any) Query + Pluck(column string, dest any) error + Raw(sql string, values ...any) Query + Save(value any) error + Scan(dest any) error Scopes(funcs ...func(Query) Query) Query - Select(query interface{}, args ...interface{}) Query - Table(name string, args ...interface{}) Query - Update(column string, value interface{}) error - Updates(values interface{}) error - Where(query interface{}, args ...interface{}) Query + Select(query any, args ...any) Query + Table(name string, args ...any) Query + Update(column string, value any) error + Updates(values any) error + Where(query any, args ...any) Query WithTrashed() Query + With(query string, args ...any) Query +} + +//go:generate mockery --name=Association +type Association interface { + Find(out any, conds ...any) error + Append(values ...any) error + Replace(values ...any) error + Delete(values ...any) error + Clear() error + Count() int64 } diff --git a/contracts/event/mocks/Instance.go b/contracts/event/mocks/Instance.go index deb25f978..fa32f2f23 100644 --- a/contracts/event/mocks/Instance.go +++ b/contracts/event/mocks/Instance.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -49,13 +49,13 @@ func (_m *Instance) Register(_a0 map[event.Event][]event.Listener) { _m.Called(_a0) } -type NewInstanceT interface { +type mockConstructorTestingTNewInstance interface { mock.TestingT Cleanup(func()) } // NewInstance creates a new instance of Instance. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewInstance(t NewInstanceT) *Instance { +func NewInstance(t mockConstructorTestingTNewInstance) *Instance { mock := &Instance{} mock.Mock.Test(t) diff --git a/contracts/event/mocks/Task.go b/contracts/event/mocks/Task.go index cd70ee491..d7672d2f1 100644 --- a/contracts/event/mocks/Task.go +++ b/contracts/event/mocks/Task.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -23,13 +23,13 @@ func (_m *Task) Dispatch() error { return r0 } -type NewTaskT interface { +type mockConstructorTestingTNewTask interface { mock.TestingT Cleanup(func()) } // NewTask creates a new instance of Task. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewTask(t NewTaskT) *Task { +func NewTask(t mockConstructorTestingTNewTask) *Task { mock := &Task{} mock.Mock.Test(t) diff --git a/contracts/filesystem/mocks/Driver.go b/contracts/filesystem/mocks/Driver.go index afba2cebf..4da62c7ca 100644 --- a/contracts/filesystem/mocks/Driver.go +++ b/contracts/filesystem/mocks/Driver.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -375,13 +375,13 @@ func (_m *Driver) WithContext(ctx context.Context) filesystem.Driver { return r0 } -type NewDriverT interface { +type mockConstructorTestingTNewDriver interface { mock.TestingT Cleanup(func()) } // NewDriver creates a new instance of Driver. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewDriver(t NewDriverT) *Driver { +func NewDriver(t mockConstructorTestingTNewDriver) *Driver { mock := &Driver{} mock.Mock.Test(t) diff --git a/contracts/filesystem/mocks/File.go b/contracts/filesystem/mocks/File.go index 9a5adf687..b2f3fe86c 100644 --- a/contracts/filesystem/mocks/File.go +++ b/contracts/filesystem/mocks/File.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -153,13 +153,13 @@ func (_m *File) StoreAs(path string, name string) (string, error) { return r0, r1 } -type NewFileT interface { +type mockConstructorTestingTNewFile interface { mock.TestingT Cleanup(func()) } // NewFile creates a new instance of File. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewFile(t NewFileT) *File { +func NewFile(t mockConstructorTestingTNewFile) *File { mock := &File{} mock.Mock.Test(t) diff --git a/contracts/filesystem/mocks/Storage.go b/contracts/filesystem/mocks/Storage.go index da6c47c22..ac2369ec3 100644 --- a/contracts/filesystem/mocks/Storage.go +++ b/contracts/filesystem/mocks/Storage.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -391,13 +391,13 @@ func (_m *Storage) WithContext(ctx context.Context) filesystem.Driver { return r0 } -type NewStorageT interface { +type mockConstructorTestingTNewStorage interface { mock.TestingT Cleanup(func()) } // NewStorage creates a new instance of Storage. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewStorage(t NewStorageT) *Storage { +func NewStorage(t mockConstructorTestingTNewStorage) *Storage { mock := &Storage{} mock.Mock.Test(t) diff --git a/contracts/grpc/mocks/Grpc.go b/contracts/grpc/mocks/Grpc.go index c8e65e12a..60c525bc2 100644 --- a/contracts/grpc/mocks/Grpc.go +++ b/contracts/grpc/mocks/Grpc.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -78,13 +78,13 @@ func (_m *Grpc) UnaryServerInterceptors(_a0 []grpc.UnaryServerInterceptor) { _m.Called(_a0) } -type NewGrpcT interface { +type mockConstructorTestingTNewGrpc interface { mock.TestingT Cleanup(func()) } // NewGrpc creates a new instance of Grpc. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewGrpc(t NewGrpcT) *Grpc { +func NewGrpc(t mockConstructorTestingTNewGrpc) *Grpc { mock := &Grpc{} mock.Mock.Test(t) diff --git a/contracts/http/mocks/Context.go b/contracts/http/mocks/Context.go index 559bf70a8..83ae14173 100644 --- a/contracts/http/mocks/Context.go +++ b/contracts/http/mocks/Context.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -136,13 +136,13 @@ func (_m *Context) WithValue(key string, value interface{}) { _m.Called(key, value) } -type NewContextT interface { +type mockConstructorTestingTNewContext interface { mock.TestingT Cleanup(func()) } // NewContext creates a new instance of Context. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewContext(t NewContextT) *Context { +func NewContext(t mockConstructorTestingTNewContext) *Context { mock := &Context{} mock.Mock.Test(t) diff --git a/contracts/http/mocks/Request.go b/contracts/http/mocks/Request.go index e71eca3c9..787d17e93 100644 --- a/contracts/http/mocks/Request.go +++ b/contracts/http/mocks/Request.go @@ -1,15 +1,16 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks import ( filesystem "github.com/goravel/framework/contracts/filesystem" http "github.com/goravel/framework/contracts/http" - validate "github.com/goravel/framework/contracts/validation" mock "github.com/stretchr/testify/mock" nethttp "net/http" + + validation "github.com/goravel/framework/contracts/validation" ) // Request is an autogenerated mock type for the Request type @@ -213,16 +214,32 @@ func (_m *Request) Query(key string, defaultValue string) string { return r0 } -// Response provides a mock function with given fields: -func (_m *Request) Response() http.Response { - ret := _m.Called() +// QueryArray provides a mock function with given fields: key +func (_m *Request) QueryArray(key string) []string { + ret := _m.Called(key) - var r0 http.Response - if rf, ok := ret.Get(0).(func() http.Response); ok { - r0 = rf() + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(key) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + return r0 +} + +// QueryMap provides a mock function with given fields: key +func (_m *Request) QueryMap(key string) map[string]string { + ret := _m.Called(key) + + var r0 map[string]string + if rf, ok := ret.Get(0).(func(string) map[string]string); ok { + r0 = rf(key) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(http.Response) + r0 = ret.Get(0).(map[string]string) } } @@ -244,7 +261,7 @@ func (_m *Request) Url() string { } // Validate provides a mock function with given fields: rules, options -func (_m *Request) Validate(rules map[string]string, options ...validate.Option) (validate.Validator, error) { +func (_m *Request) Validate(rules map[string]string, options ...validation.Option) (validation.Validator, error) { _va := make([]interface{}, len(options)) for _i := range options { _va[_i] = options[_i] @@ -254,17 +271,17 @@ func (_m *Request) Validate(rules map[string]string, options ...validate.Option) _ca = append(_ca, _va...) ret := _m.Called(_ca...) - var r0 validate.Validator - if rf, ok := ret.Get(0).(func(map[string]string, ...validate.Option) validate.Validator); ok { + var r0 validation.Validator + if rf, ok := ret.Get(0).(func(map[string]string, ...validation.Option) validation.Validator); ok { r0 = rf(rules, options...) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(validate.Validator) + r0 = ret.Get(0).(validation.Validator) } } var r1 error - if rf, ok := ret.Get(1).(func(map[string]string, ...validate.Option) error); ok { + if rf, ok := ret.Get(1).(func(map[string]string, ...validation.Option) error); ok { r1 = rf(rules, options...) } else { r1 = ret.Error(1) @@ -274,15 +291,15 @@ func (_m *Request) Validate(rules map[string]string, options ...validate.Option) } // ValidateRequest provides a mock function with given fields: request -func (_m *Request) ValidateRequest(request http.FormRequest) (validate.Errors, error) { +func (_m *Request) ValidateRequest(request http.FormRequest) (validation.Errors, error) { ret := _m.Called(request) - var r0 validate.Errors - if rf, ok := ret.Get(0).(func(http.FormRequest) validate.Errors); ok { + var r0 validation.Errors + if rf, ok := ret.Get(0).(func(http.FormRequest) validation.Errors); ok { r0 = rf(request) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(validate.Errors) + r0 = ret.Get(0).(validation.Errors) } } @@ -296,13 +313,13 @@ func (_m *Request) ValidateRequest(request http.FormRequest) (validate.Errors, e return r0, r1 } -type NewRequestT interface { +type mockConstructorTestingTNewRequest interface { mock.TestingT Cleanup(func()) } // NewRequest creates a new instance of Request. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewRequest(t NewRequestT) *Request { +func NewRequest(t mockConstructorTestingTNewRequest) *Request { mock := &Request{} mock.Mock.Test(t) diff --git a/contracts/http/mocks/Response.go b/contracts/http/mocks/Response.go index cd9ee9d12..316ecaef8 100644 --- a/contracts/http/mocks/Response.go +++ b/contracts/http/mocks/Response.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -43,6 +43,22 @@ func (_m *Response) Json(code int, obj interface{}) { _m.Called(code, obj) } +// Origin provides a mock function with given fields: +func (_m *Response) Origin() http.ResponseOrigin { + ret := _m.Called() + + var r0 http.ResponseOrigin + if rf, ok := ret.Get(0).(func() http.ResponseOrigin); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(http.ResponseOrigin) + } + } + + return r0 +} + // String provides a mock function with given fields: code, format, values func (_m *Response) String(code int, format string, values ...interface{}) { var _ca []interface{} @@ -67,13 +83,13 @@ func (_m *Response) Success() http.ResponseSuccess { return r0 } -type NewResponseT interface { +type mockConstructorTestingTNewResponse interface { mock.TestingT Cleanup(func()) } // NewResponse creates a new instance of Response. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewResponse(t NewResponseT) *Response { +func NewResponse(t mockConstructorTestingTNewResponse) *Response { mock := &Response{} mock.Mock.Test(t) diff --git a/contracts/http/mocks/ResponseOrigin.go b/contracts/http/mocks/ResponseOrigin.go new file mode 100644 index 000000000..9aa62653f --- /dev/null +++ b/contracts/http/mocks/ResponseOrigin.go @@ -0,0 +1,91 @@ +// Code generated by mockery v2.14.0. DO NOT EDIT. + +package mocks + +import ( + bytes "bytes" + + mock "github.com/stretchr/testify/mock" + + nethttp "net/http" +) + +// ResponseOrigin is an autogenerated mock type for the ResponseOrigin type +type ResponseOrigin struct { + mock.Mock +} + +// Body provides a mock function with given fields: +func (_m *ResponseOrigin) Body() *bytes.Buffer { + ret := _m.Called() + + var r0 *bytes.Buffer + if rf, ok := ret.Get(0).(func() *bytes.Buffer); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*bytes.Buffer) + } + } + + return r0 +} + +// Header provides a mock function with given fields: +func (_m *ResponseOrigin) Header() nethttp.Header { + ret := _m.Called() + + var r0 nethttp.Header + if rf, ok := ret.Get(0).(func() nethttp.Header); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(nethttp.Header) + } + } + + return r0 +} + +// Size provides a mock function with given fields: +func (_m *ResponseOrigin) Size() int { + ret := _m.Called() + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + return r0 +} + +// Status provides a mock function with given fields: +func (_m *ResponseOrigin) Status() int { + ret := _m.Called() + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + return r0 +} + +type mockConstructorTestingTNewResponseOrigin interface { + mock.TestingT + Cleanup(func()) +} + +// NewResponseOrigin creates a new instance of ResponseOrigin. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewResponseOrigin(t mockConstructorTestingTNewResponseOrigin) *ResponseOrigin { + mock := &ResponseOrigin{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/contracts/http/mocks/ResponseSuccess.go b/contracts/http/mocks/ResponseSuccess.go index a13afe00a..4ef1ff074 100644 --- a/contracts/http/mocks/ResponseSuccess.go +++ b/contracts/http/mocks/ResponseSuccess.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -22,13 +22,13 @@ func (_m *ResponseSuccess) String(format string, values ...interface{}) { _m.Called(_ca...) } -type NewResponseSuccessT interface { +type mockConstructorTestingTNewResponseSuccess interface { mock.TestingT Cleanup(func()) } // NewResponseSuccess creates a new instance of ResponseSuccess. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewResponseSuccess(t NewResponseSuccessT) *ResponseSuccess { +func NewResponseSuccess(t mockConstructorTestingTNewResponseSuccess) *ResponseSuccess { mock := &ResponseSuccess{} mock.Mock.Test(t) diff --git a/contracts/http/request.go b/contracts/http/request.go index d80ef3d68..215150504 100644 --- a/contracts/http/request.go +++ b/contracts/http/request.go @@ -21,6 +21,8 @@ type Request interface { Input(key string) string // Query Retrieve a query string item form the request: /users?id=1 Query(key, defaultValue string) string + QueryArray(key string) []string + QueryMap(key string) map[string]string // Form Retrieve a form string item form the post: /users POST:id=1 Form(key, defaultValue string) string Bind(obj any) error @@ -31,7 +33,6 @@ type Request interface { Next() Origin() *http.Request - Response() Response Validate(rules map[string]string, options ...validation.Option) (validation.Validator, error) ValidateRequest(request FormRequest) (validation.Errors, error) diff --git a/contracts/http/response.go b/contracts/http/response.go index 20761db8b..afc03a6f5 100644 --- a/contracts/http/response.go +++ b/contracts/http/response.go @@ -1,19 +1,33 @@ package http -type Json map[string]interface{} +import ( + "bytes" + "net/http" +) + +type Json map[string]any //go:generate mockery --name=Response type Response interface { - String(code int, format string, values ...interface{}) - Json(code int, obj interface{}) + String(code int, format string, values ...any) + Json(code int, obj any) File(filepath string) Download(filepath, filename string) Success() ResponseSuccess Header(key, value string) Response + Origin() ResponseOrigin } //go:generate mockery --name=ResponseSuccess type ResponseSuccess interface { - String(format string, values ...interface{}) - Json(obj interface{}) + String(format string, values ...any) + Json(obj any) +} + +//go:generate mockery --name=ResponseOrigin +type ResponseOrigin interface { + Body() *bytes.Buffer + Header() http.Header + Size() int + Status() int } diff --git a/contracts/log/mocks/Entry.go b/contracts/log/mocks/Entry.go index cd7a1df9d..392be4ab6 100644 --- a/contracts/log/mocks/Entry.go +++ b/contracts/log/mocks/Entry.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -116,13 +116,13 @@ func (_m *Entry) Time() time.Time { return r0 } -type NewEntryT interface { +type mockConstructorTestingTNewEntry interface { mock.TestingT Cleanup(func()) } // NewEntry creates a new instance of Entry. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewEntry(t NewEntryT) *Entry { +func NewEntry(t mockConstructorTestingTNewEntry) *Entry { mock := &Entry{} mock.Mock.Test(t) diff --git a/contracts/log/mocks/Hook.go b/contracts/log/mocks/Hook.go index e3284a4db..8053ed776 100644 --- a/contracts/log/mocks/Hook.go +++ b/contracts/log/mocks/Hook.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -42,13 +42,13 @@ func (_m *Hook) Levels() []log.Level { return r0 } -type NewHookT interface { +type mockConstructorTestingTNewHook interface { mock.TestingT Cleanup(func()) } // NewHook creates a new instance of Hook. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewHook(t NewHookT) *Hook { +func NewHook(t mockConstructorTestingTNewHook) *Hook { mock := &Hook{} mock.Mock.Test(t) diff --git a/contracts/log/mocks/Log.go b/contracts/log/mocks/Log.go index 420347d0a..f669b2563 100644 --- a/contracts/log/mocks/Log.go +++ b/contracts/log/mocks/Log.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -120,13 +120,13 @@ func (_m *Log) WithContext(ctx context.Context) log.Writer { return r0 } -type NewLogT interface { +type mockConstructorTestingTNewLog interface { mock.TestingT Cleanup(func()) } // NewLog creates a new instance of Log. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewLog(t NewLogT) *Log { +func NewLog(t mockConstructorTestingTNewLog) *Log { mock := &Log{} mock.Mock.Test(t) diff --git a/contracts/log/mocks/Logger.go b/contracts/log/mocks/Logger.go index 1869f226e..2a83ee485 100644 --- a/contracts/log/mocks/Logger.go +++ b/contracts/log/mocks/Logger.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -35,13 +35,13 @@ func (_m *Logger) Handle(channel string) (log.Hook, error) { return r0, r1 } -type NewLoggerT interface { +type mockConstructorTestingTNewLogger interface { mock.TestingT Cleanup(func()) } // NewLogger creates a new instance of Logger. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewLogger(t NewLoggerT) *Logger { +func NewLogger(t mockConstructorTestingTNewLogger) *Logger { mock := &Logger{} mock.Mock.Test(t) diff --git a/contracts/log/mocks/Writer.go b/contracts/log/mocks/Writer.go index eb5799b85..7b9c4df5b 100644 --- a/contracts/log/mocks/Writer.go +++ b/contracts/log/mocks/Writer.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -99,13 +99,13 @@ func (_m *Writer) Warningf(format string, args ...interface{}) { _m.Called(_ca...) } -type NewWriterT interface { +type mockConstructorTestingTNewWriter interface { mock.TestingT Cleanup(func()) } // NewWriter creates a new instance of Writer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewWriter(t NewWriterT) *Writer { +func NewWriter(t mockConstructorTestingTNewWriter) *Writer { mock := &Writer{} mock.Mock.Test(t) diff --git a/contracts/mail/mocks/Mail.go b/contracts/mail/mocks/Mail.go index 747211423..40842efd0 100644 --- a/contracts/mail/mocks/Mail.go +++ b/contracts/mail/mocks/Mail.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -136,13 +136,13 @@ func (_m *Mail) To(addresses []string) mail.Mail { return r0 } -type NewMailT interface { +type mockConstructorTestingTNewMail interface { mock.TestingT Cleanup(func()) } // NewMail creates a new instance of Mail. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewMail(t NewMailT) *Mail { +func NewMail(t mockConstructorTestingTNewMail) *Mail { mock := &Mail{} mock.Mock.Test(t) diff --git a/contracts/queue/mocks/Queue.go b/contracts/queue/mocks/Queue.go index 68fa94902..91765e66b 100644 --- a/contracts/queue/mocks/Queue.go +++ b/contracts/queue/mocks/Queue.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -81,13 +81,13 @@ func (_m *Queue) Worker(args *queue.Args) queue.Worker { return r0 } -type NewQueueT interface { +type mockConstructorTestingTNewQueue interface { mock.TestingT Cleanup(func()) } // NewQueue creates a new instance of Queue. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewQueue(t NewQueueT) *Queue { +func NewQueue(t mockConstructorTestingTNewQueue) *Queue { mock := &Queue{} mock.Mock.Test(t) diff --git a/contracts/queue/mocks/Task.go b/contracts/queue/mocks/Task.go index 35685eb72..e7cdd25cb 100644 --- a/contracts/queue/mocks/Task.go +++ b/contracts/queue/mocks/Task.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -72,13 +72,13 @@ func (_m *Task) OnQueue(_a0 string) queue.Task { return r0 } -type NewTaskT interface { +type mockConstructorTestingTNewTask interface { mock.TestingT Cleanup(func()) } // NewTask creates a new instance of Task. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewTask(t NewTaskT) *Task { +func NewTask(t mockConstructorTestingTNewTask) *Task { mock := &Task{} mock.Mock.Test(t) diff --git a/contracts/route/mocks/Engine.go b/contracts/route/mocks/Engine.go index 76afb0c54..c49fb2ca8 100644 --- a/contracts/route/mocks/Engine.go +++ b/contracts/route/mocks/Engine.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -139,13 +139,13 @@ func (_m *Engine) StaticFile(_a0 string, _a1 string) { _m.Called(_a0, _a1) } -type NewEngineT interface { +type mockConstructorTestingTNewEngine interface { mock.TestingT Cleanup(func()) } // NewEngine creates a new instance of Engine. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewEngine(t NewEngineT) *Engine { +func NewEngine(t mockConstructorTestingTNewEngine) *Engine { mock := &Engine{} mock.Mock.Test(t) diff --git a/contracts/route/mocks/Route.go b/contracts/route/mocks/Route.go index efea60e12..42e35b897 100644 --- a/contracts/route/mocks/Route.go +++ b/contracts/route/mocks/Route.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -109,13 +109,13 @@ func (_m *Route) StaticFile(_a0 string, _a1 string) { _m.Called(_a0, _a1) } -type NewRouteT interface { +type mockConstructorTestingTNewRoute interface { mock.TestingT Cleanup(func()) } // NewRoute creates a new instance of Route. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewRoute(t NewRouteT) *Route { +func NewRoute(t mockConstructorTestingTNewRoute) *Route { mock := &Route{} mock.Mock.Test(t) diff --git a/contracts/schedule/mocks/Event.go b/contracts/schedule/mocks/Event.go index 3ea392b31..6a1ced64e 100644 --- a/contracts/schedule/mocks/Event.go +++ b/contracts/schedule/mocks/Event.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -404,13 +404,13 @@ func (_m *Event) SkipIfStillRunning() schedule.Event { return r0 } -type NewEventT interface { +type mockConstructorTestingTNewEvent interface { mock.TestingT Cleanup(func()) } // NewEvent creates a new instance of Event. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewEvent(t NewEventT) *Event { +func NewEvent(t mockConstructorTestingTNewEvent) *Event { mock := &Event{} mock.Mock.Test(t) diff --git a/contracts/schedule/mocks/Schedule.go b/contracts/schedule/mocks/Schedule.go index 8223b6e48..cf6e04cfe 100644 --- a/contracts/schedule/mocks/Schedule.go +++ b/contracts/schedule/mocks/Schedule.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -54,13 +54,13 @@ func (_m *Schedule) Run() { _m.Called() } -type NewScheduleT interface { +type mockConstructorTestingTNewSchedule interface { mock.TestingT Cleanup(func()) } // NewSchedule creates a new instance of Schedule. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewSchedule(t NewScheduleT) *Schedule { +func NewSchedule(t mockConstructorTestingTNewSchedule) *Schedule { mock := &Schedule{} mock.Mock.Test(t) diff --git a/contracts/validation/mocks/Errors.go b/contracts/validation/mocks/Errors.go index df929919b..238fa7d53 100644 --- a/contracts/validation/mocks/Errors.go +++ b/contracts/validation/mocks/Errors.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -75,13 +75,13 @@ func (_m *Errors) One(key ...string) string { return r0 } -type NewErrorsT interface { +type mockConstructorTestingTNewErrors interface { mock.TestingT Cleanup(func()) } // NewErrors creates a new instance of Errors. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewErrors(t NewErrorsT) *Errors { +func NewErrors(t mockConstructorTestingTNewErrors) *Errors { mock := &Errors{} mock.Mock.Test(t) diff --git a/contracts/validation/mocks/Validation.go b/contracts/validation/mocks/Validation.go index c85ea503b..1aee37e62 100644 --- a/contracts/validation/mocks/Validation.go +++ b/contracts/validation/mocks/Validation.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -72,13 +72,13 @@ func (_m *Validation) Rules() []validation.Rule { return r0 } -type NewValidationT interface { +type mockConstructorTestingTNewValidation interface { mock.TestingT Cleanup(func()) } // NewValidation creates a new instance of Validation. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewValidation(t NewValidationT) *Validation { +func NewValidation(t mockConstructorTestingTNewValidation) *Validation { mock := &Validation{} mock.Mock.Test(t) diff --git a/contracts/validation/mocks/Validator.go b/contracts/validation/mocks/Validator.go index 1bf4f6e7e..5d4491fcc 100644 --- a/contracts/validation/mocks/Validator.go +++ b/contracts/validation/mocks/Validator.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.14.0. DO NOT EDIT. package mocks @@ -56,13 +56,13 @@ func (_m *Validator) Fails() bool { return r0 } -type NewValidatorT interface { +type mockConstructorTestingTNewValidator interface { mock.TestingT Cleanup(func()) } // NewValidator creates a new instance of Validator. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewValidator(t NewValidatorT) *Validator { +func NewValidator(t mockConstructorTestingTNewValidator) *Validator { mock := &Validator{} mock.Mock.Test(t) diff --git a/database/application.go b/database/application.go deleted file mode 100644 index 7d09effbc..000000000 --- a/database/application.go +++ /dev/null @@ -1,14 +0,0 @@ -package database - -import ( - "context" - - "github.com/goravel/framework/contracts/database/orm" -) - -type Application struct { -} - -func (app *Application) Init() orm.Orm { - return NewOrm(context.Background()) -} diff --git a/database/console/migrate.go b/database/console/migrate.go index 86bf617e4..b81f4a9c8 100644 --- a/database/console/migrate.go +++ b/database/console/migrate.go @@ -10,6 +10,7 @@ import ( "github.com/golang-migrate/migrate/v4/database/sqlite3" "github.com/golang-migrate/migrate/v4/database/sqlserver" + "github.com/goravel/framework/contracts/database/orm" "github.com/goravel/framework/database/support" "github.com/goravel/framework/facades" ) @@ -18,8 +19,8 @@ func getMigrate() (*migrate.Migrate, error) { connection := facades.Config.GetString("database.default") driver := facades.Config.GetString("database.connections." + connection + ".driver") dir := "file://./database/migrations" - switch driver { - case support.Mysql: + switch orm.Driver(driver) { + case orm.DriverMysql: dsn := support.GetMysqlDsn(connection) if dsn == "" { return nil, nil @@ -42,7 +43,7 @@ func getMigrate() (*migrate.Migrate, error) { } return migrate.NewWithDatabaseInstance(dir, "mysql", instance) - case support.Postgresql: + case orm.DriverPostgresql: dsn := support.GetPostgresqlDsn(connection) if dsn == "" { return nil, nil @@ -61,7 +62,7 @@ func getMigrate() (*migrate.Migrate, error) { } return migrate.NewWithDatabaseInstance(dir, "postgres", instance) - case support.Sqlite: + case orm.DriverSqlite: dsn := support.GetSqliteDsn(connection) if dsn == "" { return nil, nil @@ -80,7 +81,7 @@ func getMigrate() (*migrate.Migrate, error) { } return migrate.NewWithDatabaseInstance(dir, "sqlite3", instance) - case support.Sqlserver: + case orm.DriverSqlserver: dsn := support.GetSqlserverDsn(connection) if dsn == "" { return nil, nil diff --git a/database/console/migrate_creator.go b/database/console/migrate_creator.go index 972183132..c1d52ac4a 100644 --- a/database/console/migrate_creator.go +++ b/database/console/migrate_creator.go @@ -5,6 +5,7 @@ import ( "strings" "time" + "github.com/goravel/framework/contracts/database/orm" "github.com/goravel/framework/facades" "github.com/goravel/framework/support/file" ) @@ -32,11 +33,33 @@ func (receiver MigrateCreator) getStub(table string, create bool) (string, strin return "", "" } - if create { - return MigrateStubs{}.CreateUp(), MigrateStubs{}.CreateDown() - } + driver := facades.Config.GetString("database.connections." + facades.Config.GetString("database.default") + ".driver") + switch orm.Driver(driver) { + case orm.DriverPostgresql: + if create { + return PostgresqlStubs{}.CreateUp(), PostgresqlStubs{}.CreateDown() + } + + return PostgresqlStubs{}.UpdateUp(), PostgresqlStubs{}.UpdateDown() + case orm.DriverSqlite: + if create { + return SqliteStubs{}.CreateUp(), SqliteStubs{}.CreateDown() + } + + return SqliteStubs{}.UpdateUp(), SqliteStubs{}.UpdateDown() + case orm.DriverSqlserver: + if create { + return SqlserverStubs{}.CreateUp(), SqlserverStubs{}.CreateDown() + } - return MigrateStubs{}.UpdateUp(), MigrateStubs{}.UpdateDown() + return SqlserverStubs{}.UpdateUp(), SqlserverStubs{}.UpdateDown() + default: + if create { + return MysqlStubs{}.CreateUp(), MysqlStubs{}.CreateDown() + } + + return MysqlStubs{}.UpdateUp(), MysqlStubs{}.UpdateDown() + } } //populateStub Populate the place-holders in the migration stub. diff --git a/database/console/migrate_stubs.go b/database/console/migrate_stubs.go index 5f64ba80c..434d30a87 100644 --- a/database/console/migrate_stubs.go +++ b/database/console/migrate_stubs.go @@ -1,14 +1,14 @@ package console -type MigrateStubs struct { +type MysqlStubs struct { } //CreateUp Create up migration content. -func (receiver MigrateStubs) CreateUp() string { +func (receiver MysqlStubs) CreateUp() string { return `CREATE TABLE DummyTable ( id bigint(20) unsigned NOT NULL AUTO_INCREMENT, - created_at datetime(3) DEFAULT NULL, - updated_at datetime(3) DEFAULT NULL, + created_at datetime(3) NOT NULL, + updated_at datetime(3) NOT NULL, PRIMARY KEY (id), KEY idx_DummyTable_created_at (created_at), KEY idx_DummyTable_updated_at (updated_at) @@ -17,19 +17,113 @@ func (receiver MigrateStubs) CreateUp() string { } //CreateDown Create down migration content. -func (receiver MigrateStubs) CreateDown() string { +func (receiver MysqlStubs) CreateDown() string { return `DROP TABLE IF EXISTS DummyTable; ` } //UpdateUp Update up migration content. -func (receiver MigrateStubs) UpdateUp() string { +func (receiver MysqlStubs) UpdateUp() string { return `ALTER TABLE DummyTable ADD column varchar(255) COMMENT ''; ` } //UpdateDown Update down migration content. -func (receiver MigrateStubs) UpdateDown() string { +func (receiver MysqlStubs) UpdateDown() string { + return `ALTER TABLE DummyTable DROP COLUMN column; +` +} + +type PostgresqlStubs struct { +} + +//CreateUp Create up migration content. +func (receiver PostgresqlStubs) CreateUp() string { + return `CREATE TABLE DummyTable ( + id SERIAL PRIMARY KEY NOT NULL, + created_at timestamp NOT NULL, + updated_at timestamp NOT NULL +); +` +} + +//CreateDown Create down migration content. +func (receiver PostgresqlStubs) CreateDown() string { + return `DROP TABLE IF EXISTS DummyTable; +` +} + +//UpdateUp Update up migration content. +func (receiver PostgresqlStubs) UpdateUp() string { + return `ALTER TABLE DummyTable ADD column varchar(255) NOT NULL; +` +} + +//UpdateDown Update down migration content. +func (receiver PostgresqlStubs) UpdateDown() string { + return `ALTER TABLE DummyTable DROP COLUMN column; +` +} + +type SqliteStubs struct { +} + +//CreateUp Create up migration content. +func (receiver SqliteStubs) CreateUp() string { + return `CREATE TABLE DummyTable ( + id integer PRIMARY KEY AUTOINCREMENT NOT NULL, + created_at datetime NOT NULL, + updated_at datetime NOT NULL +); +` +} + +//CreateDown Create down migration content. +func (receiver SqliteStubs) CreateDown() string { + return `DROP TABLE IF EXISTS DummyTable; +` +} + +//UpdateUp Update up migration content. +func (receiver SqliteStubs) UpdateUp() string { + return `ALTER TABLE DummyTable ADD column text; +` +} + +//UpdateDown Update down migration content. +func (receiver SqliteStubs) UpdateDown() string { + return `ALTER TABLE DummyTable DROP COLUMN column; +` +} + +type SqlserverStubs struct { +} + +//CreateUp Create up migration content. +func (receiver SqlserverStubs) CreateUp() string { + return `CREATE TABLE DummyTable ( + id bigint NOT NULL IDENTITY(1,1), + created_at datetime NOT NULL, + updated_at datetime NOT NULL, + PRIMARY KEY (id) +); +` +} + +//CreateDown Create down migration content. +func (receiver SqlserverStubs) CreateDown() string { + return `DROP TABLE IF EXISTS DummyTable; +` +} + +//UpdateUp Update up migration content. +func (receiver SqlserverStubs) UpdateUp() string { + return `ALTER TABLE DummyTable ADD column varchar(255); +` +} + +//UpdateDown Update down migration content. +func (receiver SqlserverStubs) UpdateDown() string { return `ALTER TABLE DummyTable DROP COLUMN column; ` } diff --git a/database/db.go b/database/db.go index 03bf6f1a0..a0566b288 100644 --- a/database/db.go +++ b/database/db.go @@ -7,6 +7,7 @@ import ( "github.com/pkg/errors" "github.com/goravel/framework/contracts/database" + "github.com/goravel/framework/contracts/database/orm" "github.com/goravel/framework/database/support" "github.com/goravel/framework/facades" ) @@ -93,14 +94,14 @@ func (r *DB) Transaction(ctx context.Context, txFunc func(tx *sqlx.Tx) error) er func GetDsn(connection string) (string, error) { driver := facades.Config.GetString("database.connections." + connection + ".driver") - switch driver { - case support.Mysql: + switch orm.Driver(driver) { + case orm.DriverMysql: return support.GetMysqlDsn(connection), nil - case support.Postgresql: + case orm.DriverPostgresql: return support.GetPostgresqlDsn(connection), nil - case support.Sqlite: + case orm.DriverSqlite: return support.GetSqliteDsn(connection), nil - case support.Sqlserver: + case orm.DriverSqlserver: return support.GetSqlserverDsn(connection), nil default: return "", errors.New("database driver only support mysql, postgresql, sqlite and sqlserver") diff --git a/database/errors.go b/database/errors.go new file mode 100644 index 000000000..a92ee90bb --- /dev/null +++ b/database/errors.go @@ -0,0 +1,7 @@ +package database + +import "github.com/pkg/errors" + +var ( + ErrorMissingWhereClause = errors.New("WHERE conditions required") +) diff --git a/database/gorm.go b/database/gorm.go deleted file mode 100644 index 5474ee4b6..000000000 --- a/database/gorm.go +++ /dev/null @@ -1,336 +0,0 @@ -package database - -import ( - "context" - "errors" - "fmt" - "log" - "os" - "time" - - "gorm.io/driver/mysql" - "gorm.io/driver/postgres" - "gorm.io/driver/sqlite" - "gorm.io/driver/sqlserver" - "gorm.io/gorm" - gormLogger "gorm.io/gorm/logger" - - contractsorm "github.com/goravel/framework/contracts/database/orm" - "github.com/goravel/framework/database/support" - "github.com/goravel/framework/facades" -) - -type GormDB struct { - contractsorm.Query - instance *gorm.DB -} - -func NewGormDB(ctx context.Context, connection string) (contractsorm.DB, error) { - db, err := NewGormInstance(connection) - if err != nil { - return nil, err - } - if db == nil { - return nil, nil - } - - if ctx != nil { - db = db.WithContext(ctx) - } - - return &GormDB{ - Query: NewGormQuery(db), - instance: db, - }, nil -} - -func NewGormInstance(connection string) (*gorm.DB, error) { - gormConfig, err := getGormConfig(connection) - if err != nil { - return nil, errors.New(fmt.Sprintf("init gorm config error: %v", err)) - } - if gormConfig == nil { - return nil, nil - } - - var logLevel gormLogger.LogLevel - if facades.Config.GetBool("app.debug") { - logLevel = gormLogger.Info - } else { - logLevel = gormLogger.Error - } - - logger := New(log.New(os.Stdout, "\r\n", log.LstdFlags), gormLogger.Config{ - SlowThreshold: 200 * time.Millisecond, - LogLevel: gormLogger.Info, - IgnoreRecordNotFoundError: true, - Colorful: true, - }) - - return gorm.Open(gormConfig, &gorm.Config{ - DisableForeignKeyConstraintWhenMigrating: true, - SkipDefaultTransaction: true, - Logger: logger.LogMode(logLevel), - }) -} - -func (r *GormDB) Begin() (contractsorm.Transaction, error) { - tx := r.instance.Begin() - - return NewGormTransaction(tx), tx.Error -} - -type GormTransaction struct { - contractsorm.Query - instance *gorm.DB -} - -func NewGormTransaction(instance *gorm.DB) contractsorm.Transaction { - return &GormTransaction{Query: NewGormQuery(instance), instance: instance} -} - -func (r *GormTransaction) Commit() error { - return r.instance.Commit().Error -} - -func (r *GormTransaction) Rollback() error { - return r.instance.Rollback().Error -} - -type GormQuery struct { - instance *gorm.DB -} - -func NewGormQuery(instance *gorm.DB) contractsorm.Query { - return &GormQuery{instance} -} - -func (r *GormQuery) Driver() contractsorm.Driver { - return contractsorm.Driver(r.instance.Dialector.Name()) -} - -func (r *GormQuery) Count(count *int64) error { - return r.instance.Count(count).Error -} - -func (r *GormQuery) Create(value interface{}) error { - return r.instance.Create(value).Error -} - -func (r *GormQuery) Delete(value interface{}, conds ...interface{}) error { - return r.instance.Delete(value, conds...).Error -} - -func (r *GormQuery) Distinct(args ...interface{}) contractsorm.Query { - tx := r.instance.Distinct(args...) - - return NewGormQuery(tx) -} - -func (r *GormQuery) Exec(sql string, values ...interface{}) error { - return r.instance.Exec(sql, values...).Error -} - -func (r *GormQuery) Find(dest interface{}, conds ...interface{}) error { - return r.instance.Find(dest, conds...).Error -} - -func (r *GormQuery) First(dest interface{}) error { - err := r.instance.First(dest).Error - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil - } - - return err -} - -func (r *GormQuery) FirstOrCreate(dest interface{}, conds ...interface{}) error { - var err error - if len(conds) > 1 { - err = r.instance.Attrs([]interface{}{conds[1]}...).FirstOrCreate(dest, []interface{}{conds[0]}...).Error - } else { - err = r.instance.FirstOrCreate(dest, conds...).Error - } - - return err -} - -func (r *GormQuery) ForceDelete(value interface{}, conds ...interface{}) error { - return r.instance.Unscoped().Delete(value, conds...).Error -} - -func (r *GormQuery) Get(dest interface{}) error { - return r.instance.Find(dest).Error -} - -func (r *GormQuery) Group(name string) contractsorm.Query { - tx := r.instance.Group(name) - - return NewGormQuery(tx) -} - -func (r *GormQuery) Having(query interface{}, args ...interface{}) contractsorm.Query { - tx := r.instance.Having(query, args...) - - return NewGormQuery(tx) -} - -func (r *GormQuery) Join(query string, args ...interface{}) contractsorm.Query { - tx := r.instance.Joins(query, args...) - - return NewGormQuery(tx) -} - -func (r *GormQuery) Limit(limit int) contractsorm.Query { - tx := r.instance.Limit(limit) - - return NewGormQuery(tx) -} - -func (r *GormQuery) Model(value interface{}) contractsorm.Query { - tx := r.instance.Model(value) - - return NewGormQuery(tx) -} - -func (r *GormQuery) Offset(offset int) contractsorm.Query { - tx := r.instance.Offset(offset) - - return NewGormQuery(tx) -} - -func (r *GormQuery) Order(value interface{}) contractsorm.Query { - tx := r.instance.Order(value) - - return NewGormQuery(tx) -} - -func (r *GormQuery) OrWhere(query interface{}, args ...interface{}) contractsorm.Query { - tx := r.instance.Or(query, args...) - - return NewGormQuery(tx) -} - -func (r *GormQuery) Pluck(column string, dest interface{}) error { - return r.instance.Pluck(column, dest).Error -} - -func (r *GormQuery) Raw(sql string, values ...interface{}) contractsorm.Query { - tx := r.instance.Raw(sql, values...) - - return NewGormQuery(tx) -} - -func (r *GormQuery) Save(value interface{}) error { - return r.instance.Save(value).Error -} - -func (r *GormQuery) Scan(dest interface{}) error { - return r.instance.Scan(dest).Error -} - -func (r *GormQuery) Select(query interface{}, args ...interface{}) contractsorm.Query { - tx := r.instance.Select(query, args...) - - return NewGormQuery(tx) -} - -func (r *GormQuery) Table(name string, args ...interface{}) contractsorm.Query { - tx := r.instance.Table(name, args...) - - return NewGormQuery(tx) -} - -func (r *GormQuery) Update(column string, value interface{}) error { - return r.instance.Update(column, value).Error -} - -func (r *GormQuery) Updates(values interface{}) error { - return r.instance.Updates(values).Error -} - -func (r *GormQuery) Where(query interface{}, args ...interface{}) contractsorm.Query { - tx := r.instance.Where(query, args...) - - return NewGormQuery(tx) -} - -func (r *GormQuery) WithTrashed() contractsorm.Query { - tx := r.instance.Unscoped() - - return NewGormQuery(tx) -} - -func (r *GormQuery) Scopes(funcs ...func(contractsorm.Query) contractsorm.Query) contractsorm.Query { - var gormFuncs []func(*gorm.DB) *gorm.DB - for _, item := range funcs { - gormFuncs = append(gormFuncs, func(db *gorm.DB) *gorm.DB { - item(&GormQuery{db}) - - return db - }) - } - - tx := r.instance.Scopes(gormFuncs...) - - return NewGormQuery(tx) -} - -func getGormConfig(connection string) (gorm.Dialector, error) { - driver := facades.Config.GetString("database.connections." + connection + ".driver") - - switch driver { - case support.Mysql: - return getMysqlGormConfig(connection), nil - case support.Postgresql: - return getPostgresqlGormConfig(connection), nil - case support.Sqlite: - return getSqliteGormConfig(connection), nil - case support.Sqlserver: - return getSqlserverGormConfig(connection), nil - default: - return nil, errors.New(fmt.Sprintf("err database driver: %s, only support mysql, postgresql, sqlite and sqlserver", driver)) - } -} - -func getMysqlGormConfig(connection string) gorm.Dialector { - dsn := support.GetMysqlDsn(connection) - if dsn == "" { - return nil - } - - return mysql.New(mysql.Config{ - DSN: dsn, - }) -} - -func getPostgresqlGormConfig(connection string) gorm.Dialector { - dsn := support.GetPostgresqlDsn(connection) - if dsn == "" { - return nil - } - - return postgres.New(postgres.Config{ - DSN: dsn, - }) -} - -func getSqliteGormConfig(connection string) gorm.Dialector { - dsn := support.GetSqliteDsn(connection) - if dsn == "" { - return nil - } - - return sqlite.Open(dsn) -} - -func getSqlserverGormConfig(connection string) gorm.Dialector { - dsn := support.GetSqlserverDsn(connection) - if dsn == "" { - return nil - } - - return sqlserver.New(sqlserver.Config{ - DSN: dsn, - }) -} diff --git a/database/gorm/config.go b/database/gorm/config.go new file mode 100644 index 000000000..aefe50e0a --- /dev/null +++ b/database/gorm/config.go @@ -0,0 +1,75 @@ +package gorm + +import ( + "errors" + "fmt" + + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/driver/sqlserver" + "gorm.io/gorm" + + "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/database/support" + "github.com/goravel/framework/facades" +) + +func config(connection string) (gorm.Dialector, error) { + driver := facades.Config.GetString(fmt.Sprintf("database.connections.%s.driver", connection)) + + switch orm.Driver(driver) { + case orm.DriverMysql: + return mysqlConfig(connection), nil + case orm.DriverPostgresql: + return postgresqlConfig(connection), nil + case orm.DriverSqlite: + return sqliteConfig(connection), nil + case orm.DriverSqlserver: + return sqlserverConfig(connection), nil + default: + return nil, errors.New(fmt.Sprintf("err database driver: %s, only support mysql, postgresql, sqlite and sqlserver", driver)) + } +} + +func mysqlConfig(connection string) gorm.Dialector { + dsn := support.GetMysqlDsn(connection) + if dsn == "" { + return nil + } + + return mysql.New(mysql.Config{ + DSN: dsn, + }) +} + +func postgresqlConfig(connection string) gorm.Dialector { + dsn := support.GetPostgresqlDsn(connection) + if dsn == "" { + return nil + } + + return postgres.New(postgres.Config{ + DSN: dsn, + }) +} + +func sqliteConfig(connection string) gorm.Dialector { + dsn := support.GetSqliteDsn(connection) + if dsn == "" { + return nil + } + + return sqlite.Open(dsn) +} + +func sqlserverConfig(connection string) gorm.Dialector { + dsn := support.GetSqlserverDsn(connection) + if dsn == "" { + return nil + } + + return sqlserver.New(sqlserver.Config{ + DSN: dsn, + }) +} diff --git a/database/gorm_test.go b/database/gorm/config_test.go similarity index 89% rename from database/gorm_test.go rename to database/gorm/config_test.go index d4410e543..23a8058a3 100644 --- a/database/gorm_test.go +++ b/database/gorm/config_test.go @@ -1,20 +1,20 @@ -package database +package gorm import ( "errors" "fmt" "testing" - "github.com/goravel/framework/contracts/config/mocks" - "github.com/goravel/framework/database/support" - "github.com/goravel/framework/testing/mock" - "github.com/stretchr/testify/assert" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/driver/sqlserver" "gorm.io/gorm" + + "github.com/goravel/framework/contracts/config/mocks" + "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/testing/mock" ) func TestGetGormConfig(t *testing.T) { @@ -22,17 +22,17 @@ func TestGetGormConfig(t *testing.T) { tests := []struct { name string - connection string + connection orm.Driver setup func() expectDialector gorm.Dialector expectErr error }{ { name: "mysql", - connection: "mysql", + connection: orm.DriverMysql, setup: func() { mockConfig.On("GetString", "database.connections.mysql.driver"). - Return(support.Mysql).Once() + Return(orm.DriverMysql.String()).Once() mockConfig.On("GetString", "database.connections.mysql.host"). Return("127.0.0.1").Once() mockConfig.On("GetString", "database.connections.mysql.port"). @@ -55,10 +55,10 @@ func TestGetGormConfig(t *testing.T) { }, { name: "postgresql", - connection: support.Postgresql, + connection: orm.DriverPostgresql, setup: func() { mockConfig.On("GetString", "database.connections.postgresql.driver"). - Return(support.Postgresql).Once() + Return(orm.DriverPostgresql.String()).Once() mockConfig.On("GetString", "database.connections.postgresql.host"). Return("127.0.0.1").Once() mockConfig.On("GetString", "database.connections.postgresql.port"). @@ -81,10 +81,10 @@ func TestGetGormConfig(t *testing.T) { }, { name: "sqlite", - connection: support.Sqlite, + connection: orm.DriverSqlite, setup: func() { mockConfig.On("GetString", "database.connections.sqlite.driver"). - Return(support.Sqlite).Once() + Return(orm.DriverSqlite.String()).Once() mockConfig.On("GetString", "database.connections.sqlite.database"). Return("goravel").Once() }, @@ -92,10 +92,10 @@ func TestGetGormConfig(t *testing.T) { }, { name: "sqlserver", - connection: support.Sqlserver, + connection: orm.DriverSqlserver, setup: func() { mockConfig.On("GetString", "database.connections.sqlserver.driver"). - Return(support.Sqlserver).Once() + Return(orm.DriverSqlserver.String()).Once() mockConfig.On("GetString", "database.connections.sqlserver.host"). Return("127.0.0.1").Once() mockConfig.On("GetString", "database.connections.sqlserver.port"). @@ -126,7 +126,7 @@ func TestGetGormConfig(t *testing.T) { for _, test := range tests { mockConfig = mock.Config() test.setup() - dialector, err := getGormConfig(test.connection) + dialector, err := config(test.connection.String()) assert.Equal(t, test.expectDialector, dialector) assert.Equal(t, test.expectErr, err) } diff --git a/database/gorm/gorm.go b/database/gorm/gorm.go new file mode 100644 index 000000000..d482a2c3c --- /dev/null +++ b/database/gorm/gorm.go @@ -0,0 +1,473 @@ +package gorm + +import ( + "context" + "errors" + "fmt" + "log" + "os" + "reflect" + "time" + + "github.com/spf13/cast" + "gorm.io/gorm" + gormLogger "gorm.io/gorm/logger" + + ormcontract "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/database/orm" + databasesupport "github.com/goravel/framework/database/support" + "github.com/goravel/framework/facades" + "github.com/goravel/framework/support/database" +) + +func New(connection string) (*gorm.DB, error) { + gormConfig, err := config(connection) + if err != nil { + return nil, errors.New(fmt.Sprintf("init gorm config error: %v", err)) + } + if gormConfig == nil { + return nil, nil + } + + var logLevel gormLogger.LogLevel + if facades.Config.GetBool("app.debug") { + logLevel = gormLogger.Info + } else { + logLevel = gormLogger.Error + } + + logger := NewLogger(log.New(os.Stdout, "\r\n", log.LstdFlags), gormLogger.Config{ + SlowThreshold: 200 * time.Millisecond, + LogLevel: gormLogger.Info, + IgnoreRecordNotFoundError: true, + Colorful: true, + }) + + return gorm.Open(gormConfig, &gorm.Config{ + DisableForeignKeyConstraintWhenMigrating: true, + SkipDefaultTransaction: true, + Logger: logger.LogMode(logLevel), + }) +} + +type DB struct { + ormcontract.Query + instance *gorm.DB +} + +func NewDB(ctx context.Context, connection string) (*DB, error) { + db, err := New(connection) + if err != nil { + return nil, err + } + if db == nil { + return nil, nil + } + + if ctx != nil { + db = db.WithContext(ctx) + } + + return &DB{ + Query: NewQuery(db), + instance: db, + }, nil +} + +func (r *DB) Begin() (ormcontract.Transaction, error) { + tx := r.instance.Begin() + + return NewTransaction(tx), tx.Error +} + +func (r *DB) Instance() *gorm.DB { + return r.instance +} + +type Transaction struct { + ormcontract.Query + instance *gorm.DB +} + +func NewTransaction(instance *gorm.DB) *Transaction { + return &Transaction{Query: NewQuery(instance), instance: instance} +} + +func (r *Transaction) Commit() error { + return r.instance.Commit().Error +} + +func (r *Transaction) Rollback() error { + return r.instance.Rollback().Error +} + +type Query struct { + instance *gorm.DB +} + +func NewQuery(instance *gorm.DB) *Query { + return &Query{instance} +} + +func (r *Query) Association(association string) ormcontract.Association { + return r.instance.Association(association) +} + +func (r *Query) Driver() ormcontract.Driver { + return ormcontract.Driver(r.instance.Dialector.Name()) +} + +func (r *Query) Count(count *int64) error { + return r.instance.Count(count).Error +} + +func (r *Query) Create(value any) error { + if len(r.instance.Statement.Selects) > 0 && len(r.instance.Statement.Omits) > 0 { + return errors.New("cannot set Select and Omits at the same time") + } + + if len(r.instance.Statement.Selects) > 0 { + if len(r.instance.Statement.Selects) == 1 && r.instance.Statement.Selects[0] == orm.Associations { + r.instance.Statement.Selects = []string{} + return r.instance.Create(value).Error + } + + for _, val := range r.instance.Statement.Selects { + if val == orm.Associations { + return errors.New("cannot set orm.Associations and other fields at the same time") + } + } + + return r.instance.Create(value).Error + } + + if len(r.instance.Statement.Omits) > 0 { + if len(r.instance.Statement.Omits) == 1 && r.instance.Statement.Omits[0] == orm.Associations { + r.instance.Statement.Selects = []string{} + return r.instance.Omit(orm.Associations).Create(value).Error + } + + for _, val := range r.instance.Statement.Omits { + if val == orm.Associations { + return errors.New("cannot set orm.Associations and other fields at the same time") + } + } + + return r.instance.Create(value).Error + } + + return r.instance.Omit(orm.Associations).Create(value).Error +} + +func (r *Query) Delete(value any, conds ...any) error { + return r.instance.Delete(value, conds...).Error +} + +func (r *Query) Distinct(args ...any) ormcontract.Query { + tx := r.instance.Distinct(args...) + + return NewQuery(tx) +} + +func (r *Query) Exec(sql string, values ...any) error { + return r.instance.Exec(sql, values...).Error +} + +func (r *Query) Find(dest any, conds ...any) error { + if len(conds) == 1 { + switch conds[0].(type) { + case string: + if conds[0].(string) == "" { + return databasesupport.ErrorMissingWhereClause + } + default: + reflectValue := reflect.Indirect(reflect.ValueOf(conds[0])) + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + if reflectValue.Len() == 0 { + return databasesupport.ErrorMissingWhereClause + } + } + } + } + + return r.instance.Find(dest, conds...).Error +} + +func (r *Query) First(dest any) error { + err := r.instance.First(dest).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil + } + + return err +} + +func (r *Query) FirstOrCreate(dest any, conds ...any) error { + var err error + if len(conds) > 1 { + err = r.instance.Attrs([]any{conds[1]}...).FirstOrCreate(dest, []any{conds[0]}...).Error + } else { + err = r.instance.FirstOrCreate(dest, conds...).Error + } + + return err +} + +func (r *Query) ForceDelete(value any, conds ...any) error { + return r.instance.Unscoped().Delete(value, conds...).Error +} + +func (r *Query) Get(dest any) error { + return r.instance.Find(dest).Error +} + +func (r *Query) Group(name string) ormcontract.Query { + tx := r.instance.Group(name) + + return NewQuery(tx) +} + +func (r *Query) Having(query any, args ...any) ormcontract.Query { + tx := r.instance.Having(query, args...) + + return NewQuery(tx) +} + +func (r *Query) Join(query string, args ...any) ormcontract.Query { + tx := r.instance.Joins(query, args...) + + return NewQuery(tx) +} + +func (r *Query) Limit(limit int) ormcontract.Query { + tx := r.instance.Limit(limit) + + return NewQuery(tx) +} + +func (r *Query) Load(model any, relation string, args ...any) error { + if relation == "" { + return errors.New("relation cannot be empty") + } + + destType := reflect.TypeOf(model) + if destType.Kind() != reflect.Pointer { + return errors.New("model must be pointer") + } + + if id := database.GetID(model); id == nil { + return errors.New("id cannot be empty") + } + + copyDest := copyStruct(model) + query := r.With(relation, args...) + err := query.Find(model) + + t := destType.Elem() + v := reflect.ValueOf(model).Elem() + for i := 0; i < t.NumField(); i++ { + if t.Field(i).Name != relation { + v.Field(i).Set(copyDest.Field(i)) + } + } + + return err +} + +func (r *Query) LoadMissing(model any, relation string, args ...any) error { + destType := reflect.TypeOf(model) + if destType.Kind() != reflect.Pointer { + return errors.New("model must be pointer") + } + + t := reflect.TypeOf(model).Elem() + v := reflect.ValueOf(model).Elem() + for i := 0; i < t.NumField(); i++ { + if t.Field(i).Name == relation { + var id any + if v.Field(i).Kind() == reflect.Pointer { + if !v.Field(i).IsNil() { + id = database.GetIDByReflect(v.Field(i).Type().Elem(), v.Field(i).Elem()) + } + } else if v.Field(i).Kind() == reflect.Slice { + if v.Field(i).Len() > 0 { + return nil + } + } else { + id = database.GetIDByReflect(v.Field(i).Type(), v.Field(i)) + } + if cast.ToString(id) != "" { + return nil + } + } + } + + return r.Load(model, relation, args...) +} + +func (r *Query) Model(value any) ormcontract.Query { + tx := r.instance.Model(value) + + return NewQuery(tx) +} + +func (r *Query) Offset(offset int) ormcontract.Query { + tx := r.instance.Offset(offset) + + return NewQuery(tx) +} + +func (r *Query) Omit(columns ...string) ormcontract.Query { + tx := r.instance.Omit(columns...) + + return NewQuery(tx) +} + +func (r *Query) Order(value any) ormcontract.Query { + tx := r.instance.Order(value) + + return NewQuery(tx) +} + +func (r *Query) OrWhere(query any, args ...any) ormcontract.Query { + tx := r.instance.Or(query, args...) + + return NewQuery(tx) +} + +func (r *Query) Pluck(column string, dest any) error { + return r.instance.Pluck(column, dest).Error +} + +func (r *Query) Raw(sql string, values ...any) ormcontract.Query { + tx := r.instance.Raw(sql, values...) + + return NewQuery(tx) +} + +func (r *Query) Save(value any) error { + if len(r.instance.Statement.Selects) > 0 && len(r.instance.Statement.Omits) > 0 { + return errors.New("cannot set Select and Omits at the same time") + } + + if len(r.instance.Statement.Selects) > 0 { + for _, val := range r.instance.Statement.Selects { + if val == orm.Associations { + return r.instance.Session(&gorm.Session{FullSaveAssociations: true}).Save(value).Error + } + } + + return r.instance.Save(value).Error + } + + if len(r.instance.Statement.Omits) > 0 { + for _, val := range r.instance.Statement.Omits { + if val == orm.Associations { + return r.instance.Omit(orm.Associations).Save(value).Error + } + } + + return r.instance.Save(value).Error + } + + return r.instance.Omit(orm.Associations).Save(value).Error +} + +func (r *Query) Scan(dest any) error { + return r.instance.Scan(dest).Error +} + +func (r *Query) Select(query any, args ...any) ormcontract.Query { + tx := r.instance.Select(query, args...) + + return NewQuery(tx) +} + +func (r *Query) Table(name string, args ...any) ormcontract.Query { + tx := r.instance.Table(name, args...) + + return NewQuery(tx) +} + +func (r *Query) Update(column string, value any) error { + return r.instance.Update(column, value).Error +} + +func (r *Query) Updates(values any) error { + if len(r.instance.Statement.Selects) > 0 && len(r.instance.Statement.Omits) > 0 { + return errors.New("cannot set Select and Omits at the same time") + } + + if len(r.instance.Statement.Selects) > 0 { + for _, val := range r.instance.Statement.Selects { + if val == orm.Associations { + return r.instance.Session(&gorm.Session{FullSaveAssociations: true}).Updates(values).Error + } + } + + return r.instance.Updates(values).Error + } + + if len(r.instance.Statement.Omits) > 0 { + for _, val := range r.instance.Statement.Omits { + if val == orm.Associations { + return r.instance.Omit(orm.Associations).Updates(values).Error + } + } + + return r.instance.Updates(values).Error + } + + return r.instance.Omit(orm.Associations).Updates(values).Error +} + +func (r *Query) Where(query any, args ...any) ormcontract.Query { + tx := r.instance.Where(query, args...) + + return NewQuery(tx) +} + +func (r *Query) WithTrashed() ormcontract.Query { + tx := r.instance.Unscoped() + + return NewQuery(tx) +} + +func (r *Query) With(query string, args ...any) ormcontract.Query { + if len(args) == 1 { + switch args[0].(type) { + case func(ormcontract.Query) ormcontract.Query: + newArgs := []any{ + func(db *gorm.DB) *gorm.DB { + query := args[0].(func(query ormcontract.Query) ormcontract.Query)(NewQuery(db)) + + return query.(*Query).instance + }, + } + + tx := r.instance.Preload(query, newArgs...) + + return NewQuery(tx) + } + } + + tx := r.instance.Preload(query, args...) + + return NewQuery(tx) +} + +func (r *Query) Scopes(funcs ...func(ormcontract.Query) ormcontract.Query) ormcontract.Query { + var gormFuncs []func(*gorm.DB) *gorm.DB + for _, item := range funcs { + gormFuncs = append(gormFuncs, func(db *gorm.DB) *gorm.DB { + item(&Query{db}) + + return db + }) + } + + tx := r.instance.Scopes(gormFuncs...) + + return NewQuery(tx) +} diff --git a/database/gorm/gorm_test.go b/database/gorm/gorm_test.go new file mode 100644 index 000000000..1862595ba --- /dev/null +++ b/database/gorm/gorm_test.go @@ -0,0 +1,1216 @@ +package gorm + +import ( + "log" + "testing" + + "github.com/stretchr/testify/suite" + _ "gorm.io/driver/postgres" + + ormcontract "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/database/orm" + "github.com/goravel/framework/support/file" +) + +type User struct { + orm.Model + orm.SoftDeletes + Name string + Avatar string + Address *Address + Books []*Book + House *House `gorm:"polymorphic:Houseable"` + Phones []*Phone `gorm:"polymorphic:Phoneable"` + Roles []*Role `gorm:"many2many:role_user"` +} + +type Role struct { + orm.Model + Name string + Users []*User `gorm:"many2many:role_user"` +} + +type Address struct { + orm.Model + UserID uint + Name string + Province string + User *User +} + +type Book struct { + orm.Model + UserID uint + Name string + User *User + Author *Author +} + +type Author struct { + orm.Model + BookID uint + Name string +} + +type House struct { + orm.Model + Name string + HouseableID uint + HouseableType string +} + +type Phone struct { + orm.Model + Name string + PhoneableID uint + PhoneableType string +} + +type GormQueryTestSuite struct { + suite.Suite + dbs []ormcontract.DB +} + +func TestGormQueryTestSuite(t *testing.T) { + mysqlPool, mysqlDocker, mysqlDB, err := MysqlDocker() + if err != nil { + log.Fatalf("Get gorm mysql error: %s", err) + } + + postgresqlPool, postgresqlDocker, postgresqlDB, err := PostgresqlDocker() + if err != nil { + log.Fatalf("Get gorm postgresql error: %s", err) + } + + _, _, sqliteDB, err := SqliteDocker() + if err != nil { + log.Fatalf("Get gorm sqlite error: %s", err) + } + + sqlserverPool, sqlserverDocker, sqlserverDB, err := SqlserverDocker() + if err != nil { + log.Fatalf("Get gorm postgresql error: %s", err) + } + + suite.Run(t, &GormQueryTestSuite{ + dbs: []ormcontract.DB{ + mysqlDB, + postgresqlDB, + sqliteDB, + sqlserverDB, + }, + }) + + file.Remove("goravel") + + if err := mysqlPool.Purge(mysqlDocker); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } + if err := postgresqlPool.Purge(postgresqlDocker); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } + if err := sqlserverPool.Purge(sqlserverDocker); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } +} + +func (s *GormQueryTestSuite) SetupTest() { +} + +func (s *GormQueryTestSuite) TestAssociation() { + for _, db := range s.dbs { + tests := []struct { + description string + setup func(description string) + }{ + { + description: "Find", + setup: func(description string) { + user := &User{ + Name: "association_find_name", + Address: &Address{ + Name: "association_find_address", + }, + } + + s.Nil(db.Select(orm.Associations).Create(&user), description) + s.True(user.ID > 0, description) + s.True(user.Address.ID > 0, description) + + var user1 User + s.Nil(db.Find(&user1, user.ID), description) + s.True(user1.ID > 0, description) + + var userAddress Address + s.Nil(db.Model(&user1).Association("Address").Find(&userAddress), description) + s.True(userAddress.ID > 0, description) + s.Equal("association_find_address", userAddress.Name, description) + }, + }, + { + description: "hasOne Append", + setup: func(description string) { + user := &User{ + Name: "association_has_one_append_name", + Address: &Address{ + Name: "association_has_one_append_address", + }, + } + + s.Nil(db.Select(orm.Associations).Create(&user), description) + s.True(user.ID > 0, description) + s.True(user.Address.ID > 0, description) + + var user1 User + s.Nil(db.Find(&user1, user.ID), description) + s.True(user1.ID > 0, description) + s.Nil(db.Model(&user1).Association("Address").Append(&Address{Name: "association_has_one_append_address1"}), description) + + s.Nil(db.Load(&user1, "Address"), description) + s.True(user1.Address.ID > 0, description) + s.Equal("association_has_one_append_address1", user1.Address.Name, description) + }, + }, + { + description: "hasMany Append", + setup: func(description string) { + user := &User{ + Name: "association_has_many_append_name", + Books: []*Book{ + {Name: "association_has_many_append_address1"}, + {Name: "association_has_many_append_address2"}, + }, + } + + s.Nil(db.Select(orm.Associations).Create(&user), description) + s.True(user.ID > 0, description) + s.True(user.Books[0].ID > 0, description) + s.True(user.Books[1].ID > 0, description) + + var user1 User + s.Nil(db.Find(&user1, user.ID), description) + s.True(user1.ID > 0, description) + s.Nil(db.Model(&user1).Association("Books").Append(&Book{Name: "association_has_many_append_address3"}), description) + + s.Nil(db.Load(&user1, "Books"), description) + s.Equal(3, len(user1.Books), description) + s.Equal("association_has_many_append_address3", user1.Books[2].Name, description) + }, + }, + { + description: "hasOne Replace", + setup: func(description string) { + user := &User{ + Name: "association_has_one_append_name", + Address: &Address{ + Name: "association_has_one_append_address", + }, + } + + s.Nil(db.Select(orm.Associations).Create(&user), description) + s.True(user.ID > 0, description) + s.True(user.Address.ID > 0, description) + + var user1 User + s.Nil(db.Find(&user1, user.ID), description) + s.True(user1.ID > 0, description) + s.Nil(db.Model(&user1).Association("Address").Replace(&Address{Name: "association_has_one_append_address1"}), description) + + s.Nil(db.Load(&user1, "Address"), description) + s.True(user1.Address.ID > 0, description) + s.Equal("association_has_one_append_address1", user1.Address.Name, description) + }, + }, + { + description: "hasMany Replace", + setup: func(description string) { + user := &User{ + Name: "association_has_many_replace_name", + Books: []*Book{ + {Name: "association_has_many_replace_address1"}, + {Name: "association_has_many_replace_address2"}, + }, + } + + s.Nil(db.Select(orm.Associations).Create(&user), description) + s.True(user.ID > 0, description) + s.True(user.Books[0].ID > 0, description) + s.True(user.Books[1].ID > 0, description) + + var user1 User + s.Nil(db.Find(&user1, user.ID), description) + s.True(user1.ID > 0, description) + s.Nil(db.Model(&user1).Association("Books").Replace(&Book{Name: "association_has_many_replace_address3"}), description) + + s.Nil(db.Load(&user1, "Books"), description) + s.Equal(1, len(user1.Books), description) + s.Equal("association_has_many_replace_address3", user1.Books[0].Name, description) + }, + }, + { + description: "Delete", + setup: func(description string) { + user := &User{ + Name: "association_delete_name", + Address: &Address{ + Name: "association_delete_address", + }, + } + + s.Nil(db.Select(orm.Associations).Create(&user), description) + s.True(user.ID > 0, description) + s.True(user.Address.ID > 0, description) + + // No ID when Delete + var user1 User + s.Nil(db.Find(&user1, user.ID), description) + s.True(user1.ID > 0, description) + s.Nil(db.Model(&user1).Association("Address").Delete(&Address{Name: "association_delete_address"}), description) + + s.Nil(db.Load(&user1, "Address"), description) + s.True(user1.Address.ID > 0, description) + s.Equal("association_delete_address", user1.Address.Name, description) + + // Has ID when Delete + var user2 User + s.Nil(db.Find(&user2, user.ID), description) + s.True(user2.ID > 0, description) + var userAddress Address + userAddress.ID = user1.Address.ID + s.Nil(db.Model(&user2).Association("Address").Delete(&userAddress), description) + + s.Nil(db.Load(&user2, "Address"), description) + s.Nil(user2.Address, description) + }, + }, + { + description: "Clear", + setup: func(description string) { + user := &User{ + Name: "association_clear_name", + Address: &Address{ + Name: "association_clear_address", + }, + } + + s.Nil(db.Select(orm.Associations).Create(&user), description) + s.True(user.ID > 0, description) + s.True(user.Address.ID > 0, description) + + // No ID when Delete + var user1 User + s.Nil(db.Find(&user1, user.ID), description) + s.True(user1.ID > 0, description) + s.Nil(db.Model(&user1).Association("Address").Clear(), description) + + s.Nil(db.Load(&user1, "Address"), description) + s.Nil(user1.Address, description) + }, + }, + { + description: "Count", + setup: func(description string) { + user := &User{ + Name: "association_count_name", + Books: []*Book{ + {Name: "association_count_address1"}, + {Name: "association_count_address2"}, + }, + } + + s.Nil(db.Select(orm.Associations).Create(&user), description) + s.True(user.ID > 0, description) + s.True(user.Books[0].ID > 0, description) + s.True(user.Books[1].ID > 0, description) + + var user1 User + s.Nil(db.Find(&user1, user.ID), description) + s.True(user1.ID > 0, description) + s.Equal(int64(2), db.Model(&user1).Association("Books").Count(), description) + }, + }, + } + + for _, test := range tests { + test.setup(test.description) + } + } +} + +func (s *GormQueryTestSuite) TestCount() { + for _, db := range s.dbs { + user := User{Name: "count_user", Avatar: "count_avatar"} + s.Nil(db.Create(&user)) + s.True(user.ID > 0) + + user1 := User{Name: "count_user", Avatar: "count_avatar1"} + s.Nil(db.Create(&user1)) + s.True(user1.ID > 0) + + var count int64 + s.Nil(db.Model(&User{}).Where("name = ?", "count_user").Count(&count)) + s.True(count > 0) + + var count1 int64 + s.Nil(db.Table("users").Where("name = ?", "count_user").Count(&count1)) + s.True(count1 > 0) + } +} + +func (s *GormQueryTestSuite) TestCreate() { + for _, db := range s.dbs { + tests := []struct { + description string + setup func(description string) + }{ + { + description: "success when create with no relationships", + setup: func(description string) { + user := User{Name: "create_user", Address: &Address{}, Books: []*Book{&Book{}, &Book{}}} + user.Address.Name = "create_address" + user.Books[0].Name = "create_book0" + user.Books[1].Name = "create_book1" + s.Nil(db.Create(&user), description) + s.True(user.ID > 0, description) + s.True(user.Address.ID == 0, description) + s.True(user.Books[0].ID == 0, description) + s.True(user.Books[1].ID == 0, description) + }, + }, + { + description: "success when create with select orm.Associations", + setup: func(description string) { + user := User{Name: "create_user", Address: &Address{}, Books: []*Book{&Book{}, &Book{}}} + user.Address.Name = "create_address" + user.Books[0].Name = "create_book0" + user.Books[1].Name = "create_book1" + s.Nil(db.Select(orm.Associations).Create(&user), description) + s.True(user.ID > 0, description) + s.True(user.Address.ID > 0, description) + s.True(user.Books[0].ID > 0, description) + s.True(user.Books[1].ID > 0, description) + }, + }, + { + description: "success when create with select fields", + setup: func(description string) { + user := User{Name: "create_user", Avatar: "create_avatar", Address: &Address{}, Books: []*Book{&Book{}, &Book{}}} + user.Address.Name = "create_address" + user.Books[0].Name = "create_book0" + user.Books[1].Name = "create_book1" + s.Nil(db.Select("Name", "Avatar", "Address").Create(&user), description) + s.True(user.ID > 0, description) + s.True(user.Address.ID > 0, description) + s.True(user.Books[0].ID == 0, description) + s.True(user.Books[1].ID == 0, description) + }, + }, + { + description: "success when create with omit fields", + setup: func(description string) { + user := User{Name: "create_user", Avatar: "create_avatar", Address: &Address{}, Books: []*Book{&Book{}, &Book{}}} + user.Address.Name = "create_address" + user.Books[0].Name = "create_book0" + user.Books[1].Name = "create_book1" + s.Nil(db.Omit("Address").Create(&user), description) + s.True(user.ID > 0, description) + s.True(user.Address.ID == 0, description) + s.True(user.Books[0].ID > 0, description) + s.True(user.Books[1].ID > 0, description) + }, + }, + { + description: "success create with omit orm.Associations", + setup: func(description string) { + user := User{Name: "create_user", Avatar: "create_avatar", Address: &Address{}, Books: []*Book{&Book{}, &Book{}}} + user.Address.Name = "create_address" + user.Books[0].Name = "create_book0" + user.Books[1].Name = "create_book1" + s.Nil(db.Omit(orm.Associations).Create(&user), description) + s.True(user.ID > 0, description) + s.True(user.Address.ID == 0, description) + s.True(user.Books[0].ID == 0, description) + s.True(user.Books[1].ID == 0, description) + }, + }, + { + description: "error when set select and omit at the same time", + setup: func(description string) { + user := User{Name: "create_user", Avatar: "create_avatar", Address: &Address{}, Books: []*Book{&Book{}, &Book{}}} + user.Address.Name = "create_address" + user.Books[0].Name = "create_book0" + user.Books[1].Name = "create_book1" + s.EqualError(db.Omit(orm.Associations).Select("Name").Create(&user), "cannot set Select and Omits at the same time", description) + }, + }, + { + description: "error when select that set fields and orm.Associations at the same time", + setup: func(description string) { + user := User{Name: "create_user", Avatar: "create_avatar", Address: &Address{}, Books: []*Book{&Book{}, &Book{}}} + user.Address.Name = "create_address" + user.Books[0].Name = "create_book0" + user.Books[1].Name = "create_book1" + s.EqualError(db.Select("Name", orm.Associations).Create(&user), "cannot set orm.Associations and other fields at the same time", description) + }, + }, + { + description: "error when omit that set fields and orm.Associations at the same time", + setup: func(description string) { + user := User{Name: "create_user", Avatar: "create_avatar", Address: &Address{}, Books: []*Book{&Book{}, &Book{}}} + user.Address.Name = "create_address" + user.Books[0].Name = "create_book0" + user.Books[1].Name = "create_book1" + s.EqualError(db.Omit("Name", orm.Associations).Create(&user), "cannot set orm.Associations and other fields at the same time", description) + }, + }, + } + for _, test := range tests { + test.setup(test.description) + } + } +} + +func (s *GormQueryTestSuite) TestDelete() { + for _, db := range s.dbs { + user := User{Name: "delete_user", Avatar: "delete_avatar"} + s.Nil(db.Create(&user)) + s.True(user.ID > 0) + + s.Nil(db.Delete(&user)) + + var user1 User + s.Nil(db.Find(&user1, user.ID)) + s.Equal(uint(0), user1.ID) + + user2 := User{Name: "delete_user", Avatar: "delete_avatar"} + s.Nil(db.Create(&user2)) + s.True(user2.ID > 0) + + s.Nil(db.Delete(&User{}, user2.ID)) + + var user3 User + s.Nil(db.Find(&user3, user2.ID)) + s.Equal(uint(0), user3.ID) + + users := []User{{Name: "delete_user", Avatar: "delete_avatar"}, {Name: "delete_user1", Avatar: "delete_avatar1"}} + s.Nil(db.Create(&users)) + s.True(users[0].ID > 0) + s.True(users[1].ID > 0) + + s.Nil(db.Delete(&User{}, []uint{users[0].ID, users[1].ID})) + + var count int64 + s.Nil(db.Model(&User{}).Where("name", "delete_user").OrWhere("name", "delete_user1").Count(&count)) + s.True(count == 0) + } +} + +func (s *GormQueryTestSuite) TestDistinct() { + for _, db := range s.dbs { + user := User{Name: "distinct_user", Avatar: "distinct_avatar"} + s.Nil(db.Create(&user)) + s.True(user.ID > 0) + + user1 := User{Name: "distinct_user", Avatar: "distinct_avatar1"} + s.Nil(db.Create(&user1)) + s.True(user1.ID > 0) + + var users []User + s.Nil(db.Distinct("name").Find(&users, []uint{user.ID, user1.ID})) + s.Equal(1, len(users)) + } +} + +func (s *GormQueryTestSuite) TestFirst() { + for _, db := range s.dbs { + user := User{Name: "first_user"} + s.Nil(db.Create(&user)) + s.True(user.ID > 0) + + var user1 User + s.Nil(db.Where("name = ?", "first_user").First(&user1)) + s.True(user1.ID > 0) + } +} + +func (s *GormQueryTestSuite) TestFirstOrCreate() { + for _, db := range s.dbs { + var user User + s.Nil(db.Where("avatar = ?", "first_or_create_avatar").FirstOrCreate(&user, User{Name: "first_or_create_user"})) + s.True(user.ID > 0) + + var user1 User + s.Nil(db.Where("avatar = ?", "first_or_create_avatar").FirstOrCreate(&user1, User{Name: "user"}, User{Avatar: "first_or_create_avatar1"})) + s.True(user1.ID > 0) + s.True(user1.Avatar == "first_or_create_avatar1") + } +} + +func (s *GormQueryTestSuite) TestJoin() { + for _, db := range s.dbs { + user := User{Name: "join_user", Avatar: "join_avatar"} + s.Nil(db.Create(&user)) + s.True(user.ID > 0) + + userAddress := Address{UserID: user.ID, Name: "join_address", Province: "join_province"} + s.Nil(db.Create(&userAddress)) + s.True(userAddress.ID > 0) + + type Result struct { + UserName string + UserAddressName string + } + var result []Result + s.Nil(db.Model(&User{}).Where("users.id = ?", user.ID).Join("left join addresses ua on users.id = ua.user_id"). + Select("users.name user_name, ua.name user_address_name").Get(&result)) + s.Equal(1, len(result)) + s.Equal("join_user", result[0].UserName) + s.Equal("join_address", result[0].UserAddressName) + } +} + +func (s *GormQueryTestSuite) TestOffset() { + for _, db := range s.dbs { + user := User{Name: "offset_user", Avatar: "offset_avatar"} + s.Nil(db.Create(&user)) + s.True(user.ID > 0) + + user1 := User{Name: "offset_user", Avatar: "offset_avatar1"} + s.Nil(db.Create(&user1)) + s.True(user1.ID > 0) + + var user2 []User + s.Nil(db.Where("name = ?", "offset_user").Offset(1).Limit(1).Get(&user2)) + s.True(len(user2) > 0) + s.True(user2[0].ID > 0) + } +} + +func (s *GormQueryTestSuite) TestOmit() { + // todo +} + +func (s *GormQueryTestSuite) TestOrder() { + for _, db := range s.dbs { + user := User{Name: "order_user", Avatar: "order_avatar"} + s.Nil(db.Create(&user)) + s.True(user.ID > 0) + + user1 := User{Name: "order_user", Avatar: "order_avatar1"} + s.Nil(db.Create(&user1)) + s.True(user1.ID > 0) + + var user2 []User + s.Nil(db.Where("name = ?", "order_user").Order("id desc").Order("name asc").Get(&user2)) + s.True(len(user2) > 0) + s.True(user2[0].ID > 0) + } +} + +func (s *GormQueryTestSuite) TestPluck() { + for _, db := range s.dbs { + user := User{Name: "pluck_user", Avatar: "pluck_avatar"} + s.Nil(db.Create(&user)) + s.True(user.ID > 0) + + user1 := User{Name: "pluck_user", Avatar: "pluck_avatar1"} + s.Nil(db.Create(&user1)) + s.True(user1.ID > 0) + + var avatars []string + s.Nil(db.Model(&User{}).Where("name = ?", "pluck_user").Pluck("avatar", &avatars)) + s.True(len(avatars) > 0) + s.True(avatars[0] == "pluck_avatar") + } +} + +func (s *GormQueryTestSuite) TestHasOne() { + for _, db := range s.dbs { + user := &User{ + Name: "has_one_name", + Address: &Address{ + Name: "has_one_address", + }, + } + + s.Nil(db.Select(orm.Associations).Create(&user)) + s.True(user.ID > 0) + s.True(user.Address.ID > 0) + + var user1 User + s.Nil(db.With("Address").Where("name = ?", "has_one_name").First(&user1)) + s.True(user.ID > 0) + s.True(user.Address.ID > 0) + } +} + +func (s *GormQueryTestSuite) TestHasOneMorph() { + for _, db := range s.dbs { + user := &User{ + Name: "has_one_morph_name", + House: &House{ + Name: "has_one_morph_house", + }, + } + s.Nil(db.Select(orm.Associations).Create(&user)) + s.True(user.ID > 0) + s.True(user.House.ID > 0) + + var user1 User + s.Nil(db.With("House").Where("name = ?", "has_one_morph_name").First(&user1)) + s.True(user.ID > 0) + s.True(user.Name == "has_one_morph_name") + s.True(user.House.ID > 0) + s.True(user.House.Name == "has_one_morph_house") + + var house House + s.Nil(db.Where("name = ?", "has_one_morph_house").Where("houseable_type = ?", "users").Where("houseable_id = ?", user.ID).First(&house)) + s.True(house.ID > 0) + } +} + +func (s *GormQueryTestSuite) TestHasMany() { + for _, db := range s.dbs { + user := &User{ + Name: "has_many_name", + Books: []*Book{ + {Name: "has_many_book1"}, + {Name: "has_many_book2"}, + }, + } + + s.Nil(db.Select(orm.Associations).Create(&user)) + s.True(user.ID > 0) + s.True(user.Books[0].ID > 0) + s.True(user.Books[1].ID > 0) + + var user1 User + s.Nil(db.With("Books").Where("name = ?", "has_many_name").First(&user1)) + s.True(user.ID > 0) + s.True(len(user.Books) == 2) + } +} + +func (s *GormQueryTestSuite) TestHasManyMorph() { + for _, db := range s.dbs { + user := &User{ + Name: "has_many_morph_name", + Phones: []*Phone{ + {Name: "has_many_morph_phone1"}, + {Name: "has_many_morph_phone2"}, + }, + } + s.Nil(db.Select(orm.Associations).Create(&user)) + s.True(user.ID > 0) + s.True(user.Phones[0].ID > 0) + s.True(user.Phones[1].ID > 0) + + var user1 User + s.Nil(db.With("Phones").Where("name = ?", "has_many_morph_name").First(&user1)) + s.True(user.ID > 0) + s.True(user.Name == "has_many_morph_name") + s.True(len(user.Phones) == 2) + s.True(user.Phones[0].Name == "has_many_morph_phone1") + s.True(user.Phones[1].Name == "has_many_morph_phone2") + + var phones []Phone + s.Nil(db.Where("name like ?", "has_many_morph_phone%").Where("phoneable_type = ?", "users").Where("phoneable_id = ?", user.ID).Find(&phones)) + s.True(len(phones) == 2) + } +} + +func (s *GormQueryTestSuite) TestBelongsTo() { + for _, db := range s.dbs { + user := &User{ + Name: "belongs_to_name", + Address: &Address{ + Name: "belongs_to_address", + }, + } + + s.Nil(db.Select(orm.Associations).Create(&user)) + s.True(user.ID > 0) + s.True(user.Address.ID > 0) + + var userAddress Address + s.Nil(db.With("User").Where("name = ?", "belongs_to_address").First(&userAddress)) + s.True(userAddress.ID > 0) + s.True(userAddress.User.ID > 0) + } +} + +func (s *GormQueryTestSuite) TestManyToMany() { + for _, db := range s.dbs { + user := &User{ + Name: "many_to_many_name", + Roles: []*Role{ + {Name: "many_to_many_role1"}, + {Name: "many_to_many_role2"}, + }, + } + + s.Nil(db.Select(orm.Associations).Create(&user)) + s.True(user.ID > 0) + s.True(user.Roles[0].ID > 0) + s.True(user.Roles[1].ID > 0) + + var user1 User + s.Nil(db.With("Roles").Where("name = ?", "many_to_many_name").First(&user1)) + s.True(user.ID > 0) + s.True(len(user.Roles) == 2) + + var role Role + s.Nil(db.With("Users").Where("name = ?", "many_to_many_role1").First(&role)) + s.True(role.ID > 0) + s.True(len(role.Users) == 1) + s.Equal("many_to_many_name", role.Users[0].Name) + } +} + +func (s *GormQueryTestSuite) TestLimit() { + for _, db := range s.dbs { + user := User{Name: "limit_user", Avatar: "limit_avatar"} + s.Nil(db.Create(&user)) + s.True(user.ID > 0) + + user1 := User{Name: "limit_user", Avatar: "limit_avatar1"} + s.Nil(db.Create(&user1)) + s.True(user1.ID > 0) + + var user2 []User + s.Nil(db.Where("name = ?", "limit_user").Limit(1).Get(&user2)) + s.True(len(user2) > 0) + s.True(user2[0].ID > 0) + } +} + +func (s *GormQueryTestSuite) TestLoad() { + for _, db := range s.dbs { + user := User{Name: "load_user", Address: &Address{}, Books: []*Book{&Book{}, &Book{}}} + user.Address.Name = "load_address" + user.Books[0].Name = "load_book0" + user.Books[1].Name = "load_book1" + s.Nil(db.Select(orm.Associations).Create(&user)) + s.True(user.ID > 0) + s.True(user.Address.ID > 0) + s.True(user.Books[0].ID > 0) + s.True(user.Books[1].ID > 0) + + tests := []struct { + description string + setup func(description string) + }{ + { + description: "simple load relationship", + setup: func(description string) { + var user1 User + s.Nil(db.Find(&user1, user.ID)) + s.True(user1.ID > 0) + s.Nil(user1.Address) + s.True(len(user1.Books) == 0) + s.Nil(db.Load(&user1, "Address")) + s.True(user1.Address.ID > 0) + s.True(len(user1.Books) == 0) + s.Nil(db.Load(&user1, "Books")) + s.True(user1.Address.ID > 0) + s.True(len(user1.Books) == 2) + }, + }, + { + description: "load relationship with simple condition", + setup: func(description string) { + var user1 User + s.Nil(db.Find(&user1, user.ID)) + s.True(user1.ID > 0) + s.Nil(user1.Address) + s.Equal(0, len(user1.Books)) + s.Nil(db.Load(&user1, "Books", "name = ?", "load_book0")) + s.True(user1.ID > 0) + s.Nil(user1.Address) + s.Equal(1, len(user1.Books)) + s.Equal("load_book0", user.Books[0].Name) + }, + }, + { + description: "load relationship with func condition", + setup: func(description string) { + var user1 User + s.Nil(db.Find(&user1, user.ID)) + s.True(user1.ID > 0) + s.Nil(user1.Address) + s.Equal(0, len(user1.Books)) + s.Nil(db.Load(&user1, "Books", func(query ormcontract.Query) ormcontract.Query { + return query.Where("name = ?", "load_book0") + })) + s.True(user1.ID > 0) + s.Nil(user1.Address) + s.Equal(1, len(user1.Books)) + s.Equal("load_book0", user.Books[0].Name) + }, + }, + { + description: "error when relation is empty", + setup: func(description string) { + var user1 User + s.Nil(db.Find(&user1, user.ID)) + s.True(user1.ID > 0) + s.Nil(user1.Address) + s.Equal(0, len(user1.Books)) + s.EqualError(db.Load(&user1, ""), "relation cannot be empty") + }, + }, + { + description: "error when id is nil", + setup: func(description string) { + type UserNoID struct { + Name string + Avatar string + } + var userNoID UserNoID + s.EqualError(db.Load(&userNoID, "Book"), "id cannot be empty") + }, + }, + } + for _, test := range tests { + test.setup(test.description) + } + } +} + +func (s *GormQueryTestSuite) TestLoadMissing() { + for _, db := range s.dbs { + user := User{Name: "load_missing_user", Address: &Address{}, Books: []*Book{&Book{}, &Book{}}} + user.Address.Name = "load_missing_address" + user.Books[0].Name = "load_missing_book0" + user.Books[1].Name = "load_missing_book1" + s.Nil(db.Select(orm.Associations).Create(&user)) + s.True(user.ID > 0) + s.True(user.Address.ID > 0) + s.True(user.Books[0].ID > 0) + s.True(user.Books[1].ID > 0) + + tests := []struct { + description string + setup func(description string) + }{ + { + description: "load when missing", + setup: func(description string) { + var user1 User + s.Nil(db.Find(&user1, user.ID)) + s.True(user1.ID > 0) + s.Nil(user1.Address) + s.True(len(user1.Books) == 0) + s.Nil(db.LoadMissing(&user1, "Address")) + s.True(user1.Address.ID > 0) + s.True(len(user1.Books) == 0) + s.Nil(db.LoadMissing(&user1, "Books")) + s.True(user1.Address.ID > 0) + s.True(len(user1.Books) == 2) + }, + }, + { + description: "don't load when not missing", + setup: func(description string) { + var user1 User + s.Nil(db.With("Books", "name = ?", "load_missing_book0").Find(&user1, user.ID)) + s.True(user1.ID > 0) + s.Nil(user1.Address) + s.True(len(user1.Books) == 1) + s.Nil(db.LoadMissing(&user1, "Address")) + s.True(user1.Address.ID > 0) + s.Nil(db.LoadMissing(&user1, "Books")) + s.True(len(user1.Books) == 1) + }, + }, + } + for _, test := range tests { + test.setup(test.description) + } + } +} + +func (s *GormQueryTestSuite) TestRaw() { + for _, db := range s.dbs { + user := User{Name: "raw_user", Avatar: "raw_avatar"} + s.Nil(db.Create(&user)) + s.True(user.ID > 0) + + var user1 User + s.Nil(db.Raw("SELECT id, name FROM users WHERE name = ?", "raw_user").Scan(&user1)) + s.True(user1.ID > 0) + s.Equal("raw_user", user1.Name) + s.Equal("", user1.Avatar) + } +} + +func (s *GormQueryTestSuite) TestScope() { + for _, db := range s.dbs { + users := []User{{Name: "scope_user", Avatar: "scope_avatar"}, {Name: "scope_user1", Avatar: "scope_avatar1"}} + s.Nil(db.Create(&users)) + s.True(users[0].ID > 0) + s.True(users[1].ID > 0) + + var users1 []User + s.Nil(db.Scopes(paginator("1", "1")).Find(&users1)) + + s.Equal(1, len(users1)) + s.True(users1[0].ID > 0) + } +} + +func (s *GormQueryTestSuite) TestFind() { + for _, db := range s.dbs { + user := User{Name: "find_user"} + s.Nil(db.Create(&user)) + s.True(user.ID > 0) + + var user2 User + s.Nil(db.Find(&user2, user.ID)) + s.True(user2.ID > 0) + + var user3 []User + s.Nil(db.Find(&user3, []uint{user.ID})) + s.Equal(1, len(user3)) + + var user4 []User + s.Nil(db.Where("id in ?", []uint{user.ID}).Find(&user4)) + s.Equal(1, len(user4)) + } +} + +func (s *GormQueryTestSuite) TestGet() { + for _, db := range s.dbs { + user := User{Name: "get_user"} + s.Nil(db.Create(&user)) + s.True(user.ID > 0) + + var user5 []User + s.Nil(db.Where("id in ?", []uint{user.ID}).Get(&user5)) + s.Equal(1, len(user5)) + } +} + +func (s *GormQueryTestSuite) TestSelect() { + for _, db := range s.dbs { + user := User{Name: "select_user", Avatar: "select_avatar"} + s.Nil(db.Create(&user)) + s.True(user.ID > 0) + + user1 := User{Name: "select_user", Avatar: "select_avatar1"} + s.Nil(db.Create(&user1)) + s.True(user1.ID > 0) + + user2 := User{Name: "select_user1", Avatar: "select_avatar1"} + s.Nil(db.Create(&user2)) + s.True(user2.ID > 0) + + type Result struct { + Name string + Count string + } + var result []Result + s.Nil(db.Model(&User{}).Select("name, count(avatar) as count").Where("id in ?", []uint{user.ID, user1.ID, user2.ID}).Group("name").Get(&result)) + s.Equal(2, len(result)) + s.Equal("select_user", result[0].Name) + s.Equal("2", result[0].Count) + s.Equal("select_user1", result[1].Name) + s.Equal("1", result[1].Count) + + var result1 []Result + s.Nil(db.Model(&User{}).Select("name, count(avatar) as count").Group("name").Having("name = ?", "select_user").Get(&result1)) + + s.Equal(1, len(result1)) + s.Equal("select_user", result1[0].Name) + s.Equal("2", result1[0].Count) + } +} + +func (s *GormQueryTestSuite) TestSoftDelete() { + for _, db := range s.dbs { + user := User{Name: "soft_delete_user", Avatar: "soft_delete_avatar"} + s.Nil(db.Create(&user)) + s.True(user.ID > 0) + + s.Nil(db.Where("name = ?", "soft_delete_user").Delete(&User{})) + + var user1 User + s.Nil(db.Find(&user1, user.ID)) + s.Equal(uint(0), user1.ID) + + var user2 User + s.Nil(db.WithTrashed().Find(&user2, user.ID)) + s.True(user2.ID > 0) + + s.Nil(db.Where("name = ?", "soft_delete_user").ForceDelete(&User{})) + + var user3 User + s.Nil(db.WithTrashed().Find(&user3, user.ID)) + s.Equal(uint(0), user3.ID) + } +} + +func (s *GormQueryTestSuite) TestTransactionSuccess() { + for _, db := range s.dbs { + user := User{Name: "transaction_success_user", Avatar: "transaction_success_avatar"} + user1 := User{Name: "transaction_success_user1", Avatar: "transaction_success_avatar1"} + tx, err := db.Begin() + s.Nil(err) + s.Nil(tx.Create(&user)) + s.Nil(tx.Create(&user1)) + s.Nil(tx.Commit()) + + var user2, user3 User + s.Nil(db.Find(&user2, user.ID)) + s.Nil(db.Find(&user3, user1.ID)) + } +} + +func (s *GormQueryTestSuite) TestTransactionError() { + for _, db := range s.dbs { + user := User{Name: "transaction_error_user", Avatar: "transaction_error_avatar"} + user1 := User{Name: "transaction_error_user1", Avatar: "transaction_error_avatar1"} + tx, err := db.Begin() + s.Nil(err) + s.Nil(tx.Create(&user)) + s.Nil(tx.Create(&user1)) + s.Nil(tx.Rollback()) + + var users []User + s.Nil(db.Where("name = ? or name = ?", "transaction_error_user", "transaction_error_user1").Find(&users)) + s.Equal(0, len(users)) + } +} + +func (s *GormQueryTestSuite) TestUpdate() { + for _, db := range s.dbs { + user := User{Name: "update_user", Avatar: "update_avatar"} + s.Nil(db.Create(&user)) + s.True(user.ID > 0) + + user.Name = "update_user1" + s.Nil(db.Save(&user)) + s.Nil(db.Model(&User{}).Where("id = ?", user.ID).Update("avatar", "update_avatar1")) + + var user1 User + s.Nil(db.Find(&user1, user.ID)) + s.Equal("update_user1", user1.Name) + s.Equal("update_avatar1", user1.Avatar) + } +} + +func (s *GormQueryTestSuite) TestWhere() { + for _, db := range s.dbs { + user := User{Name: "where_user", Avatar: "where_avatar"} + s.Nil(db.Create(&user)) + s.True(user.ID > 0) + + user1 := User{Name: "where_user1", Avatar: "where_avatar1"} + s.Nil(db.Create(&user1)) + s.True(user1.ID > 0) + + var user2 []User + s.Nil(db.Where("name = ?", "where_user").OrWhere("avatar = ?", "where_avatar1").Find(&user2)) + s.True(len(user2) > 0) + + var user3 User + s.Nil(db.Where("name = 'where_user'").Find(&user3)) + s.True(user3.ID > 0) + + var user4 User + s.Nil(db.Where("name", "where_user").Find(&user4)) + s.True(user4.ID > 0) + } +} + +func (s *GormQueryTestSuite) TestWith() { + for _, db := range s.dbs { + user := User{Name: "with_user", Address: &Address{ + Name: "with_address", + }, Books: []*Book{{ + Name: "with_book0", + }, { + Name: "with_book1", + }}} + s.Nil(db.Select(orm.Associations).Create(&user)) + s.True(user.ID > 0) + s.True(user.Address.ID > 0) + s.True(user.Books[0].ID > 0) + s.True(user.Books[1].ID > 0) + + tests := []struct { + description string + setup func(description string) + }{ + { + description: "simple", + setup: func(description string) { + var user1 User + s.Nil(db.With("Address").With("Books").Find(&user1, user.ID)) + s.True(user1.ID > 0) + s.True(user1.Address.ID > 0) + s.True(user1.Books[0].ID > 0) + s.True(user1.Books[1].ID > 0) + }, + }, + { + description: "with simple conditions", + setup: func(description string) { + var user1 User + s.Nil(db.With("Books", "name = ?", "with_book0").Find(&user1, user.ID)) + s.True(user1.ID > 0) + s.Nil(user1.Address) + s.Equal(1, len(user1.Books)) + s.Equal("with_book0", user1.Books[0].Name) + }, + }, + { + description: "with func conditions", + setup: func(description string) { + var user1 User + s.Nil(db.With("Books", func(query ormcontract.Query) ormcontract.Query { + return query.Where("name = ?", "with_book0") + }).Find(&user1, user.ID)) + s.True(user1.ID > 0) + s.Nil(user1.Address) + s.Equal(1, len(user1.Books)) + s.Equal("with_book0", user1.Books[0].Name) + }, + }, + } + for _, test := range tests { + test.setup(test.description) + } + } +} + +func (s *GormQueryTestSuite) TestWithNesting() { + for _, db := range s.dbs { + user := User{Name: "with_nesting_user", Books: []*Book{{ + Name: "with_nesting_book0", + Author: &Author{Name: "with_nesting_author0"}, + }, { + Name: "with_nesting_book1", + Author: &Author{Name: "with_nesting_author1"}, + }}} + s.Nil(db.Select(orm.Associations).Create(&user)) + s.True(user.ID > 0) + s.True(user.Books[0].ID > 0) + s.True(user.Books[0].Author.ID > 0) + s.True(user.Books[1].ID > 0) + s.True(user.Books[1].Author.ID > 0) + + var user1 User + s.Nil(db.With("Books.Author").Find(&user1, user.ID)) + s.True(user1.ID > 0) + s.Equal("with_nesting_user", user1.Name) + s.True(user1.Books[0].ID > 0) + s.Equal("with_nesting_book0", user1.Books[0].Name) + s.True(user1.Books[0].Author.ID > 0) + s.Equal("with_nesting_author0", user1.Books[0].Author.Name) + s.True(user1.Books[1].ID > 0) + s.Equal("with_nesting_book1", user1.Books[1].Name) + s.True(user1.Books[1].Author.ID > 0) + s.Equal("with_nesting_author1", user1.Books[1].Author.Name) + } +} diff --git a/database/gorm_logger.go b/database/gorm/logger.go similarity index 97% rename from database/gorm_logger.go rename to database/gorm/logger.go index d02774ac7..629d218f2 100644 --- a/database/gorm_logger.go +++ b/database/gorm/logger.go @@ -1,4 +1,4 @@ -package database +package gorm import ( "context" @@ -14,8 +14,8 @@ import ( "gorm.io/gorm/logger" ) -// New initialize Logger -func New(writer logger.Writer, config logger.Config) logger.Interface { +// NewLogger initialize Logger +func NewLogger(writer logger.Writer, config logger.Config) logger.Interface { var ( infoStr = "%s\n[Orm] " warnStr = "%s\n[Orm] " diff --git a/database/gorm/test_utils.go b/database/gorm/test_utils.go new file mode 100644 index 000000000..6d8e04753 --- /dev/null +++ b/database/gorm/test_utils.go @@ -0,0 +1,681 @@ +package gorm + +import ( + "context" + "strconv" + + "github.com/ory/dockertest/v3" + + ormcontract "github.com/goravel/framework/contracts/database/orm" + testingdocker "github.com/goravel/framework/testing/docker" + "github.com/goravel/framework/testing/mock" +) + +const ( + dbDatabase = "goravel" + dbPassword = "Goravel(!)" + dbUser = "root" +) + +func MysqlDocker() (*dockertest.Pool, *dockertest.Resource, ormcontract.DB, error) { + pool, err := testingdocker.Pool() + if err != nil { + return nil, nil, nil, err + } + resource, err := testingdocker.Resource(pool, &dockertest.RunOptions{ + Repository: "mysql", + Tag: "5.7", + Env: []string{ + "MYSQL_ROOT_PASSWORD=" + dbPassword, + }, + }) + if err != nil { + return nil, nil, nil, err + } + + _ = resource.Expire(60) + + if err := pool.Retry(func() error { + return initDatabase(ormcontract.DriverMysql, resource.GetPort("3306/tcp")) + }); err != nil { + return nil, nil, nil, err + } + + db, err := getDB(ormcontract.DriverMysql, dbDatabase, resource.GetPort("3306/tcp")) + if err != nil { + return nil, nil, nil, err + } + + if err := initTables(ormcontract.DriverMysql, db); err != nil { + return nil, nil, nil, err + } + + return pool, resource, db, nil +} + +func PostgresqlDocker() (*dockertest.Pool, *dockertest.Resource, ormcontract.DB, error) { + pool, err := testingdocker.Pool() + if err != nil { + return nil, nil, nil, err + } + resource, err := testingdocker.Resource(pool, &dockertest.RunOptions{ + Repository: "postgres", + Tag: "11", + Env: []string{ + "POSTGRES_USER=" + dbUser, + "POSTGRES_PASSWORD=" + dbPassword, + "listen_addresses = '*'", + }, + }) + if err != nil { + return nil, nil, nil, err + } + + _ = resource.Expire(60) + + if err := pool.Retry(func() error { + return initDatabase(ormcontract.DriverPostgresql, resource.GetPort("5432/tcp")) + }); err != nil { + return nil, nil, nil, err + } + + db, err := getDB(ormcontract.DriverPostgresql, dbDatabase, resource.GetPort("5432/tcp")) + if err != nil { + return nil, nil, nil, err + } + + if err := initTables(ormcontract.DriverPostgresql, db); err != nil { + return nil, nil, nil, err + } + + return pool, resource, db, nil +} + +func SqliteDocker() (*dockertest.Pool, *dockertest.Resource, ormcontract.DB, error) { + pool, err := testingdocker.Pool() + if err != nil { + return nil, nil, nil, err + } + resource, err := testingdocker.Resource(pool, &dockertest.RunOptions{ + Repository: "nouchka/sqlite3", + Tag: "latest", + Env: []string{}, + }) + if err != nil { + return nil, nil, nil, err + } + + _ = resource.Expire(60) + + var db ormcontract.DB + if err := pool.Retry(func() error { + var err error + db, err = getDB(ormcontract.DriverSqlite, dbDatabase, "") + + return err + }); err != nil { + return nil, nil, nil, err + } + + if err := initTables(ormcontract.DriverSqlite, db); err != nil { + return nil, nil, nil, err + } + + return pool, resource, db, nil +} + +func SqlserverDocker() (*dockertest.Pool, *dockertest.Resource, ormcontract.DB, error) { + pool, err := testingdocker.Pool() + if err != nil { + return nil, nil, nil, err + } + resource, err := testingdocker.Resource(pool, &dockertest.RunOptions{ + Repository: "mcr.microsoft.com/mssql/server", + Tag: "2022-latest", + Env: []string{ + "MSSQL_SA_PASSWORD=" + dbPassword, + "ACCEPT_EULA=Y", + }, + }) + if err != nil { + return nil, nil, nil, err + } + + _ = resource.Expire(60) + + if err := pool.Retry(func() error { + return initDatabase(ormcontract.DriverSqlserver, resource.GetPort("1433/tcp")) + }); err != nil { + return nil, nil, nil, err + } + + db, err := getDB(ormcontract.DriverSqlserver, dbDatabase, resource.GetPort("1433/tcp")) + if err != nil { + return nil, nil, nil, err + } + + if err := initTables(ormcontract.DriverSqlserver, db); err != nil { + return nil, nil, nil, err + } + + return pool, resource, db, nil +} + +func initDatabase(connection ormcontract.Driver, port string) error { + var ( + database = "" + createSql = "" + ) + + switch connection { + case ormcontract.DriverMysql: + database = "mysql" + createSql = "CREATE DATABASE `goravel` DEFAULT CHARACTER SET = `utf8mb4` DEFAULT COLLATE = `utf8mb4_general_ci`;" + case ormcontract.DriverPostgresql: + database = "postgres" + createSql = "CREATE DATABASE goravel;" + case ormcontract.DriverSqlserver: + database = "msdb" + createSql = "CREATE DATABASE goravel;" + } + + db, err := getDB(connection, database, port) + if err != nil { + return err + } + + if err := db.Exec(createSql); err != nil { + return err + } + + return nil +} + +func getDB(driver ormcontract.Driver, database, port string) (ormcontract.DB, error) { + mockConfig := mock.Config() + switch driver { + case ormcontract.DriverMysql: + mockConfig.On("GetBool", "app.debug").Return(true).Once() + mockConfig.On("GetString", "database.connections.mysql.driver").Return(ormcontract.DriverMysql.String()).Once() + mockConfig.On("GetString", "database.connections.mysql.host").Return("localhost").Once() + mockConfig.On("GetString", "database.connections.mysql.port").Return(port).Once() + mockConfig.On("GetString", "database.connections.mysql.database").Return(database).Once() + mockConfig.On("GetString", "database.connections.mysql.username").Return(dbUser).Once() + mockConfig.On("GetString", "database.connections.mysql.password").Return(dbPassword).Once() + mockConfig.On("GetString", "database.connections.mysql.charset").Return("utf8mb4").Once() + mockConfig.On("GetString", "database.connections.mysql.loc").Return("Local").Once() + case ormcontract.DriverPostgresql: + mockConfig.On("GetBool", "app.debug").Return(true).Once() + mockConfig.On("GetString", "database.connections.postgresql.driver").Return(ormcontract.DriverPostgresql.String()).Once() + mockConfig.On("GetString", "database.connections.postgresql.host").Return("localhost").Once() + mockConfig.On("GetString", "database.connections.postgresql.port").Return(port).Once() + mockConfig.On("GetString", "database.connections.postgresql.database").Return(database).Once() + mockConfig.On("GetString", "database.connections.postgresql.username").Return(dbUser).Once() + mockConfig.On("GetString", "database.connections.postgresql.password").Return(dbPassword).Once() + mockConfig.On("GetString", "database.connections.postgresql.sslmode").Return("disable").Once() + mockConfig.On("GetString", "database.connections.postgresql.timezone").Return("UTC").Once() + case ormcontract.DriverSqlite: + mockConfig.On("GetBool", "app.debug").Return(true).Once() + mockConfig.On("GetString", "database.connections.sqlite.driver").Return(ormcontract.DriverSqlite.String()).Once() + mockConfig.On("GetString", "database.connections.sqlite.database").Return(database).Once() + case ormcontract.DriverSqlserver: + mockConfig.On("GetBool", "app.debug").Return(true).Once() + mockConfig.On("GetString", "database.connections.sqlserver.driver").Return(ormcontract.DriverSqlserver.String()).Once() + mockConfig.On("GetString", "database.connections.sqlserver.host").Return("localhost").Once() + mockConfig.On("GetString", "database.connections.sqlserver.port").Return(port).Once() + mockConfig.On("GetString", "database.connections.sqlserver.database").Return(database).Once() + mockConfig.On("GetString", "database.connections.sqlserver.username").Return("sa").Once() + mockConfig.On("GetString", "database.connections.sqlserver.password").Return(dbPassword).Once() + } + + return NewDB(context.Background(), driver.String()) +} + +func initTables(driver ormcontract.Driver, db ormcontract.DB) error { + if err := db.Exec(createUserTable(driver)); err != nil { + return err + } + if err := db.Exec(createAddressTable(driver)); err != nil { + return err + } + if err := db.Exec(createBookTable(driver)); err != nil { + return err + } + if err := db.Exec(createRoleTable(driver)); err != nil { + return err + } + if err := db.Exec(createHouseTable(driver)); err != nil { + return err + } + if err := db.Exec(createPhoneTable(driver)); err != nil { + return err + } + if err := db.Exec(createRoleUserTable(driver)); err != nil { + return err + } + if err := db.Exec(createAuthorTable(driver)); err != nil { + return err + } + + return nil +} + +func createUserTable(driver ormcontract.Driver) string { + switch driver { + case ormcontract.DriverMysql: + return ` +CREATE TABLE users ( + id bigint(20) unsigned NOT NULL AUTO_INCREMENT, + name varchar(255) NOT NULL, + avatar varchar(255) NOT NULL, + created_at datetime(3) NOT NULL, + updated_at datetime(3) NOT NULL, + deleted_at datetime(3) DEFAULT NULL, + PRIMARY KEY (id), + KEY idx_users_created_at (created_at), + KEY idx_users_updated_at (updated_at) +) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; +` + case ormcontract.DriverPostgresql: + return ` +CREATE TABLE users ( + id SERIAL PRIMARY KEY NOT NULL, + name varchar(255) NOT NULL, + avatar varchar(255) NOT NULL, + created_at timestamp NOT NULL, + updated_at timestamp NOT NULL, + deleted_at timestamp DEFAULT NULL +); +` + case ormcontract.DriverSqlite: + return ` +CREATE TABLE users ( + id integer PRIMARY KEY AUTOINCREMENT NOT NULL, + name varchar(255) NOT NULL, + avatar varchar(255) NOT NULL, + created_at datetime NOT NULL, + updated_at datetime NOT NULL, + deleted_at datetime DEFAULT NULL +); +` + case ormcontract.DriverSqlserver: + return ` +CREATE TABLE users ( + id bigint NOT NULL IDENTITY(1,1), + name varchar(255) NOT NULL, + avatar varchar(255) NOT NULL, + created_at datetime NOT NULL, + updated_at datetime NOT NULL, + deleted_at datetime DEFAULT NULL, + PRIMARY KEY (id) +); +` + default: + return "" + } +} + +func createAddressTable(driver ormcontract.Driver) string { + switch driver { + case ormcontract.DriverMysql: + return ` +CREATE TABLE addresses ( + id bigint(20) unsigned NOT NULL AUTO_INCREMENT, + user_id bigint(20) unsigned DEFAULT NULL, + name varchar(255) NOT NULL, + province varchar(255) NOT NULL, + created_at datetime(3) NOT NULL, + updated_at datetime(3) NOT NULL, + PRIMARY KEY (id), + KEY idx_addresses_created_at (created_at), + KEY idx_addresses_updated_at (updated_at) +) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; +` + case ormcontract.DriverPostgresql: + return ` +CREATE TABLE addresses ( + id SERIAL PRIMARY KEY NOT NULL, + user_id int DEFAULT NULL, + name varchar(255) NOT NULL, + province varchar(255) NOT NULL, + created_at timestamp NOT NULL, + updated_at timestamp NOT NULL +); +` + case ormcontract.DriverSqlite: + return ` +CREATE TABLE addresses ( + id integer PRIMARY KEY AUTOINCREMENT NOT NULL, + user_id int DEFAULT NULL, + name varchar(255) NOT NULL, + province varchar(255) NOT NULL, + created_at datetime NOT NULL, + updated_at datetime NOT NULL +); +` + case ormcontract.DriverSqlserver: + return ` +CREATE TABLE addresses ( + id bigint NOT NULL IDENTITY(1,1), + user_id bigint DEFAULT NULL, + name varchar(255) NOT NULL, + province varchar(255) NOT NULL, + created_at datetime NOT NULL, + updated_at datetime NOT NULL, + PRIMARY KEY (id) +); +` + default: + return "" + } +} + +func createBookTable(driver ormcontract.Driver) string { + switch driver { + case ormcontract.DriverMysql: + return ` +CREATE TABLE books ( + id bigint(20) unsigned NOT NULL AUTO_INCREMENT, + user_id bigint(20) unsigned DEFAULT NULL, + name varchar(255) NOT NULL, + created_at datetime(3) NOT NULL, + updated_at datetime(3) NOT NULL, + PRIMARY KEY (id), + KEY idx_books_created_at (created_at), + KEY idx_books_updated_at (updated_at) +) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; +` + case ormcontract.DriverPostgresql: + return ` +CREATE TABLE books ( + id SERIAL PRIMARY KEY NOT NULL, + user_id int DEFAULT NULL, + name varchar(255) NOT NULL, + created_at timestamp NOT NULL, + updated_at timestamp NOT NULL +); +` + case ormcontract.DriverSqlite: + return ` +CREATE TABLE books ( + id integer PRIMARY KEY AUTOINCREMENT NOT NULL, + user_id int DEFAULT NULL, + name varchar(255) NOT NULL, + created_at datetime NOT NULL, + updated_at datetime NOT NULL +); +` + case ormcontract.DriverSqlserver: + return ` +CREATE TABLE books ( + id bigint NOT NULL IDENTITY(1,1), + user_id bigint DEFAULT NULL, + name varchar(255) NOT NULL, + created_at datetime NOT NULL, + updated_at datetime NOT NULL, + PRIMARY KEY (id) +); +` + default: + return "" + } +} + +func createAuthorTable(driver ormcontract.Driver) string { + switch driver { + case ormcontract.DriverMysql: + return ` +CREATE TABLE authors ( + id bigint(20) unsigned NOT NULL AUTO_INCREMENT, + book_id bigint(20) unsigned DEFAULT NULL, + name varchar(255) NOT NULL, + created_at datetime(3) NOT NULL, + updated_at datetime(3) NOT NULL, + PRIMARY KEY (id), + KEY idx_books_created_at (created_at), + KEY idx_books_updated_at (updated_at) +) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; +` + case ormcontract.DriverPostgresql: + return ` +CREATE TABLE authors ( + id SERIAL PRIMARY KEY NOT NULL, + book_id int DEFAULT NULL, + name varchar(255) NOT NULL, + created_at timestamp NOT NULL, + updated_at timestamp NOT NULL +); +` + case ormcontract.DriverSqlite: + return ` +CREATE TABLE authors ( + id integer PRIMARY KEY AUTOINCREMENT NOT NULL, + book_id int DEFAULT NULL, + name varchar(255) NOT NULL, + created_at datetime NOT NULL, + updated_at datetime NOT NULL +); +` + case ormcontract.DriverSqlserver: + return ` +CREATE TABLE authors ( + id bigint NOT NULL IDENTITY(1,1), + book_id bigint DEFAULT NULL, + name varchar(255) NOT NULL, + created_at datetime NOT NULL, + updated_at datetime NOT NULL, + PRIMARY KEY (id) +); +` + default: + return "" + } +} + +func createRoleTable(driver ormcontract.Driver) string { + switch driver { + case ormcontract.DriverMysql: + return ` +CREATE TABLE roles ( + id bigint(20) unsigned NOT NULL AUTO_INCREMENT, + name varchar(255) NOT NULL, + created_at datetime(3) NOT NULL, + updated_at datetime(3) NOT NULL, + PRIMARY KEY (id), + KEY idx_roles_created_at (created_at), + KEY idx_roles_updated_at (updated_at) +) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; +` + case ormcontract.DriverPostgresql: + return ` +CREATE TABLE roles ( + id SERIAL PRIMARY KEY NOT NULL, + name varchar(255) NOT NULL, + created_at timestamp NOT NULL, + updated_at timestamp NOT NULL +); +` + case ormcontract.DriverSqlite: + return ` +CREATE TABLE roles ( + id integer PRIMARY KEY AUTOINCREMENT NOT NULL, + name varchar(255) NOT NULL, + created_at datetime NOT NULL, + updated_at datetime NOT NULL +); +` + case ormcontract.DriverSqlserver: + return ` +CREATE TABLE roles ( + id bigint NOT NULL IDENTITY(1,1), + name varchar(255) NOT NULL, + created_at datetime NOT NULL, + updated_at datetime NOT NULL, + PRIMARY KEY (id) +); +` + default: + return "" + } +} + +func createHouseTable(driver ormcontract.Driver) string { + switch driver { + case ormcontract.DriverMysql: + return ` +CREATE TABLE houses ( + id bigint(20) unsigned NOT NULL AUTO_INCREMENT, + name varchar(255) NOT NULL, + houseable_id bigint(20) unsigned NOT NULL, + houseable_type varchar(255) NOT NULL, + created_at datetime(3) NOT NULL, + updated_at datetime(3) NOT NULL, + PRIMARY KEY (id), + KEY idx_houses_created_at (created_at), + KEY idx_houses_updated_at (updated_at) +) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; +` + case ormcontract.DriverPostgresql: + return ` +CREATE TABLE houses ( + id SERIAL PRIMARY KEY NOT NULL, + name varchar(255) NOT NULL, + houseable_id int NOT NULL, + houseable_type varchar(255) NOT NULL, + created_at timestamp NOT NULL, + updated_at timestamp NOT NULL +); +` + case ormcontract.DriverSqlite: + return ` +CREATE TABLE houses ( + id integer PRIMARY KEY AUTOINCREMENT NOT NULL, + name varchar(255) NOT NULL, + houseable_id int NOT NULL, + houseable_type varchar(255) NOT NULL, + created_at datetime NOT NULL, + updated_at datetime NOT NULL +); +` + case ormcontract.DriverSqlserver: + return ` +CREATE TABLE houses ( + id bigint NOT NULL IDENTITY(1,1), + name varchar(255) NOT NULL, + houseable_id bigint NOT NULL, + houseable_type varchar(255) NOT NULL, + created_at datetime NOT NULL, + updated_at datetime NOT NULL, + PRIMARY KEY (id) +); +` + default: + return "" + } +} + +func createPhoneTable(driver ormcontract.Driver) string { + switch driver { + case ormcontract.DriverMysql: + return ` +CREATE TABLE phones ( + id bigint(20) unsigned NOT NULL AUTO_INCREMENT, + name varchar(255) NOT NULL, + phoneable_id bigint(20) unsigned NOT NULL, + phoneable_type varchar(255) NOT NULL, + created_at datetime(3) NOT NULL, + updated_at datetime(3) NOT NULL, + PRIMARY KEY (id), + KEY idx_phones_created_at (created_at), + KEY idx_phones_updated_at (updated_at) +) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; +` + case ormcontract.DriverPostgresql: + return ` +CREATE TABLE phones ( + id SERIAL PRIMARY KEY NOT NULL, + name varchar(255) NOT NULL, + phoneable_id int NOT NULL, + phoneable_type varchar(255) NOT NULL, + created_at timestamp NOT NULL, + updated_at timestamp NOT NULL +); +` + case ormcontract.DriverSqlite: + return ` +CREATE TABLE phones ( + id integer PRIMARY KEY AUTOINCREMENT NOT NULL, + name varchar(255) NOT NULL, + phoneable_id int NOT NULL, + phoneable_type varchar(255) NOT NULL, + created_at datetime NOT NULL, + updated_at datetime NOT NULL +); +` + case ormcontract.DriverSqlserver: + return ` +CREATE TABLE phones ( + id bigint NOT NULL IDENTITY(1,1), + name varchar(255) NOT NULL, + phoneable_id bigint NOT NULL, + phoneable_type varchar(255) NOT NULL, + created_at datetime NOT NULL, + updated_at datetime NOT NULL, + PRIMARY KEY (id) +); +` + default: + return "" + } +} + +func createRoleUserTable(driver ormcontract.Driver) string { + switch driver { + case ormcontract.DriverMysql: + return ` +CREATE TABLE role_user ( + id bigint(20) unsigned NOT NULL AUTO_INCREMENT, + role_id bigint(20) unsigned NOT NULL, + user_id bigint(20) unsigned NOT NULL, + PRIMARY KEY (id) +) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; +` + case ormcontract.DriverPostgresql: + return ` +CREATE TABLE role_user ( + id SERIAL PRIMARY KEY NOT NULL, + role_id int NOT NULL, + user_id int NOT NULL +); +` + case ormcontract.DriverSqlite: + return ` +CREATE TABLE role_user ( + id integer PRIMARY KEY AUTOINCREMENT NOT NULL, + role_id int NOT NULL, + user_id int NOT NULL +); +` + case ormcontract.DriverSqlserver: + return ` +CREATE TABLE role_user ( + id bigint NOT NULL IDENTITY(1,1), + role_id bigint NOT NULL, + user_id bigint NOT NULL, + PRIMARY KEY (id) +); +` + default: + return "" + } +} + +func paginator(page string, limit string) func(methods ormcontract.Query) ormcontract.Query { + return func(query ormcontract.Query) ormcontract.Query { + page, _ := strconv.Atoi(page) + limit, _ := strconv.Atoi(limit) + offset := (page - 1) * limit + + return query.Offset(offset).Limit(limit) + } +} diff --git a/database/gorm/utils.go b/database/gorm/utils.go new file mode 100644 index 000000000..420d7dc3d --- /dev/null +++ b/database/gorm/utils.go @@ -0,0 +1,15 @@ +package gorm + +import "reflect" + +func copyStruct(dest any) reflect.Value { + t := reflect.TypeOf(dest).Elem() + v := reflect.ValueOf(dest).Elem() + destFields := make([]reflect.StructField, 0) + for i := 0; i < t.NumField(); i++ { + destFields = append(destFields, t.Field(i)) + } + copyDestStruct := reflect.StructOf(destFields) + + return v.Convert(copyDestStruct) +} diff --git a/database/orm.go b/database/orm.go index fc70d9723..a0717e836 100644 --- a/database/orm.go +++ b/database/orm.go @@ -2,29 +2,35 @@ package database import ( "context" + "database/sql" "fmt" - contractsorm "github.com/goravel/framework/contracts/database/orm" - "github.com/goravel/framework/facades" - "github.com/gookit/color" "github.com/pkg/errors" + "gorm.io/gorm" + + ormcontract "github.com/goravel/framework/contracts/database/orm" + databasegorm "github.com/goravel/framework/database/gorm" + "github.com/goravel/framework/facades" ) type Orm struct { ctx context.Context connection string - defaultInstance contractsorm.DB - instances map[string]contractsorm.DB + defaultInstance ormcontract.DB + instances map[string]ormcontract.DB } -func NewOrm(ctx context.Context) contractsorm.Orm { - orm := &Orm{ctx: ctx} +func NewOrm(ctx context.Context) *Orm { + return &Orm{ctx: ctx} +} - return orm.Connection("") +// DEPRECATED: use gorm.New() +func NewGormInstance(connection string) (*gorm.DB, error) { + return databasegorm.New(connection) } -func (r *Orm) Connection(name string) contractsorm.Orm { +func (r *Orm) Connection(name string) ormcontract.Orm { defaultConnection := facades.Config.GetString("database.default") if name == "" { name = defaultConnection @@ -32,33 +38,43 @@ func (r *Orm) Connection(name string) contractsorm.Orm { r.connection = name if r.instances == nil { - r.instances = make(map[string]contractsorm.DB) + r.instances = make(map[string]ormcontract.DB) } - if _, exist := r.instances[name]; exist { + if instance, exist := r.instances[name]; exist { + if name == defaultConnection && r.defaultInstance == nil { + r.defaultInstance = instance + } + return r } - gorm, err := NewGormDB(r.ctx, name) + gormDB, err := databasegorm.NewDB(r.ctx, name) if err != nil { color.Redln(fmt.Sprintf("[Orm] Init connection error, %v", err)) return nil } - if gorm == nil { + if gormDB == nil { return nil } - r.instances[name] = gorm + r.instances[name] = gormDB if name == defaultConnection { - r.defaultInstance = gorm + r.defaultInstance = gormDB } return r } -func (r *Orm) Query() contractsorm.DB { +func (r *Orm) DB() (*sql.DB, error) { + db := r.Query().(*databasegorm.DB) + + return db.Instance().DB() +} + +func (r *Orm) Query() ormcontract.DB { if r.connection == "" { if r.defaultInstance == nil { r.Connection("") @@ -77,7 +93,7 @@ func (r *Orm) Query() contractsorm.DB { return instance } -func (r *Orm) Transaction(txFunc func(tx contractsorm.Transaction) error) error { +func (r *Orm) Transaction(txFunc func(tx ormcontract.Transaction) error) error { tx, err := r.Query().Begin() if err != nil { return err @@ -94,6 +110,6 @@ func (r *Orm) Transaction(txFunc func(tx contractsorm.Transaction) error) error } } -func (r *Orm) WithContext(ctx context.Context) contractsorm.Orm { +func (r *Orm) WithContext(ctx context.Context) ormcontract.Orm { return NewOrm(ctx) } diff --git a/database/orm/model.go b/database/orm/model.go index c7922909c..f48ff6880 100644 --- a/database/orm/model.go +++ b/database/orm/model.go @@ -4,10 +4,11 @@ import ( "time" "gorm.io/gorm" - - "github.com/goravel/framework/facades" + "gorm.io/gorm/clause" ) +const Associations = clause.Associations + type Model struct { ID uint `gorm:"primaryKey"` Timestamps @@ -21,18 +22,3 @@ type Timestamps struct { CreatedAt time.Time UpdatedAt time.Time } - -type Relationship struct { -} - -func (r *Relationship) HasOne(dest, id interface{}, foreignKey string) error { - return facades.Orm.Query().Where(foreignKey+" = ?", id).Find(dest) -} - -func (r *Relationship) HasMany(dest, id interface{}, foreignKey string) error { - return facades.Orm.Query().Where(foreignKey+" in ?", id).Find(dest) -} - -func (r *Relationship) belongsTo(dest, id interface{}) error { - return facades.Orm.Query().Find(dest, id) -} diff --git a/database/orm_test.go b/database/orm_test.go new file mode 100644 index 000000000..1bf4990e3 --- /dev/null +++ b/database/orm_test.go @@ -0,0 +1,185 @@ +package database + +import ( + "errors" + "log" + "testing" + + "github.com/stretchr/testify/suite" + + ormcontract "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/database/gorm" + "github.com/goravel/framework/database/orm" + "github.com/goravel/framework/support/file" + "github.com/goravel/framework/testing/mock" +) + +var connections = []ormcontract.Driver{ + ormcontract.DriverMysql, + ormcontract.DriverPostgresql, + ormcontract.DriverSqlite, + ormcontract.DriverSqlserver, +} + +type User struct { + orm.Model + orm.SoftDeletes + Name string + Avatar string +} + +type OrmSuite struct { + suite.Suite +} + +var ( + testMysqlDB ormcontract.DB + testPostgresqlDB ormcontract.DB + testSqliteDB ormcontract.DB + testSqlserverDB ormcontract.DB +) + +func TestOrmSuite(t *testing.T) { + mysqlPool, mysqlDocker, mysqlDB, err := gorm.MysqlDocker() + testMysqlDB = mysqlDB + if err != nil { + log.Fatalf("Get gorm mysql error: %s", err) + } + + postgresqlPool, postgresqlDocker, postgresqlDB, err := gorm.PostgresqlDocker() + testPostgresqlDB = postgresqlDB + if err != nil { + log.Fatalf("Get gorm postgresql error: %s", err) + } + + _, _, sqliteDB, err := gorm.SqliteDocker() + testSqliteDB = sqliteDB + if err != nil { + log.Fatalf("Get gorm sqlite error: %s", err) + } + + sqlserverPool, sqlserverDocker, sqlserverDB, err := gorm.SqlserverDocker() + testSqlserverDB = sqlserverDB + if err != nil { + log.Fatalf("Get gorm postgresql error: %s", err) + } + + suite.Run(t, new(OrmSuite)) + + file.Remove("goravel") + + if err := mysqlPool.Purge(mysqlDocker); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } + if err := postgresqlPool.Purge(postgresqlDocker); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } + if err := sqlserverPool.Purge(sqlserverDocker); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } +} + +func (s *OrmSuite) SetupTest() { + +} + +func (s *OrmSuite) TestConnection() { + mockConfig := mock.Config() + mockConfig.On("GetString", "database.default").Return(ormcontract.DriverMysql.String()).Times(4) + testOrm := newTestOrm() + for _, connection := range connections { + s.NotNil(testOrm.Connection(connection.String())) + } + + mockConfig.AssertExpectations(s.T()) +} + +func (s *OrmSuite) TestDB() { + mockConfig := mock.Config() + mockConfig.On("GetString", "database.default").Return(ormcontract.DriverMysql.String()).Times(5) + + testOrm := newTestOrm() + db, err := testOrm.DB() + s.NotNil(db) + s.Nil(err) + + for _, connection := range connections { + db, err := testOrm.Connection(connection.String()).DB() + s.NotNil(db) + s.Nil(err) + } + + mockConfig.AssertExpectations(s.T()) +} + +func (s *OrmSuite) TestQuery() { + mockConfig := mock.Config() + mockConfig.On("GetString", "database.default").Return(ormcontract.DriverMysql.String()).Times(5) + testOrm := newTestOrm() + s.NotNil(testOrm.Query()) + + for _, connection := range connections { + s.NotNil(testOrm.Connection(connection.String()).Query()) + } + + mockConfig.AssertExpectations(s.T()) +} + +func (s *OrmSuite) TestTransactionSuccess() { + mockConfig := mock.Config() + mockConfig.On("GetString", "database.default").Return(ormcontract.DriverMysql.String()).Times(12) + + testOrm := newTestOrm() + for _, connection := range connections { + user := User{Name: "transaction_success_user", Avatar: "transaction_success_avatar"} + user1 := User{Name: "transaction_success_user1", Avatar: "transaction_success_avatar1"} + s.Nil(testOrm.Connection(connection.String()).Transaction(func(tx ormcontract.Transaction) error { + s.Nil(tx.Create(&user)) + s.Nil(tx.Create(&user1)) + + return nil + })) + + var user2, user3 User + s.Nil(testOrm.Connection(connection.String()).Query().Find(&user2, user.ID)) + s.Nil(testOrm.Connection(connection.String()).Query().Find(&user3, user1.ID)) + + } + + mockConfig.AssertExpectations(s.T()) +} + +func (s *OrmSuite) TestTransactionError() { + mockConfig := mock.Config() + mockConfig.On("GetString", "database.default").Return(ormcontract.DriverMysql.String()).Times(8) + + testOrm := newTestOrm() + for _, connection := range connections { + s.NotNil(testOrm.Connection(connection.String()).Transaction(func(tx ormcontract.Transaction) error { + user := User{Name: "transaction_error_user", Avatar: "transaction_error_avatar"} + s.Nil(tx.Create(&user)) + + user1 := User{Name: "transaction_error_user1", Avatar: "transaction_error_avatar1"} + s.Nil(tx.Create(&user1)) + + return errors.New("error") + })) + + var users []User + s.Nil(testOrm.Connection(connection.String()).Query().Find(&users)) + s.Equal(0, len(users)) + } + + mockConfig.AssertExpectations(s.T()) +} + +func newTestOrm() *Orm { + return &Orm{ + instances: map[string]ormcontract.DB{ + ormcontract.DriverMysql.String(): testMysqlDB, + ormcontract.DriverPostgresql.String(): testPostgresqlDB, + ormcontract.DriverSqlite.String(): testSqliteDB, + ormcontract.DriverSqlserver.String(): testSqlserverDB, + }, + } +} diff --git a/database/service_provider.go b/database/service_provider.go index 6a6d252cf..77506ff60 100644 --- a/database/service_provider.go +++ b/database/service_provider.go @@ -1,6 +1,8 @@ package database import ( + "context" + consolecontract "github.com/goravel/framework/contracts/console" "github.com/goravel/framework/database/console" "github.com/goravel/framework/facades" @@ -10,8 +12,7 @@ type ServiceProvider struct { } func (database *ServiceProvider) Register() { - app := Application{} - facades.Orm = app.Init() + facades.Orm = NewOrm(context.Background()) } func (database *ServiceProvider) Boot() { diff --git a/database/support/dsn.go b/database/support/dsn.go index 9557b986e..3ed261c11 100644 --- a/database/support/dsn.go +++ b/database/support/dsn.go @@ -6,13 +6,6 @@ import ( "github.com/goravel/framework/facades" ) -const ( - Mysql = "mysql" - Postgresql = "postgresql" - Sqlite = "sqlite" - Sqlserver = "sqlserver" -) - func GetMysqlDsn(connection string) string { host := facades.Config.GetString("database.connections." + connection + ".host") if host == "" { diff --git a/database/support/errors.go b/database/support/errors.go new file mode 100644 index 000000000..26e1b60f0 --- /dev/null +++ b/database/support/errors.go @@ -0,0 +1,7 @@ +package support + +import "github.com/pkg/errors" + +var ( + ErrorMissingWhereClause = errors.New("WHERE conditions required") +) diff --git a/event/application.go b/event/application.go index 3e9a4c542..634ee72ee 100644 --- a/event/application.go +++ b/event/application.go @@ -9,6 +9,10 @@ type Application struct { events map[event.Event][]event.Listener } +func NewApplication() *Application { + return &Application{} +} + func (app *Application) Register(events map[event.Event][]event.Listener) { app.events = events } diff --git a/event/application_test.go b/event/application_test.go new file mode 100644 index 000000000..2f27eab1a --- /dev/null +++ b/event/application_test.go @@ -0,0 +1,242 @@ +package event + +import ( + "context" + "errors" + "log" + "testing" + "time" + + "github.com/stretchr/testify/suite" + + "github.com/goravel/framework/config" + "github.com/goravel/framework/contracts/event" + eventcontract "github.com/goravel/framework/contracts/event" + "github.com/goravel/framework/facades" + "github.com/goravel/framework/queue" + testingdocker "github.com/goravel/framework/testing/docker" +) + +var ( + testSyncListener = 0 + testAsyncListener = 0 + testCancelListener = 0 + testCancelAfterListener = 0 +) + +type EventTestSuite struct { + suite.Suite +} + +func TestEventTestSuite(t *testing.T) { + redisPool, redisResource, err := testingdocker.Redis() + if err != nil { + log.Fatalf("Get redis error: %s", err) + } + + initConfig(redisResource.GetPort("6379/tcp")) + facades.Queue = queue.NewApplication() + facades.Event = NewApplication() + + suite.Run(t, new(EventTestSuite)) + + if err := redisPool.Purge(redisResource); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } +} + +func (s *EventTestSuite) SetupTest() { + +} + +func (s *EventTestSuite) TestEvent() { + facades.Event.Register(map[event.Event][]event.Listener{ + &TestEvent{}: { + &TestSyncListener{}, + &TestAsyncListener{}, + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + go func(ctx context.Context) { + s.Nil(facades.Queue.Worker(nil).Run()) + + for { + select { + case <-ctx.Done(): + return + } + } + }(ctx) + + time.Sleep(3 * time.Second) + s.Nil(facades.Event.Job(&TestEvent{}, []eventcontract.Arg{ + {Type: "string", Value: "Goravel"}, + {Type: "int", Value: 1}, + }).Dispatch()) + time.Sleep(1 * time.Second) + s.Equal(1, testSyncListener) + s.Equal(1, testAsyncListener) +} + +func (s *EventTestSuite) TestCancelEvent() { + facades.Event.Register(map[event.Event][]event.Listener{ + &TestCancelEvent{}: { + &TestCancelListener{}, + &TestCancelAfterListener{}, + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + go func(ctx context.Context) { + s.Nil(facades.Queue.Worker(nil).Run()) + + for { + select { + case <-ctx.Done(): + return + } + } + }(ctx) + + time.Sleep(3 * time.Second) + s.EqualError(facades.Event.Job(&TestCancelEvent{}, []eventcontract.Arg{ + {Type: "string", Value: "Goravel"}, + {Type: "int", Value: 1}, + }).Dispatch(), "cancel") + time.Sleep(1 * time.Second) + s.Equal(1, testCancelListener) + s.Equal(0, testCancelAfterListener) +} + +func initConfig(redisPort string) { + application := config.NewApplication("../.env") + application.Add("app", map[string]interface{}{ + "name": "goravel", + }) + application.Add("queue", map[string]interface{}{ + "default": "redis", + "connections": map[string]interface{}{ + "sync": map[string]interface{}{ + "driver": "sync", + }, + "redis": map[string]interface{}{ + "driver": "redis", + "connection": "default", + "queue": "default", + }, + }, + }) + application.Add("database", map[string]interface{}{ + "redis": map[string]interface{}{ + "default": map[string]interface{}{ + "host": "localhost", + "password": "", + "port": redisPort, + "database": 0, + }, + }, + }) + + facades.Config = application +} + +type TestEvent struct { +} + +func (receiver *TestEvent) Handle(args []event.Arg) ([]event.Arg, error) { + return args, nil +} + +type TestCancelEvent struct { +} + +func (receiver *TestCancelEvent) Handle(args []event.Arg) ([]event.Arg, error) { + return args, nil +} + +type TestAsyncListener struct { +} + +func (receiver *TestAsyncListener) Signature() string { + return "test_async_listener" +} + +func (receiver *TestAsyncListener) Queue(args ...interface{}) event.Queue { + return event.Queue{ + Enable: true, + Connection: "", + Queue: "", + } +} + +func (receiver *TestAsyncListener) Handle(args ...interface{}) error { + testAsyncListener++ + + return nil +} + +type TestSyncListener struct { +} + +func (receiver *TestSyncListener) Signature() string { + return "test_sync_listener" +} + +func (receiver *TestSyncListener) Queue(args ...interface{}) event.Queue { + return event.Queue{ + Enable: false, + Connection: "", + Queue: "", + } +} + +func (receiver *TestSyncListener) Handle(args ...interface{}) error { + testSyncListener++ + + return nil +} + +type TestCancelListener struct { +} + +func (receiver *TestCancelListener) Signature() string { + return "test_cancel_listener" +} + +func (receiver *TestCancelListener) Queue(args ...interface{}) event.Queue { + return event.Queue{ + Enable: false, + Connection: "", + Queue: "", + } +} + +func (receiver *TestCancelListener) Handle(args ...interface{}) error { + testCancelListener++ + + return errors.New("cancel") +} + +type TestCancelAfterListener struct { +} + +func (receiver *TestCancelAfterListener) Signature() string { + return "test_cancel_after_listener" +} + +func (receiver *TestCancelAfterListener) Queue(args ...interface{}) event.Queue { + return event.Queue{ + Enable: false, + Connection: "", + Queue: "", + } +} + +func (receiver *TestCancelAfterListener) Handle(args ...interface{}) error { + testCancelAfterListener++ + + return nil +} diff --git a/event/service_provider.go b/event/service_provider.go index 0becd10b5..4d9549486 100644 --- a/event/service_provider.go +++ b/event/service_provider.go @@ -10,7 +10,7 @@ type ServiceProvider struct { } func (receiver *ServiceProvider) Register() { - facades.Event = &Application{} + facades.Event = NewApplication() } func (receiver *ServiceProvider) Boot() { diff --git a/event/support/task_test.go b/event/support/task_test.go index 2e23c9130..3a3736be6 100644 --- a/event/support/task_test.go +++ b/event/support/task_test.go @@ -4,8 +4,9 @@ import ( "errors" "testing" - "github.com/goravel/framework/contracts/event" "github.com/stretchr/testify/assert" + + "github.com/goravel/framework/contracts/event" ) type TestEvent struct { diff --git a/filesystem/application_test.go b/filesystem/application_test.go new file mode 100644 index 000000000..814dc3895 --- /dev/null +++ b/filesystem/application_test.go @@ -0,0 +1,411 @@ +package filesystem + +import ( + "io/ioutil" + "net/http" + "testing" + "time" + + "github.com/gookit/color" + "github.com/stretchr/testify/assert" + + "github.com/goravel/framework/config" + "github.com/goravel/framework/contracts/filesystem" + "github.com/goravel/framework/facades" + "github.com/goravel/framework/support/file" +) + +type TestDisk struct { + disk string + url string +} + +func TestStorage(t *testing.T) { + if !file.Exists("../.env") { + color.Redln("No filesystem tests run, need create .env based on .env.example, then initialize it") + return + } + + file.Create("test.txt", "Goravel") + initConfig() + + var driver filesystem.Driver + + disks := []TestDisk{ + { + disk: "local", + url: "http://localhost/storage/", + }, + { + disk: "oss", + url: "https://goravel.oss-cn-beijing.aliyuncs.com/", + }, + { + disk: "cos", + url: "https://goravel-1257814968.cos.ap-beijing.myqcloud.com/", + }, + { + disk: "s3", + url: "https://goravel.s3.us-east-2.amazonaws.com/", + }, + { + disk: "custom", + url: "http://localhost/storage/", + }, + } + + tests := []struct { + name string + setup func(name string, disk TestDisk) + }{ + { + name: "Put", + setup: func(name string, disk TestDisk) { + assert.Nil(t, driver.Put("Put/1.txt", "Goravel"), name) + assert.True(t, driver.Exists("Put/1.txt"), name) + assert.True(t, driver.Missing("Put/2.txt"), name) + }, + }, + { + name: "Get", + setup: func(name string, disk TestDisk) { + assert.Nil(t, driver.Put("Get/1.txt", "Goravel"), name) + assert.True(t, driver.Exists("Get/1.txt"), name) + data, err := driver.Get("Get/1.txt") + assert.Nil(t, err, name) + assert.Equal(t, "Goravel", data, name) + length, err := driver.Size("Get/1.txt") + assert.Nil(t, err, name) + assert.Equal(t, int64(7), length, name) + }, + }, + { + name: "PutFile_Text", + setup: func(name string, disk TestDisk) { + fileInfo, err := NewFile("./test.txt") + assert.Nil(t, err, name) + path, err := driver.PutFile("PutFile", fileInfo) + assert.Nil(t, err, name) + assert.True(t, driver.Exists(path), name) + data, err := driver.Get(path) + assert.Nil(t, err, name) + assert.Equal(t, "Goravel", data, name) + }, + }, + { + name: "PutFile_Image", + setup: func(name string, disk TestDisk) { + fileInfo, err := NewFile("../logo.png") + assert.Nil(t, err, name) + path, err := driver.PutFile("PutFile", fileInfo) + assert.Nil(t, err, name) + assert.True(t, driver.Exists(path), name) + }, + }, + { + name: "PutFileAs_Text", + setup: func(name string, disk TestDisk) { + fileInfo, err := NewFile("./test.txt") + assert.Nil(t, err, name) + path, err := driver.PutFileAs("PutFileAs", fileInfo, "text") + assert.Nil(t, err, name) + assert.Equal(t, "PutFileAs/text.txt", path, name) + assert.True(t, driver.Exists(path), name) + data, err := driver.Get(path) + assert.Nil(t, err, name) + assert.Equal(t, "Goravel", data, name) + + path, err = driver.PutFileAs("PutFileAs", fileInfo, "text1.txt") + assert.Nil(t, err, name) + assert.Equal(t, "PutFileAs/text1.txt", path, name) + assert.True(t, driver.Exists(path), name) + data, err = driver.Get(path) + assert.Nil(t, err, name) + assert.Equal(t, "Goravel", data, name) + }, + }, + { + name: "PutFileAs_Image", + setup: func(name string, disk TestDisk) { + fileInfo, err := NewFile("../logo.png") + assert.Nil(t, err, name) + path, err := driver.PutFileAs("PutFileAs", fileInfo, "image") + assert.Nil(t, err, name) + assert.Equal(t, "PutFileAs/image.png", path, name) + assert.True(t, driver.Exists(path), name) + + path, err = driver.PutFileAs("PutFileAs", fileInfo, "image1.png") + assert.Nil(t, err, name) + assert.Equal(t, "PutFileAs/image1.png", path, name) + assert.True(t, driver.Exists(path), name) + }, + }, + { + name: "Url", + setup: func(name string, disk TestDisk) { + assert.Nil(t, driver.Put("Url/1.txt", "Goravel"), name) + assert.True(t, driver.Exists("Url/1.txt"), name) + assert.Equal(t, disk.url+"Url/1.txt", driver.Url("Url/1.txt"), name) + if disk.disk != "local" && disk.disk != "custom" { + resp, err := http.Get(disk.url + "Url/1.txt") + assert.Nil(t, err, name) + content, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + assert.Nil(t, err, name) + assert.Equal(t, "Goravel", string(content), name) + } + }, + }, + { + name: "TemporaryUrl", + setup: func(name string, disk TestDisk) { + assert.Nil(t, driver.Put("TemporaryUrl/1.txt", "Goravel"), name) + assert.True(t, driver.Exists("TemporaryUrl/1.txt"), name) + url, err := driver.TemporaryUrl("TemporaryUrl/1.txt", time.Now().Add(5*time.Second)) + assert.Nil(t, err, name) + assert.NotEmpty(t, url, name) + if disk.disk != "local" && disk.disk != "custom" { + resp, err := http.Get(url) + assert.Nil(t, err, name) + content, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + assert.Nil(t, err, name) + assert.Equal(t, "Goravel", string(content), name) + } + }, + }, + { + name: "Copy", + setup: func(name string, disk TestDisk) { + assert.Nil(t, driver.Put("Copy/1.txt", "Goravel"), name) + assert.True(t, driver.Exists("Copy/1.txt"), name) + assert.Nil(t, driver.Copy("Copy/1.txt", "Copy1/1.txt"), name) + assert.True(t, driver.Exists("Copy/1.txt"), name) + assert.True(t, driver.Exists("Copy1/1.txt"), name) + }, + }, + { + name: "Move", + setup: func(name string, disk TestDisk) { + assert.Nil(t, driver.Put("Move/1.txt", "Goravel"), name) + assert.True(t, driver.Exists("Move/1.txt"), name) + assert.Nil(t, driver.Move("Move/1.txt", "Move1/1.txt"), name) + assert.True(t, driver.Missing("Move/1.txt"), name) + assert.True(t, driver.Exists("Move1/1.txt"), name) + }, + }, + { + name: "Delete", + setup: func(name string, disk TestDisk) { + assert.Nil(t, driver.Put("Delete/1.txt", "Goravel"), name) + assert.True(t, driver.Exists("Delete/1.txt"), name) + assert.Nil(t, driver.Delete("Delete/1.txt"), name) + assert.True(t, driver.Missing("Delete/1.txt"), name) + }, + }, + { + name: "MakeDirectory", + setup: func(name string, disk TestDisk) { + assert.Nil(t, driver.MakeDirectory("MakeDirectory1/"), name) + assert.Nil(t, driver.MakeDirectory("MakeDirectory2"), name) + assert.Nil(t, driver.MakeDirectory("MakeDirectory3/MakeDirectory4"), name) + }, + }, + { + name: "DeleteDirectory", + setup: func(name string, disk TestDisk) { + assert.Nil(t, driver.Put("DeleteDirectory/1.txt", "Goravel"), name) + assert.True(t, driver.Exists("DeleteDirectory/1.txt"), name) + assert.Nil(t, driver.DeleteDirectory("DeleteDirectory"), name) + assert.True(t, driver.Missing("DeleteDirectory/1.txt"), name) + }, + }, + { + name: "Files", + setup: func(name string, disk TestDisk) { + assert.Nil(t, driver.Put("Files/1.txt", "Goravel"), name) + assert.Nil(t, driver.Put("Files/2.txt", "Goravel"), name) + assert.Nil(t, driver.Put("Files/3/3.txt", "Goravel"), name) + assert.Nil(t, driver.Put("Files/3/4/4.txt", "Goravel"), name) + assert.True(t, driver.Exists("Files/1.txt"), name) + assert.True(t, driver.Exists("Files/2.txt"), name) + assert.True(t, driver.Exists("Files/3/3.txt"), name) + assert.True(t, driver.Exists("Files/3/4/4.txt"), name) + files, err := driver.Files("Files") + assert.Nil(t, err, name) + assert.Equal(t, []string{"1.txt", "2.txt"}, files, name) + files, err = driver.Files("./Files") + assert.Nil(t, err, name) + assert.Equal(t, []string{"1.txt", "2.txt"}, files, name) + files, err = driver.Files("/Files") + assert.Nil(t, err, name) + assert.Equal(t, []string{"1.txt", "2.txt"}, files, name) + files, err = driver.Files("./Files/") + assert.Nil(t, err, name) + assert.Equal(t, []string{"1.txt", "2.txt"}, files, name) + }, + }, + { + name: "AllFiles", + setup: func(name string, disk TestDisk) { + assert.Nil(t, driver.Put("AllFiles/1.txt", "Goravel"), name) + assert.Nil(t, driver.Put("AllFiles/2.txt", "Goravel"), name) + assert.Nil(t, driver.Put("AllFiles/3/3.txt", "Goravel"), name) + assert.Nil(t, driver.Put("AllFiles/3/4/4.txt", "Goravel"), name) + assert.True(t, driver.Exists("AllFiles/1.txt"), name) + assert.True(t, driver.Exists("AllFiles/2.txt"), name) + assert.True(t, driver.Exists("AllFiles/3/3.txt"), name) + assert.True(t, driver.Exists("AllFiles/3/4/4.txt"), name) + files, err := driver.AllFiles("AllFiles") + assert.Nil(t, err, name) + assert.Equal(t, []string{"1.txt", "2.txt", "3/3.txt", "3/4/4.txt"}, files, name) + files, err = driver.AllFiles("./AllFiles") + assert.Nil(t, err, name) + assert.Equal(t, []string{"1.txt", "2.txt", "3/3.txt", "3/4/4.txt"}, files, name) + files, err = driver.AllFiles("/AllFiles") + assert.Nil(t, err, name) + assert.Equal(t, []string{"1.txt", "2.txt", "3/3.txt", "3/4/4.txt"}, files, name) + files, err = driver.AllFiles("./AllFiles/") + assert.Nil(t, err, name) + assert.Equal(t, []string{"1.txt", "2.txt", "3/3.txt", "3/4/4.txt"}, files, name) + }, + }, + { + name: "Directories", + setup: func(name string, disk TestDisk) { + assert.Nil(t, driver.Put("Directories/1.txt", "Goravel"), name) + assert.Nil(t, driver.Put("Directories/2.txt", "Goravel"), name) + assert.Nil(t, driver.Put("Directories/3/3.txt", "Goravel"), name) + assert.Nil(t, driver.Put("Directories/3/5/5.txt", "Goravel"), name) + assert.Nil(t, driver.MakeDirectory("Directories/3/4"), name) + assert.True(t, driver.Exists("Directories/1.txt"), name) + assert.True(t, driver.Exists("Directories/2.txt"), name) + assert.True(t, driver.Exists("Directories/3/3.txt"), name) + assert.True(t, driver.Exists("Directories/3/4/"), name) + assert.True(t, driver.Exists("Directories/3/5/5.txt"), name) + files, err := driver.Directories("Directories") + assert.Nil(t, err, name) + assert.Equal(t, []string{"3/"}, files, name) + files, err = driver.Directories("./Directories") + assert.Nil(t, err, name) + assert.Equal(t, []string{"3/"}, files, name) + files, err = driver.Directories("/Directories") + assert.Nil(t, err, name) + assert.Equal(t, []string{"3/"}, files, name) + files, err = driver.Directories("./Directories/") + assert.Nil(t, err, name) + assert.Equal(t, []string{"3/"}, files, name) + }, + }, + { + name: "AllDirectories", + setup: func(name string, disk TestDisk) { + assert.Nil(t, driver.Put("AllDirectories/1.txt", "Goravel"), name) + assert.Nil(t, driver.Put("AllDirectories/2.txt", "Goravel"), name) + assert.Nil(t, driver.Put("AllDirectories/3/3.txt", "Goravel"), name) + assert.Nil(t, driver.Put("AllDirectories/3/5/6/6.txt", "Goravel"), name) + assert.Nil(t, driver.MakeDirectory("AllDirectories/3/4"), name) + assert.True(t, driver.Exists("AllDirectories/1.txt"), name) + assert.True(t, driver.Exists("AllDirectories/2.txt"), name) + assert.True(t, driver.Exists("AllDirectories/3/3.txt"), name) + assert.True(t, driver.Exists("AllDirectories/3/4/"), name) + assert.True(t, driver.Exists("AllDirectories/3/5/6/6.txt"), name) + files, err := driver.AllDirectories("AllDirectories") + assert.Nil(t, err, name) + assert.Equal(t, []string{"3/", "3/4/", "3/5/", "3/5/6/"}, files, name) + files, err = driver.AllDirectories("./AllDirectories") + assert.Nil(t, err, name) + assert.Equal(t, []string{"3/", "3/4/", "3/5/", "3/5/6/"}, files, name) + files, err = driver.AllDirectories("/AllDirectories") + assert.Nil(t, err, name) + assert.Equal(t, []string{"3/", "3/4/", "3/5/", "3/5/6/"}, files, name) + files, err = driver.AllDirectories("./AllDirectories/") + assert.Nil(t, err, name) + assert.Equal(t, []string{"3/", "3/4/", "3/5/", "3/5/6/"}, files, name) + }, + }, + } + + for _, disk := range disks { + var err error + driver, err = NewDriver(disk.disk) + assert.NotNil(t, driver) + assert.Nil(t, err) + + for _, test := range tests { + test.setup(disk.disk+" "+test.name, disk) + } + + assert.Nil(t, driver.DeleteDirectory("Put"), disk.disk) + assert.Nil(t, driver.DeleteDirectory("Get"), disk.disk) + assert.Nil(t, driver.DeleteDirectory("PutFile"), disk.disk) + assert.Nil(t, driver.DeleteDirectory("PutFileAs"), disk.disk) + assert.Nil(t, driver.DeleteDirectory("Url"), disk.disk) + assert.Nil(t, driver.DeleteDirectory("TemporaryUrl"), disk.disk) + assert.Nil(t, driver.DeleteDirectory("Copy"), disk.disk) + assert.Nil(t, driver.DeleteDirectory("Copy1"), disk.disk) + assert.Nil(t, driver.DeleteDirectory("Move"), disk.disk) + assert.Nil(t, driver.DeleteDirectory("Move1"), disk.disk) + assert.Nil(t, driver.DeleteDirectory("Delete"), disk.disk) + assert.Nil(t, driver.DeleteDirectory("MakeDirectory1"), disk.disk) + assert.Nil(t, driver.DeleteDirectory("MakeDirectory2"), disk.disk) + assert.Nil(t, driver.DeleteDirectory("MakeDirectory3"), disk.disk) + assert.Nil(t, driver.DeleteDirectory("MakeDirectory4"), disk.disk) + assert.Nil(t, driver.DeleteDirectory("DeleteDirectory"), disk.disk) + assert.Nil(t, driver.DeleteDirectory("Files"), disk.disk) + assert.Nil(t, driver.DeleteDirectory("AllFiles"), disk.disk) + assert.Nil(t, driver.DeleteDirectory("Directories"), disk.disk) + assert.Nil(t, driver.DeleteDirectory("AllDirectories"), disk.disk) + + if disk.disk == "local" || disk.disk == "custom" { + assert.True(t, file.Remove("./storage")) + } + } + file.Remove("test.txt") +} + +func initConfig() { + application := config.NewApplication("../.env") + application.Add("filesystems", map[string]any{ + "default": "local", + "disks": map[string]any{ + "local": map[string]any{ + "driver": "local", + "root": "storage/app", + "url": "http://localhost/storage", + }, + "s3": map[string]any{ + "driver": "s3", + "key": application.Env("AWS_ACCESS_KEY_ID"), + "secret": application.Env("AWS_ACCESS_KEY_SECRET"), + "region": application.Env("AWS_DEFAULT_REGION"), + "bucket": application.Env("AWS_BUCKET"), + "url": application.Env("AWS_URL"), + }, + "oss": map[string]any{ + "driver": "oss", + "key": application.Env("ALIYUN_ACCESS_KEY_ID"), + "secret": application.Env("ALIYUN_ACCESS_KEY_SECRET"), + "bucket": application.Env("ALIYUN_BUCKET"), + "url": application.Env("ALIYUN_URL"), + "endpoint": application.Env("ALIYUN_ENDPOINT"), + }, + "cos": map[string]any{ + "driver": "cos", + "key": application.Env("TENCENT_ACCESS_KEY_ID"), + "secret": application.Env("TENCENT_ACCESS_KEY_SECRET"), + "bucket": application.Env("TENCENT_BUCKET"), + "url": application.Env("TENCENT_URL"), + }, + "custom": map[string]any{ + "driver": "custom", + "via": &Local{ + root: "storage/app/public", + url: "http://localhost/storage", + }, + }, + }, + }) + + facades.Config = application +} diff --git a/foundation/application.go b/foundation/application.go index a75081784..66f15c3a0 100644 --- a/foundation/application.go +++ b/foundation/application.go @@ -83,10 +83,3 @@ func (app *Application) bootServiceProviders(serviceProviders []contracts.Servic serviceProvider.Boot() } } - -//RunningInConsole Determine if the application is running in the console. -func (app *Application) RunningInConsole() bool { - args := os.Args - - return len(args) >= 2 && args[1] == "artisan" -} diff --git a/foundation/application_test.go b/foundation/application_test.go index a7c5f0a7e..3e73bd598 100644 --- a/foundation/application_test.go +++ b/foundation/application_test.go @@ -3,11 +3,12 @@ package foundation import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/goravel/framework/config" "github.com/goravel/framework/console" "github.com/goravel/framework/contracts" "github.com/goravel/framework/facades" - "github.com/stretchr/testify/assert" ) func TestInit(t *testing.T) { diff --git a/go.mod b/go.mod index 6d49dbd7e..55cb4552e 100644 --- a/go.mod +++ b/go.mod @@ -10,9 +10,9 @@ require ( github.com/aws/aws-sdk-go-v2/service/s3 v1.29.3 github.com/gin-gonic/gin v1.7.3 github.com/go-redis/redis/v8 v8.11.4 - github.com/go-sql-driver/mysql v1.6.0 + github.com/go-sql-driver/mysql v1.7.0 github.com/golang-jwt/jwt/v4 v4.4.2 - github.com/golang-migrate/migrate/v4 v4.15.1 + github.com/golang-migrate/migrate/v4 v4.15.3-0.20230111075136-dc26c41ac9da github.com/gookit/color v1.5.2 github.com/gookit/validate v1.4.5 github.com/goravel/file-rotatelogs/v2 v2.4.1 @@ -21,35 +21,39 @@ require ( github.com/jmoiron/sqlx v1.3.5 github.com/jordan-wright/email v4.0.1-0.20210109023952-943e75fe5223+incompatible github.com/opentracing/opentracing-go v1.2.0 + github.com/ory/dockertest/v3 v3.9.1 github.com/pkg/errors v0.9.1 github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5 github.com/robfig/cron/v3 v3.0.1 github.com/rs/cors v1.8.3-0.20221003140808-fcebdb403f4d - github.com/sirupsen/logrus v1.8.1 + github.com/sirupsen/logrus v1.9.0 github.com/spf13/cast v1.4.1 github.com/spf13/viper v1.9.0 - github.com/stretchr/testify v1.8.0 + github.com/stretchr/testify v1.8.1 github.com/tencentyun/cos-go-sdk-v5 v0.7.40 github.com/urfave/cli/v2 v2.3.0 - google.golang.org/grpc v1.44.0 - gorm.io/driver/mysql v1.3.6 - gorm.io/driver/postgres v1.3.10 - gorm.io/driver/sqlite v1.3.6 - gorm.io/driver/sqlserver v1.3.2 - gorm.io/gorm v1.23.8 + google.golang.org/grpc v1.50.1 + gorm.io/driver/mysql v1.4.5 + gorm.io/driver/postgres v1.4.6 + gorm.io/driver/sqlite v1.4.4 + gorm.io/driver/sqlserver v1.4.2 + gorm.io/gorm v1.24.3 ) require ( - cloud.google.com/go v0.100.2 // indirect - cloud.google.com/go/compute v1.3.0 // indirect - cloud.google.com/go/iam v0.1.0 // indirect - cloud.google.com/go/kms v1.4.0 // indirect - cloud.google.com/go/pubsub v1.10.0 // indirect + cloud.google.com/go v0.105.0 // indirect + cloud.google.com/go/compute v1.13.0 // indirect + cloud.google.com/go/compute/metadata v0.2.1 // indirect + cloud.google.com/go/iam v0.7.0 // indirect + cloud.google.com/go/pubsub v1.28.0 // indirect + github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Azure/go-autorest v14.2.0+incompatible // indirect - github.com/Azure/go-autorest/autorest/adal v0.9.14 // indirect + github.com/Azure/go-autorest/autorest/adal v0.9.16 // indirect github.com/Azure/go-autorest/autorest/date v0.3.0 // indirect github.com/Azure/go-autorest/logger v0.2.1 // indirect github.com/Azure/go-autorest/tracing v0.6.0 // indirect + github.com/Microsoft/go-winio v0.6.0 // indirect + github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect github.com/RichardKnop/logging v0.0.0-20190827224416-1a693bdd4fae // indirect github.com/aws/aws-sdk-go v1.37.16 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.9 // indirect @@ -61,13 +65,17 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.19 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.13.19 // indirect github.com/aws/smithy-go v1.13.4 // indirect + github.com/cenkalti/backoff/v4 v4.2.0 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/clbanning/mxj v1.8.4 // indirect + github.com/containerd/continuity v0.3.0 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/denisenkom/go-mssqldb v0.12.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/form3tech-oss/jwt-go v3.2.5+incompatible // indirect + github.com/docker/cli v20.10.22+incompatible // indirect + github.com/docker/docker v20.10.22+incompatible // indirect + github.com/docker/go-connections v0.4.0 // indirect + github.com/docker/go-units v0.5.0 // indirect github.com/fsnotify/fsnotify v1.5.1 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.0 // indirect @@ -75,30 +83,29 @@ require ( github.com/go-playground/validator/v10 v10.9.0 // indirect github.com/go-redsync/redsync/v4 v4.0.4 // indirect github.com/go-stack/stack v1.8.0 // indirect - github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe // indirect - github.com/golang-sql/sqlexp v0.0.0-20170517235910-f1bb20e5a188 // indirect - github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect + github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect + github.com/golang-sql/sqlexp v0.1.0 // indirect + github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/gomodule/redigo v2.0.0+incompatible // indirect - github.com/google/go-cmp v0.5.8 // indirect + github.com/google/go-cmp v0.5.9 // indirect github.com/google/go-querystring v1.1.0 // indirect + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/google/uuid v1.3.0 // indirect - github.com/googleapis/gax-go/v2 v2.1.1 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.2.0 // indirect + github.com/googleapis/gax-go/v2 v2.7.0 // indirect github.com/gookit/filter v1.1.4 // indirect github.com/gookit/goutil v0.5.15 // indirect github.com/goravel/file-rotatelogs v0.0.0-20211215053220-2ab31dd9575c // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/hcl v1.0.0 // indirect - github.com/jackc/chunkreader/v2 v2.0.1 // indirect - github.com/jackc/pgconn v1.13.0 // indirect - github.com/jackc/pgio v1.0.0 // indirect + github.com/imdario/mergo v0.3.13 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgproto3/v2 v2.3.1 // indirect - github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect - github.com/jackc/pgtype v1.12.0 // indirect - github.com/jackc/pgx/v4 v4.17.2 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgx/v5 v5.2.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect @@ -110,11 +117,16 @@ require ( github.com/lib/pq v1.10.2 // indirect github.com/magiconair/properties v1.8.5 // indirect github.com/mattn/go-isatty v0.0.16 // indirect - github.com/mattn/go-sqlite3 v1.14.12 // indirect + github.com/mattn/go-sqlite3 v1.14.15 // indirect + github.com/microsoft/go-mssqldb v0.19.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/moby/term v0.0.0-20221205130635-1aeaba878587 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/mozillazg/go-httpheader v0.3.1 // indirect + github.com/opencontainers/go-digest v1.0.0 // indirect + github.com/opencontainers/image-spec v1.0.2 // indirect + github.com/opencontainers/runc v1.1.4 // indirect github.com/pelletier/go-toml v1.9.4 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect @@ -122,28 +134,33 @@ require ( github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/streadway/amqp v1.0.0 // indirect - github.com/stretchr/objx v0.4.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect github.com/subosito/gotenv v1.2.0 // indirect github.com/ugorji/go/codec v1.2.6 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.0.2 // indirect github.com/xdg-go/stringprep v1.0.2 // indirect + github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect + github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect + github.com/xeipuuv/gojsonschema v1.2.0 // indirect github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect go.mongodb.org/mongo-driver v1.7.0 // indirect - go.opencensus.io v0.23.0 // indirect - go.uber.org/atomic v1.9.0 // indirect - golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 // indirect - golang.org/x/net v0.0.0-20220722155237-a158d28d115b // indirect - golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 // indirect - golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 // indirect - golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 // indirect - golang.org/x/text v0.3.8 // indirect - golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 // indirect - google.golang.org/api v0.70.0 // indirect + go.opencensus.io v0.24.0 // indirect + go.uber.org/atomic v1.10.0 // indirect + golang.org/x/crypto v0.4.0 // indirect + golang.org/x/mod v0.7.0 // indirect + golang.org/x/net v0.5.0 // indirect + golang.org/x/oauth2 v0.1.0 // indirect + golang.org/x/sync v0.1.0 // indirect + golang.org/x/sys v0.4.0 // indirect + golang.org/x/text v0.6.0 // indirect + golang.org/x/time v0.1.0 // indirect + golang.org/x/tools v0.5.0 // indirect + google.golang.org/api v0.103.0 // indirect google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto v0.0.0-20220222213610-43724f9ea8cf // indirect - google.golang.org/protobuf v1.27.1 // indirect + google.golang.org/genproto v0.0.0-20221202195650-67e5cbc046fd // indirect + google.golang.org/protobuf v1.28.1 // indirect gopkg.in/ini.v1 v1.64.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/grpc/application.go b/grpc/application.go index d73ba71c1..73502b4af 100644 --- a/grpc/application.go +++ b/grpc/application.go @@ -4,12 +4,13 @@ import ( "context" "errors" "fmt" - "github.com/goravel/framework/facades" "net" grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + + "github.com/goravel/framework/facades" ) type Application struct { diff --git a/grpc/application_test.go b/grpc/application_test.go index c185e9fe6..ce694e6d1 100644 --- a/grpc/application_test.go +++ b/grpc/application_test.go @@ -2,16 +2,98 @@ package grpc import ( "context" + "errors" "fmt" + "net/http" "testing" - configmocks "github.com/goravel/framework/contracts/config/mocks" - "github.com/goravel/framework/testing/mock" - "github.com/stretchr/testify/assert" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + configmocks "github.com/goravel/framework/contracts/config/mocks" + "github.com/goravel/framework/testing/mock" ) +func TestRun(t *testing.T) { + var ( + app *Application + mockConfig *configmocks.Config + name = "test" + ) + + beforeEach := func() { + mockConfig = mock.Config() + mockConfig.On("Get", fmt.Sprintf("grpc.clients.%s.interceptors", name)).Return([]string{"test"}).Once() + + app = NewApplication() + app.UnaryServerInterceptors([]grpc.UnaryServerInterceptor{ + serverInterceptor, + }) + app.UnaryClientInterceptorGroups(map[string][]grpc.UnaryClientInterceptor{ + "test": { + clientInterceptor, + }, + }) + RegisterTestServiceServer(app.Server(), &TestController{}) + } + + tests := []struct { + name string + setup func() + expectErr bool + }{ + { + name: "success", + setup: func() { + host := "127.0.0.1:3001" + mockConfig.On("GetString", fmt.Sprintf("grpc.clients.%s.host", name)).Return(host).Once() + + go func() { + assert.Nil(t, app.Run(host)) + }() + + client, err := app.Client(context.Background(), name) + assert.Nil(t, err) + testServiceClient := NewTestServiceClient(client) + res, err := testServiceClient.Get(context.Background(), &TestRequest{ + Name: "success", + }) + + assert.Equal(t, &TestResponse{Code: http.StatusOK, Message: "Goravel: server: goravel-server, client: goravel-client"}, res) + assert.Nil(t, err) + }, + }, + { + name: "error", + setup: func() { + host := "127.0.0.1:3002" + mockConfig.On("GetString", fmt.Sprintf("grpc.clients.%s.host", name)).Return(host).Once() + + go func() { + assert.Nil(t, app.Run(host)) + }() + + client, err := app.Client(context.Background(), "test") + assert.Nil(t, err) + testServiceClient := NewTestServiceClient(client) + res, err := testServiceClient.Get(context.Background(), &TestRequest{ + Name: "error", + }) + + assert.Nil(t, res) + assert.EqualError(t, err, "rpc error: code = Unknown desc = error") + }, + }, + } + + for _, test := range tests { + beforeEach() + test.setup() + mockConfig.AssertExpectations(t) + } +} + func TestClient(t *testing.T) { var ( app *Application @@ -23,8 +105,6 @@ func TestClient(t *testing.T) { beforeEach := func() { mockConfig = mock.Config() app = NewApplication() - app.UnaryServerInterceptors([]grpc.UnaryServerInterceptor{}) - go app.Run(host) } tests := []struct { @@ -38,7 +118,7 @@ func TestClient(t *testing.T) { mockConfig.On("GetString", fmt.Sprintf("grpc.clients.%s.host", name)).Return(host).Once() mockConfig.On("Get", fmt.Sprintf("grpc.clients.%s.interceptors", name)).Return([]string{"trace"}).Once() app.UnaryClientInterceptorGroups(map[string][]grpc.UnaryClientInterceptor{ - "trace": {OpentracingClient}, + "trace": {opentracingClient}, }) }, }, @@ -79,6 +159,51 @@ func TestClient(t *testing.T) { } } -func OpentracingClient(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { +func opentracingClient(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { return nil } + +func serverInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + md = metadata.New(nil) + } + + ctx = context.WithValue(ctx, "server", "goravel-server") + if len(md["client"]) > 0 { + ctx = context.WithValue(ctx, "client", md["client"][0]) + } + + return handler(ctx, req) +} + +func clientInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + md = metadata.New(nil) + } else { + md = md.Copy() + } + + md["client"] = []string{"goravel-client"} + + if err := invoker(metadata.NewOutgoingContext(ctx, md), method, req, reply, cc, opts...); err != nil { + return err + } + + return nil +} + +type TestController struct { +} + +func (r *TestController) Get(ctx context.Context, req *TestRequest) (*TestResponse, error) { + if req.GetName() == "success" { + return &TestResponse{ + Code: http.StatusOK, + Message: fmt.Sprintf("Goravel: server: %s, client: %s", ctx.Value("server"), ctx.Value("client")), + }, nil + } else { + return nil, errors.New("error") + } +} diff --git a/grpc/test.pb.go b/grpc/test.pb.go new file mode 100644 index 000000000..902f6ff9e --- /dev/null +++ b/grpc/test.pb.go @@ -0,0 +1,626 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: test/test.proto + +package grpc + +import ( + context "context" + fmt "fmt" + proto "github.com/golang/protobuf/proto" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" + io "io" + math "math" + math_bits "math/bits" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type TestRequest struct { + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *TestRequest) Reset() { *m = TestRequest{} } +func (m *TestRequest) String() string { return proto.CompactTextString(m) } +func (*TestRequest) ProtoMessage() {} +func (*TestRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_84eb23d74a64bdab, []int{0} +} +func (m *TestRequest) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *TestRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_TestRequest.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *TestRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_TestRequest.Merge(m, src) +} +func (m *TestRequest) XXX_Size() int { + return m.Size() +} +func (m *TestRequest) XXX_DiscardUnknown() { + xxx_messageInfo_TestRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_TestRequest proto.InternalMessageInfo + +func (m *TestRequest) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +type TestResponse struct { + Code int32 `protobuf:"varint,1,opt,name=code,proto3" json:"code,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *TestResponse) Reset() { *m = TestResponse{} } +func (m *TestResponse) String() string { return proto.CompactTextString(m) } +func (*TestResponse) ProtoMessage() {} +func (*TestResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_84eb23d74a64bdab, []int{1} +} +func (m *TestResponse) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *TestResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_TestResponse.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *TestResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_TestResponse.Merge(m, src) +} +func (m *TestResponse) XXX_Size() int { + return m.Size() +} +func (m *TestResponse) XXX_DiscardUnknown() { + xxx_messageInfo_TestResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_TestResponse proto.InternalMessageInfo + +func (m *TestResponse) GetCode() int32 { + if m != nil { + return m.Code + } + return 0 +} + +func (m *TestResponse) GetMessage() string { + if m != nil { + return m.Message + } + return "" +} + +func init() { + proto.RegisterType((*TestRequest)(nil), "protos.TestRequest") + proto.RegisterType((*TestResponse)(nil), "protos.TestResponse") +} + +func init() { proto.RegisterFile("test/test.proto", fileDescriptor_84eb23d74a64bdab) } + +var fileDescriptor_84eb23d74a64bdab = []byte{ + // 173 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2f, 0x49, 0x2d, 0x2e, + 0xd1, 0x07, 0x11, 0x7a, 0x05, 0x45, 0xf9, 0x25, 0xf9, 0x42, 0x6c, 0x60, 0xaa, 0x58, 0x49, 0x91, + 0x8b, 0x3b, 0x24, 0xb5, 0xb8, 0x24, 0x28, 0xb5, 0xb0, 0x34, 0xb5, 0xb8, 0x44, 0x48, 0x88, 0x8b, + 0x25, 0x2f, 0x31, 0x37, 0x55, 0x82, 0x51, 0x81, 0x51, 0x83, 0x33, 0x08, 0xcc, 0x56, 0xb2, 0xe1, + 0xe2, 0x81, 0x28, 0x29, 0x2e, 0xc8, 0xcf, 0x2b, 0x4e, 0x05, 0xa9, 0x49, 0xce, 0x4f, 0x81, 0xa8, + 0x61, 0x0d, 0x02, 0xb3, 0x85, 0x24, 0xb8, 0xd8, 0x73, 0x53, 0x8b, 0x8b, 0x13, 0xd3, 0x53, 0x25, + 0x98, 0xc0, 0x5a, 0x61, 0x5c, 0x23, 0x47, 0x88, 0x05, 0xc1, 0xa9, 0x45, 0x65, 0x99, 0xc9, 0xa9, + 0x42, 0x46, 0x5c, 0xcc, 0xee, 0xa9, 0x25, 0x42, 0xc2, 0x10, 0x67, 0x14, 0xeb, 0x21, 0x59, 0x2e, + 0x25, 0x82, 0x2a, 0x08, 0xb1, 0x4e, 0x89, 0xc1, 0x49, 0xe0, 0xc4, 0x23, 0x39, 0xc6, 0x0b, 0x8f, + 0xe4, 0x18, 0x1f, 0x3c, 0x92, 0x63, 0x9c, 0xf1, 0x58, 0x8e, 0x21, 0x09, 0xe2, 0x7a, 0x63, 0x40, + 0x00, 0x00, 0x00, 0xff, 0xff, 0x55, 0xf9, 0xec, 0xcf, 0xd7, 0x00, 0x00, 0x00, +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// TestServiceClient is the client API for TestService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. +type TestServiceClient interface { + Get(ctx context.Context, in *TestRequest, opts ...grpc.CallOption) (*TestResponse, error) +} + +type testServiceClient struct { + cc *grpc.ClientConn +} + +func NewTestServiceClient(cc *grpc.ClientConn) TestServiceClient { + return &testServiceClient{cc} +} + +func (c *testServiceClient) Get(ctx context.Context, in *TestRequest, opts ...grpc.CallOption) (*TestResponse, error) { + out := new(TestResponse) + err := c.cc.Invoke(ctx, "/protos.TestService/Get", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// TestServiceServer is the server API for TestService service. +type TestServiceServer interface { + Get(context.Context, *TestRequest) (*TestResponse, error) +} + +// UnimplementedTestServiceServer can be embedded to have forward compatible implementations. +type UnimplementedTestServiceServer struct { +} + +func (*UnimplementedTestServiceServer) Get(ctx context.Context, req *TestRequest) (*TestResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Get not implemented") +} + +func RegisterTestServiceServer(s *grpc.Server, srv TestServiceServer) { + s.RegisterService(&_TestService_serviceDesc, srv) +} + +func _TestService_Get_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(TestRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(TestServiceServer).Get(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/protos.TestService/Get", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(TestServiceServer).Get(ctx, req.(*TestRequest)) + } + return interceptor(ctx, in, info, handler) +} + +var _TestService_serviceDesc = grpc.ServiceDesc{ + ServiceName: "protos.TestService", + HandlerType: (*TestServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Get", + Handler: _TestService_Get_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "test/test.proto", +} + +func (m *TestRequest) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *TestRequest) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *TestRequest) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.XXX_unrecognized != nil { + i -= len(m.XXX_unrecognized) + copy(dAtA[i:], m.XXX_unrecognized) + } + if len(m.Name) > 0 { + i -= len(m.Name) + copy(dAtA[i:], m.Name) + i = encodeVarintTest(dAtA, i, uint64(len(m.Name))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func (m *TestResponse) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *TestResponse) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *TestResponse) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.XXX_unrecognized != nil { + i -= len(m.XXX_unrecognized) + copy(dAtA[i:], m.XXX_unrecognized) + } + if len(m.Message) > 0 { + i -= len(m.Message) + copy(dAtA[i:], m.Message) + i = encodeVarintTest(dAtA, i, uint64(len(m.Message))) + i-- + dAtA[i] = 0x12 + } + if m.Code != 0 { + i = encodeVarintTest(dAtA, i, uint64(m.Code)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func encodeVarintTest(dAtA []byte, offset int, v uint64) int { + offset -= sovTest(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *TestRequest) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Name) + if l > 0 { + n += 1 + l + sovTest(uint64(l)) + } + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + +func (m *TestResponse) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Code != 0 { + n += 1 + sovTest(uint64(m.Code)) + } + l = len(m.Message) + if l > 0 { + n += 1 + l + sovTest(uint64(l)) + } + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + +func sovTest(x uint64) (n int) { + return (math_bits.Len64(x|1) + 6) / 7 +} +func sozTest(x uint64) (n int) { + return sovTest(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *TestRequest) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTest + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: TestRequest: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: TestRequest: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Name", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTest + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthTest + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthTest + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Name = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipTest(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthTest + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *TestResponse) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTest + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: TestResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: TestResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Code", wireType) + } + m.Code = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTest + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Code |= int32(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Message", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTest + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthTest + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthTest + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Message = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipTest(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthTest + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipTest(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + depth := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowTest + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowTest + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + case 1: + iNdEx += 8 + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowTest + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLengthTest + } + iNdEx += length + case 3: + depth++ + case 4: + if depth == 0 { + return 0, ErrUnexpectedEndOfGroupTest + } + depth-- + case 5: + iNdEx += 4 + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + if iNdEx < 0 { + return 0, ErrInvalidLengthTest + } + if depth == 0 { + return iNdEx, nil + } + } + return 0, io.ErrUnexpectedEOF +} + +var ( + ErrInvalidLengthTest = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowTest = fmt.Errorf("proto: integer overflow") + ErrUnexpectedEndOfGroupTest = fmt.Errorf("proto: unexpected end of group") +) diff --git a/http/gin_context.go b/http/gin_context.go index 858c59a0a..adbc87739 100644 --- a/http/gin_context.go +++ b/http/gin_context.go @@ -5,6 +5,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/goravel/framework/contracts/http" ) @@ -25,7 +26,12 @@ func (c *GinContext) Request() http.Request { } func (c *GinContext) Response() http.Response { - return NewGinResponse(c.instance) + responseOrigin := c.Value("responseOrigin") + if responseOrigin != nil { + return NewGinResponse(c.instance, responseOrigin.(http.ResponseOrigin)) + } + + return NewGinResponse(c.instance, &BodyWriter{ResponseWriter: c.instance.Writer}) } func (c *GinContext) WithValue(key string, value interface{}) { diff --git a/http/gin_request.go b/http/gin_request.go index 48105253f..a871dfc6b 100644 --- a/http/gin_request.go +++ b/http/gin_request.go @@ -4,15 +4,15 @@ import ( "errors" "net/http" + "github.com/gin-gonic/gin" + "github.com/gookit/validate" + contractsfilesystem "github.com/goravel/framework/contracts/filesystem" httpcontract "github.com/goravel/framework/contracts/http" validatecontract "github.com/goravel/framework/contracts/validation" "github.com/goravel/framework/facades" "github.com/goravel/framework/filesystem" "github.com/goravel/framework/validation" - - "github.com/gin-gonic/gin" - "github.com/gookit/validate" ) type GinRequest struct { @@ -32,6 +32,14 @@ func (r *GinRequest) Query(key, defaultValue string) string { return r.instance.DefaultQuery(key, defaultValue) } +func (r *GinRequest) QueryArray(key string) []string { + return r.instance.QueryArray(key) +} + +func (r *GinRequest) QueryMap(key string) map[string]string { + return r.instance.QueryMap(key) +} + func (r *GinRequest) Form(key, defaultValue string) string { return r.instance.DefaultPostForm(key, defaultValue) } @@ -107,10 +115,6 @@ func (r *GinRequest) Origin() *http.Request { return r.instance.Request } -func (r *GinRequest) Response() httpcontract.Response { - return NewGinResponse(r.instance) -} - func (r *GinRequest) Validate(rules map[string]string, options ...validatecontract.Option) (validatecontract.Validator, error) { if rules == nil || len(rules) == 0 { return nil, errors.New("rules can't be empty") diff --git a/http/gin_request_test.go b/http/gin_request_test.go deleted file mode 100644 index 06ed3764e..000000000 --- a/http/gin_request_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package http - -import ( - "testing" - - "github.com/stretchr/testify/suite" -) - -type GinRequestSuite struct { - suite.Suite -} - -func TestGinRequestSuite(t *testing.T) { - suite.Run(t, new(GinRequestSuite)) -} - -func (s *GinRequestSuite) SetupTest() { - -} - -func (s *GinRequestSuite) TestInput() { - //r := gin.Default() - //r.GET("/input", func(c *gin.Context) { - // s.True(1 == 2) - //}) - // - //go func() { - // s.Nil(r.Run(":3000")) - // select {} - //}() - // - //w := httptest.NewRecorder() - //req, _ := http.NewRequest("GET", "/input", nil) - //r.ServeHTTP(w, req) -} diff --git a/http/gin_response.go b/http/gin_response.go index 854423bc4..512ae5467 100644 --- a/http/gin_response.go +++ b/http/gin_response.go @@ -1,6 +1,7 @@ package http import ( + "bytes" "net/http" "github.com/gin-gonic/gin" @@ -10,10 +11,11 @@ import ( type GinResponse struct { instance *gin.Context + origin httpcontract.ResponseOrigin } -func NewGinResponse(instance *gin.Context) httpcontract.Response { - return &GinResponse{instance: instance} +func NewGinResponse(instance *gin.Context, origin httpcontract.ResponseOrigin) *GinResponse { + return &GinResponse{instance, origin} } func (r *GinResponse) String(code int, format string, values ...interface{}) { @@ -42,6 +44,10 @@ func (r *GinResponse) Header(key, value string) httpcontract.Response { return r } +func (r *GinResponse) Origin() httpcontract.ResponseOrigin { + return r.origin +} + type GinSuccess struct { instance *gin.Context } @@ -57,3 +63,38 @@ func (r *GinSuccess) String(format string, values ...interface{}) { func (r *GinSuccess) Json(obj interface{}) { r.instance.JSON(http.StatusOK, obj) } + +func GinResponseMiddleware() httpcontract.Middleware { + return func(ctx httpcontract.Context) { + blw := &BodyWriter{body: bytes.NewBufferString("")} + switch ctx.(type) { + case *GinContext: + blw.ResponseWriter = ctx.(*GinContext).Instance().Writer + ctx.(*GinContext).Instance().Writer = blw + } + + ctx.WithValue("responseOrigin", blw) + ctx.Request().Next() + } +} + +type BodyWriter struct { + gin.ResponseWriter + body *bytes.Buffer +} + +func (w *BodyWriter) Write(b []byte) (int, error) { + w.body.Write(b) + + return w.ResponseWriter.Write(b) +} + +func (w *BodyWriter) WriteString(s string) (int, error) { + w.body.WriteString(s) + + return w.ResponseWriter.WriteString(s) +} + +func (w *BodyWriter) Body() *bytes.Buffer { + return w.body +} diff --git a/log/application_test.go b/log/application_test.go deleted file mode 100644 index b43c2afa6..000000000 --- a/log/application_test.go +++ /dev/null @@ -1,78 +0,0 @@ -package log - -import ( - "os" - "testing" - "time" - - "github.com/goravel/framework/config" - "github.com/goravel/framework/facades" - "github.com/goravel/framework/testing/file" - testingfile "github.com/goravel/framework/testing/file" - - "github.com/stretchr/testify/assert" -) - -func TestLog(t *testing.T) { - err := testingfile.CreateEnv() - assert.Nil(t, err) - - addDefaultConfig() - - app := Application{} - instance := app.Init() - - instance.Debug("debug") - instance.Error("error") - - dailyFile := "storage/logs/goravel-" + time.Now().Format("2006-01-02") + ".log" - singleFile := "storage/logs/goravel.log" - singleErrorFile := "storage/logs/goravel-error.log" - - assert.FileExists(t, dailyFile) - assert.FileExists(t, singleFile) - assert.FileExists(t, singleErrorFile) - - assert.Equal(t, 3, file.GetLineNum(dailyFile)) - assert.Equal(t, 3, file.GetLineNum(singleFile)) - assert.Equal(t, 2, file.GetLineNum(singleErrorFile)) - - err = os.Remove(".env") - assert.Nil(t, err) - - err = os.RemoveAll("storage") - assert.Nil(t, err) -} - -//addDefaultConfig Add default config for test. -func addDefaultConfig() { - configApp := config.ServiceProvider{} - configApp.Register() - - facadesConfig := facades.Config - facadesConfig.Add("logging", map[string]interface{}{ - "default": facadesConfig.Env("LOG_CHANNEL", "stack"), - "channels": map[string]interface{}{ - "stack": map[string]interface{}{ - "driver": "stack", - "channels": []string{"daily", "single", "single-error"}, - }, - "single": map[string]interface{}{ - "driver": "single", - "path": "storage/logs/goravel.log", - "level": "debug", - }, - "single-error": map[string]interface{}{ - "driver": "single", - "path": "storage/logs/goravel-error.log", - "level": "error", - }, - "daily": map[string]interface{}{ - "driver": "daily", - "path": "storage/logs/goravel.log", - "level": facadesConfig.Env("LOG_LEVEL", "debug"), - "days": 7, - }, - }, - }) -} diff --git a/log/entry.go b/log/entry.go index 2b2d9d7e9..9b2520238 100644 --- a/log/entry.go +++ b/log/entry.go @@ -35,12 +35,12 @@ func (r *Entry) GetLevel() log.Level { return r.Level() } -// DEPRECATED: use Level() +// DEPRECATED: use Time() func (r *Entry) GetTime() time.Time { return r.Time() } -// DEPRECATED: use Level() +// DEPRECATED: use Message() func (r *Entry) GetMessage() string { return r.Message() } diff --git a/log/formatter/general.go b/log/formatter/general.go index 09c575e92..d8881375d 100644 --- a/log/formatter/general.go +++ b/log/formatter/general.go @@ -6,8 +6,9 @@ import ( "fmt" "time" - "github.com/goravel/framework/facades" "github.com/sirupsen/logrus" + + "github.com/goravel/framework/facades" ) type General struct { diff --git a/log/logrus.go b/log/logrus.go index a68df55de..f16318413 100644 --- a/log/logrus.go +++ b/log/logrus.go @@ -3,6 +3,7 @@ package log import ( "context" "errors" + "github.com/goravel/framework/contracts/log" "github.com/goravel/framework/facades" "github.com/goravel/framework/log/logger" diff --git a/log/logrus_test.go b/log/logrus_test.go index d9eb64c80..89c449831 100644 --- a/log/logrus_test.go +++ b/log/logrus_test.go @@ -20,39 +20,11 @@ import ( var singleLog = "storage/logs/goravel.log" var dailyLog = fmt.Sprintf("storage/logs/goravel-%s.log", time.Now().Format("2006-01-02")) -func initMockConfig() *configmocks.Config { - mockConfig := &configmocks.Config{} - facades.Config = mockConfig - - mockConfig.On("GetString", "logging.default").Return("stack").Once() - mockConfig.On("GetString", "logging.channels.stack.driver").Return("stack").Once() - mockConfig.On("Get", "logging.channels.stack.channels").Return([]string{"single", "daily"}).Once() - mockConfig.On("GetString", "logging.channels.daily.driver").Return("daily").Once() - mockConfig.On("GetString", "logging.channels.daily.path").Return(singleLog).Once() - mockConfig.On("GetInt", "logging.channels.daily.days").Return(7).Once() - mockConfig.On("GetString", "logging.channels.single.driver").Return("single").Once() - mockConfig.On("GetString", "logging.channels.single.path").Return(singleLog).Once() - - return mockConfig -} - -func mockDriverConfig(mockConfig *configmocks.Config) { - mockConfig.On("GetString", "logging.channels.daily.level").Return("debug").Once() - mockConfig.On("GetString", "logging.channels.single.level").Return("debug").Once() - mockConfig.On("GetString", "app.timezone").Return("UTC") - mockConfig.On("GetString", "app.env").Return("test") -} - -func initFacadesLog() { - logrusInstance := logrusInstance() - facades.Log = NewLogrus(logrusInstance, NewWriter(logrusInstance.WithContext(context.Background()))) -} - type LogrusTestSuite struct { suite.Suite } -func TestAuthTestSuite(t *testing.T) { +func TestLogrusTestSuite(t *testing.T) { suite.Run(t, new(LogrusTestSuite)) } @@ -324,3 +296,31 @@ func TestLogrus_Fatalf(t *testing.T) { assert.True(t, file.Contain(dailyLog, "test.fatal: Goravel")) file.Remove("storage") } + +func initMockConfig() *configmocks.Config { + mockConfig := &configmocks.Config{} + facades.Config = mockConfig + + mockConfig.On("GetString", "logging.default").Return("stack").Once() + mockConfig.On("GetString", "logging.channels.stack.driver").Return("stack").Once() + mockConfig.On("Get", "logging.channels.stack.channels").Return([]string{"single", "daily"}).Once() + mockConfig.On("GetString", "logging.channels.daily.driver").Return("daily").Once() + mockConfig.On("GetString", "logging.channels.daily.path").Return(singleLog).Once() + mockConfig.On("GetInt", "logging.channels.daily.days").Return(7).Once() + mockConfig.On("GetString", "logging.channels.single.driver").Return("single").Once() + mockConfig.On("GetString", "logging.channels.single.path").Return(singleLog).Once() + + return mockConfig +} + +func mockDriverConfig(mockConfig *configmocks.Config) { + mockConfig.On("GetString", "logging.channels.daily.level").Return("debug").Once() + mockConfig.On("GetString", "logging.channels.single.level").Return("debug").Once() + mockConfig.On("GetString", "app.timezone").Return("UTC") + mockConfig.On("GetString", "app.env").Return("test") +} + +func initFacadesLog() { + logrusInstance := logrusInstance() + facades.Log = NewLogrus(logrusInstance, NewWriter(logrusInstance.WithContext(context.Background()))) +} diff --git a/logo.png b/logo.png new file mode 100644 index 000000000..830dda49f Binary files /dev/null and b/logo.png differ diff --git a/mail/application.go b/mail/application.go index 6b658c6af..f6a6b8acc 100644 --- a/mail/application.go +++ b/mail/application.go @@ -1,12 +1,154 @@ package mail import ( + "crypto/tls" + "fmt" + "net/smtp" + + "github.com/jordan-wright/email" + "github.com/goravel/framework/contracts/mail" + queuecontract "github.com/goravel/framework/contracts/queue" + "github.com/goravel/framework/facades" ) type Application struct { + clone int + content mail.Content + from mail.From + to []string + cc []string + bcc []string + attaches []string +} + +func NewApplication() *Application { + return &Application{} +} + +func (r *Application) Content(content mail.Content) mail.Mail { + instance := r.instance() + instance.content = content + + return instance +} + +func (r *Application) From(from mail.From) mail.Mail { + instance := r.instance() + instance.from = from + + return instance +} + +func (r *Application) To(to []string) mail.Mail { + instance := r.instance() + instance.to = to + + return instance +} + +func (r *Application) Cc(cc []string) mail.Mail { + instance := r.instance() + instance.cc = cc + + return instance +} + +func (r *Application) Bcc(bcc []string) mail.Mail { + instance := r.instance() + instance.bcc = bcc + + return instance +} + +func (r *Application) Attach(files []string) mail.Mail { + instance := r.instance() + instance.attaches = files + + return instance +} + +func (r *Application) Send() error { + return SendMail(r.content.Subject, r.content.Html, r.from.Address, r.from.Name, r.to, r.cc, r.bcc, r.attaches) +} + +func (r *Application) Queue(queue *mail.Queue) error { + job := facades.Queue.Job(&SendMailJob{}, []queuecontract.Arg{ + {Value: r.content.Subject, Type: "string"}, + {Value: r.content.Html, Type: "string"}, + {Value: r.from.Address, Type: "string"}, + {Value: r.from.Name, Type: "string"}, + {Value: r.to, Type: "[]string"}, + {Value: r.cc, Type: "[]string"}, + {Value: r.bcc, Type: "[]string"}, + {Value: r.attaches, Type: "[]string"}, + }) + if queue != nil { + if queue.Connection != "" { + job.OnConnection(queue.Connection) + } + if queue.Queue != "" { + job.OnQueue(queue.Queue) + } + } + + return job.Dispatch() +} + +func (r *Application) instance() *Application { + if r.clone == 0 { + return &Application{clone: 1} + } + + return r +} + +func SendMail(subject, html string, fromAddress, fromName string, to, cc, bcc, attaches []string) error { + e := email.NewEmail() + if fromAddress == "" { + e.From = fmt.Sprintf("%s <%s>", facades.Config.GetString("mail.from.name"), facades.Config.GetString("mail.from.address")) + } else { + e.From = fmt.Sprintf("%s <%s>", fromName, fromAddress) + } + + e.To = to + e.Bcc = bcc + e.Cc = cc + e.Subject = subject + e.HTML = []byte(html) + + for _, attach := range attaches { + if _, err := e.AttachFile(attach); err != nil { + return err + } + } + + return e.SendWithStartTLS(fmt.Sprintf("%s:%s", facades.Config.GetString("mail.host"), + facades.Config.GetString("mail.port")), + LoginAuth(facades.Config.GetString("mail.username"), + facades.Config.GetString("mail.password")), &tls.Config{ServerName: facades.Config.GetString("mail.host")}) +} + +type loginAuth struct { + username, password string +} + +func LoginAuth(username, password string) smtp.Auth { + return &loginAuth{username, password} +} + +func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) { + return "LOGIN", []byte(a.username), nil } -func (app *Application) Init() mail.Mail { - return NewEmail() +func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) { + if more { + switch string(fromServer) { + case "Username:": + return []byte(a.username), nil + case "Password:": + return []byte(a.password), nil + } + } + return nil, nil } diff --git a/mail/application_test.go b/mail/application_test.go new file mode 100644 index 000000000..1335f45fb --- /dev/null +++ b/mail/application_test.go @@ -0,0 +1,143 @@ +package mail + +import ( + "context" + "log" + "testing" + "time" + + "github.com/gookit/color" + "github.com/stretchr/testify/suite" + + "github.com/goravel/framework/config" + "github.com/goravel/framework/contracts/event" + "github.com/goravel/framework/contracts/mail" + queuecontract "github.com/goravel/framework/contracts/queue" + "github.com/goravel/framework/facades" + "github.com/goravel/framework/queue" + "github.com/goravel/framework/support/file" + testingdocker "github.com/goravel/framework/testing/docker" + "github.com/goravel/framework/testing/mock" +) + +type ApplicationTestSuite struct { + suite.Suite +} + +func TestApplicationTestSuite(t *testing.T) { + if !file.Exists("../.env") { + color.Redln("No mail tests run, need create .env based on .env.example, then initialize it") + return + } + + redisPool, redisResource, err := testingdocker.Redis() + if err != nil { + log.Fatalf("Get redis error: %s", err) + } + + initConfig(redisResource.GetPort("6379/tcp")) + facades.Mail = NewApplication() + suite.Run(t, new(ApplicationTestSuite)) + + if err := redisPool.Purge(redisResource); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } +} + +func (s *ApplicationTestSuite) SetupTest() { + +} + +func (s *ApplicationTestSuite) TestSendMail() { + s.Nil(facades.Mail.To([]string{facades.Config.Env("MAIL_TO").(string)}). + Cc([]string{facades.Config.Env("MAIL_CC").(string)}). + Bcc([]string{facades.Config.Env("MAIL_BCC").(string)}). + Attach([]string{"../logo.png"}). + Content(mail.Content{Subject: "Goravel Test", Html: "

Hello Goravel

"}). + Send()) +} + +func (s *ApplicationTestSuite) TestSendMailWithFrom() { + s.Nil(facades.Mail.From(mail.From{Address: facades.Config.GetString("mail.from.address"), Name: facades.Config.GetString("mail.from.name")}). + To([]string{facades.Config.Env("MAIL_TO").(string)}). + Cc([]string{facades.Config.Env("MAIL_CC").(string)}). + Bcc([]string{facades.Config.Env("MAIL_BCC").(string)}). + Attach([]string{"../logo.png"}). + Content(mail.Content{Subject: "Goravel Test With From", Html: "

Hello Goravel

"}). + Send()) +} + +func (s *ApplicationTestSuite) TestQueueMail() { + facades.Queue = queue.NewApplication() + facades.Queue.Register([]queuecontract.Job{ + &SendMailJob{}, + }) + + mockEvent, _ := mock.Event() + mockEvent.On("GetEvents").Return(map[event.Event][]event.Listener{}).Once() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + go func(ctx context.Context) { + s.Nil(facades.Queue.Worker(nil).Run()) + + for { + select { + case <-ctx.Done(): + return + } + } + }(ctx) + time.Sleep(3 * time.Second) + s.Nil(facades.Mail.To([]string{facades.Config.Env("MAIL_TO").(string)}). + Cc([]string{facades.Config.Env("MAIL_CC").(string)}). + Bcc([]string{facades.Config.Env("MAIL_BCC").(string)}). + Attach([]string{"../logo.png"}). + Content(mail.Content{Subject: "Goravel Test Queue", Html: "

Hello Goravel

"}). + Queue(nil)) + time.Sleep(1 * time.Second) + + mockEvent.AssertExpectations(s.T()) +} + +func initConfig(redisPort string) { + application := config.NewApplication("../.env") + application.Add("app", map[string]interface{}{ + "name": "goravel", + }) + application.Add("mail", map[string]any{ + "host": application.Env("MAIL_HOST", ""), + "port": application.Env("MAIL_PORT", 587), + "from": map[string]interface{}{ + "address": application.Env("MAIL_FROM_ADDRESS", "hello@example.com"), + "name": application.Env("MAIL_FROM_NAME", "Example"), + }, + "username": application.Env("MAIL_USERNAME"), + "password": application.Env("MAIL_PASSWORD"), + }) + application.Add("queue", map[string]interface{}{ + "default": "redis", + "connections": map[string]interface{}{ + "sync": map[string]interface{}{ + "driver": "sync", + }, + "redis": map[string]interface{}{ + "driver": "redis", + "connection": "default", + "queue": "default", + }, + }, + }) + application.Add("database", map[string]interface{}{ + "redis": map[string]interface{}{ + "default": map[string]interface{}{ + "host": "localhost", + "password": "", + "port": redisPort, + "database": 0, + }, + }, + }) + + facades.Config = application +} diff --git a/mail/email.go b/mail/email.go deleted file mode 100644 index ed6155ee5..000000000 --- a/mail/email.go +++ /dev/null @@ -1,154 +0,0 @@ -package mail - -import ( - "crypto/tls" - "fmt" - "net/smtp" - - "github.com/jordan-wright/email" - - "github.com/goravel/framework/contracts/mail" - contractqueue "github.com/goravel/framework/contracts/queue" - "github.com/goravel/framework/facades" -) - -type Email struct { - clone int - content mail.Content - from mail.From - to []string - cc []string - bcc []string - attaches []string -} - -func NewEmail() mail.Mail { - return &Email{} -} - -func (r *Email) Content(content mail.Content) mail.Mail { - instance := r.instance() - instance.content = content - - return instance -} - -func (r *Email) From(from mail.From) mail.Mail { - instance := r.instance() - instance.from = from - - return instance -} - -func (r *Email) To(to []string) mail.Mail { - instance := r.instance() - instance.to = to - - return instance -} - -func (r *Email) Cc(cc []string) mail.Mail { - instance := r.instance() - instance.cc = cc - - return instance -} - -func (r *Email) Bcc(bcc []string) mail.Mail { - instance := r.instance() - instance.bcc = bcc - - return instance -} - -func (r *Email) Attach(files []string) mail.Mail { - instance := r.instance() - instance.attaches = files - - return instance -} - -func (r *Email) Send() error { - return SendMail(r.content.Subject, r.content.Html, r.from.Address, r.from.Name, r.to, r.cc, r.bcc, r.attaches) -} - -func (r *Email) Queue(queue *mail.Queue) error { - job := facades.Queue.Job(&SendMailJob{}, []contractqueue.Arg{ - {Value: r.content.Subject, Type: "string"}, - {Value: r.content.Html, Type: "string"}, - {Value: r.from.Address, Type: "string"}, - {Value: r.from.Name, Type: "string"}, - {Value: r.to, Type: "[]string"}, - {Value: r.cc, Type: "[]string"}, - {Value: r.bcc, Type: "[]string"}, - {Value: r.attaches, Type: "[]string"}, - }) - if queue != nil { - if queue.Connection != "" { - job.OnConnection(queue.Connection) - } - if queue.Queue != "" { - job.OnQueue(queue.Queue) - } - } - - return job.Dispatch() -} - -func (r *Email) instance() *Email { - if r.clone == 0 { - return &Email{clone: 1} - } - - return r -} - -func SendMail(subject, html string, fromAddress, fromName string, to, cc, bcc, attaches []string) error { - e := email.NewEmail() - if fromAddress == "" { - e.From = fmt.Sprintf("%s <%s>", facades.Config.GetString("mail.from.name"), facades.Config.GetString("mail.from.address")) - } else { - e.From = fmt.Sprintf("%s <%s>", fromName, fromAddress) - } - - e.To = to - e.Bcc = bcc - e.Cc = cc - e.Subject = subject - e.HTML = []byte(html) - - for _, attach := range attaches { - if _, err := e.AttachFile(attach); err != nil { - return err - } - } - - return e.SendWithStartTLS(fmt.Sprintf("%s:%s", facades.Config.GetString("mail.host"), - facades.Config.GetString("mail.port")), - LoginAuth(facades.Config.GetString("mail.username"), - facades.Config.GetString("mail.password")), &tls.Config{ServerName: facades.Config.GetString("mail.host")}) -} - -type loginAuth struct { - username, password string -} - -func LoginAuth(username, password string) smtp.Auth { - return &loginAuth{username, password} -} - -func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) { - return "LOGIN", []byte(a.username), nil -} - -func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) { - if more { - switch string(fromServer) { - case "Username:": - return []byte(a.username), nil - case "Password:": - return []byte(a.password), nil - } - } - return nil, nil -} diff --git a/mail/service_provider.go b/mail/service_provider.go index 2486385c5..29cc898e9 100644 --- a/mail/service_provider.go +++ b/mail/service_provider.go @@ -9,8 +9,7 @@ type ServiceProvider struct { } func (route *ServiceProvider) Register() { - app := Application{} - facades.Mail = app.Init() + facades.Mail = NewApplication() } func (route *ServiceProvider) Boot() { diff --git a/queue/application_test.go b/queue/application_test.go index dd2e720d6..e42212207 100644 --- a/queue/application_test.go +++ b/queue/application_test.go @@ -1,20 +1,59 @@ package queue import ( + "context" + "log" "testing" + "time" + + "github.com/ory/dockertest/v3" + "github.com/stretchr/testify/suite" configmocks "github.com/goravel/framework/contracts/config/mocks" + "github.com/goravel/framework/contracts/event" "github.com/goravel/framework/contracts/queue" "github.com/goravel/framework/queue/support" + testingdocker "github.com/goravel/framework/testing/docker" "github.com/goravel/framework/testing/mock" +) - "github.com/stretchr/testify/assert" +var ( + testSyncJob = 0 + testAsyncJob = 0 + testCustomAsyncJob = 0 + testErrorAsyncJob = 0 + testChainAsyncJob = 0 + testChainSyncJob = 0 ) -func TestWorker(t *testing.T) { +type QueueTestSuite struct { + suite.Suite + app *Application + redisResource *dockertest.Resource +} + +func TestQueueTestSuite(t *testing.T) { + redisPool, redisResource, err := testingdocker.Redis() + if err != nil { + log.Fatalf("Get redis error: %s", err) + } + + suite.Run(t, &QueueTestSuite{ + app: NewApplication(), + redisResource: redisResource, + }) + + if err := redisPool.Purge(redisResource); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } +} + +func (s *QueueTestSuite) SetupTest() { +} + +func (s *QueueTestSuite) TestWorker() { var ( mockConfig *configmocks.Config - app = NewApplication() ) beforeEach := func() { @@ -33,7 +72,6 @@ func TestWorker(t *testing.T) { mockConfig.On("GetString", "queue.default").Return("redis").Once() mockConfig.On("GetString", "app.name").Return("app").Once() mockConfig.On("GetString", "queue.connections.redis.queue", "default").Return("queue").Once() - }, expectWorker: &support.Worker{ Connection: "redis", @@ -62,7 +100,329 @@ func TestWorker(t *testing.T) { for _, test := range tests { beforeEach() test.setup() - worker := app.Worker(test.args) - assert.Equal(t, test.expectWorker, worker, test.description) + worker := s.app.Worker(test.args) + s.Equal(test.expectWorker, worker, test.description) + mockConfig.AssertExpectations(s.T()) } } + +func (s *QueueTestSuite) TestSyncQueue() { + mockConfig := mock.Config() + mockConfig.On("GetString", "queue.default").Return("redis").Once() + mockConfig.On("GetString", "app.name").Return("goravel").Twice() + mockConfig.On("GetString", "queue.connections.redis.queue", "default").Return("default").Twice() + mockConfig.On("GetString", "queue.connections.redis.driver").Return("redis").Once() + mockConfig.On("GetString", "queue.connections.redis.connection").Return("default").Once() + mockConfig.On("GetString", "database.redis.default.host").Return("localhost").Once() + mockConfig.On("GetString", "database.redis.default.password").Return("").Once() + mockConfig.On("GetString", "database.redis.default.port").Return(s.redisResource.GetPort("6379/tcp")).Once() + mockConfig.On("GetInt", "database.redis.default.database").Return(0).Once() + + mockQueue, _ := mock.Queue() + mockQueue.On("GetJobs").Return([]queue.Job{&TestSyncJob{}}).Once() + + mockEvent, _ := mock.Event() + mockEvent.On("GetEvents").Return(map[event.Event][]event.Listener{}).Once() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + go func(ctx context.Context) { + s.Nil(s.app.Worker(nil).Run()) + + for { + select { + case <-ctx.Done(): + return + } + } + }(ctx) + time.Sleep(3 * time.Second) + + s.Nil(s.app.Job(&TestSyncJob{}, []queue.Arg{ + {Type: "string", Value: "TestSyncQueue"}, + {Type: "int", Value: 1}, + }).DispatchSync()) + s.Equal(1, testSyncJob) + + mockConfig.AssertExpectations(s.T()) + mockQueue.AssertExpectations(s.T()) + mockEvent.AssertExpectations(s.T()) +} + +func (s *QueueTestSuite) TestDefaultAsyncQueue() { + mockConfig := mock.Config() + mockConfig.On("GetString", "queue.default").Return("redis").Times(3) + mockConfig.On("GetString", "app.name").Return("goravel").Times(3) + mockConfig.On("GetString", "queue.connections.redis.queue", "default").Return("default").Times(3) + mockConfig.On("GetString", "queue.connections.redis.driver").Return("redis").Times(3) + mockConfig.On("GetString", "queue.connections.redis.connection").Return("default").Twice() + mockConfig.On("GetString", "database.redis.default.host").Return("localhost").Twice() + mockConfig.On("GetString", "database.redis.default.password").Return("").Twice() + mockConfig.On("GetString", "database.redis.default.port").Return(s.redisResource.GetPort("6379/tcp")).Twice() + mockConfig.On("GetInt", "database.redis.default.database").Return(0).Twice() + + mockQueue, _ := mock.Queue() + mockQueue.On("GetJobs").Return([]queue.Job{&TestAsyncJob{}}).Once() + + mockEvent, _ := mock.Event() + mockEvent.On("GetEvents").Return(map[event.Event][]event.Listener{}).Once() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + go func(ctx context.Context) { + s.Nil(s.app.Worker(nil).Run()) + + for { + select { + case <-ctx.Done(): + return + } + } + }(ctx) + time.Sleep(3 * time.Second) + s.Nil(s.app.Job(&TestAsyncJob{}, []queue.Arg{ + {Type: "string", Value: "TestDefaultAsyncQueue"}, + {Type: "int", Value: 1}, + }).Dispatch()) + time.Sleep(1 * time.Second) + s.Equal(1, testAsyncJob) + + mockConfig.AssertExpectations(s.T()) + mockQueue.AssertExpectations(s.T()) + mockEvent.AssertExpectations(s.T()) +} + +func (s *QueueTestSuite) TestCustomAsyncQueue() { + mockConfig := mock.Config() + mockConfig.On("GetString", "app.name").Return("goravel").Times(4) + mockConfig.On("GetString", "queue.connections.test.queue", "default").Return("default").Twice() + mockConfig.On("GetString", "queue.connections.test.driver").Return("redis").Times(3) + mockConfig.On("GetString", "queue.connections.test.connection").Return("default").Twice() + mockConfig.On("GetString", "database.redis.default.host").Return("localhost").Twice() + mockConfig.On("GetString", "database.redis.default.password").Return("").Twice() + mockConfig.On("GetString", "database.redis.default.port").Return(s.redisResource.GetPort("6379/tcp")).Twice() + mockConfig.On("GetInt", "database.redis.default.database").Return(0).Twice() + + mockQueue, _ := mock.Queue() + mockQueue.On("GetJobs").Return([]queue.Job{&TestCustomAsyncJob{}}).Once() + + mockEvent, _ := mock.Event() + mockEvent.On("GetEvents").Return(map[event.Event][]event.Listener{}).Once() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + go func(ctx context.Context) { + s.Nil(s.app.Worker(&queue.Args{ + Connection: "test", + Queue: "test1", + Concurrent: 2, + }).Run()) + + for { + select { + case <-ctx.Done(): + return + } + } + }(ctx) + time.Sleep(3 * time.Second) + s.Nil(s.app.Job(&TestCustomAsyncJob{}, []queue.Arg{ + {Type: "string", Value: "TestCustomAsyncQueue"}, + {Type: "int", Value: 1}, + }).OnConnection("test").OnQueue("test1").Dispatch()) + time.Sleep(1 * time.Second) + s.Equal(1, testCustomAsyncJob) + + mockConfig.AssertExpectations(s.T()) + mockQueue.AssertExpectations(s.T()) + mockEvent.AssertExpectations(s.T()) +} + +func (s *QueueTestSuite) TestErrorAsyncQueue() { + mockConfig := mock.Config() + mockConfig.On("GetString", "queue.default").Return("redis").Once() + mockConfig.On("GetString", "app.name").Return("goravel").Times(4) + mockConfig.On("GetString", "queue.connections.redis.queue", "default").Return("default").Times(3) + mockConfig.On("GetString", "queue.connections.redis.driver").Return("redis").Times(3) + mockConfig.On("GetString", "queue.connections.redis.connection").Return("default").Twice() + mockConfig.On("GetString", "database.redis.default.host").Return("localhost").Twice() + mockConfig.On("GetString", "database.redis.default.password").Return("").Twice() + mockConfig.On("GetString", "database.redis.default.port").Return(s.redisResource.GetPort("6379/tcp")).Twice() + mockConfig.On("GetInt", "database.redis.default.database").Return(0).Twice() + + mockQueue, _ := mock.Queue() + mockQueue.On("GetJobs").Return([]queue.Job{&TestErrorAsyncJob{}}).Once() + + mockEvent, _ := mock.Event() + mockEvent.On("GetEvents").Return(map[event.Event][]event.Listener{}).Once() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + go func(ctx context.Context) { + s.Nil(s.app.Worker(nil).Run()) + + for { + select { + case <-ctx.Done(): + return + } + } + }(ctx) + time.Sleep(3 * time.Second) + s.Nil(s.app.Job(&TestErrorAsyncJob{}, []queue.Arg{ + {Type: "string", Value: "TestErrorAsyncQueue"}, + {Type: "int", Value: 1}, + }).OnConnection("redis").OnQueue("test2").Dispatch()) + time.Sleep(1 * time.Second) + s.Equal(0, testErrorAsyncJob) + + mockConfig.AssertExpectations(s.T()) + mockQueue.AssertExpectations(s.T()) + mockEvent.AssertExpectations(s.T()) +} + +func (s *QueueTestSuite) TestChainAsyncQueue() { + mockConfig := mock.Config() + mockConfig.On("GetString", "queue.default").Return("redis").Times(3) + mockConfig.On("GetString", "app.name").Return("goravel").Times(3) + mockConfig.On("GetString", "queue.connections.redis.queue", "default").Return("default").Times(3) + mockConfig.On("GetString", "queue.connections.redis.driver").Return("redis").Times(3) + mockConfig.On("GetString", "queue.connections.redis.connection").Return("default").Twice() + mockConfig.On("GetString", "database.redis.default.host").Return("localhost").Twice() + mockConfig.On("GetString", "database.redis.default.password").Return("").Twice() + mockConfig.On("GetString", "database.redis.default.port").Return(s.redisResource.GetPort("6379/tcp")).Twice() + mockConfig.On("GetInt", "database.redis.default.database").Return(0).Twice() + + mockQueue, _ := mock.Queue() + mockQueue.On("GetJobs").Return([]queue.Job{&TestChainAsyncJob{}, &TestChainSyncJob{}}).Once() + + mockEvent, _ := mock.Event() + mockEvent.On("GetEvents").Return(map[event.Event][]event.Listener{}).Once() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + go func(ctx context.Context) { + s.Nil(s.app.Worker(nil).Run()) + + for { + select { + case <-ctx.Done(): + return + } + } + }(ctx) + time.Sleep(3 * time.Second) + s.Nil(s.app.Chain([]queue.Jobs{ + { + Job: &TestChainAsyncJob{}, + Args: []queue.Arg{ + {Type: "string", Value: "TestChainAsyncQueue"}, + {Type: "int", Value: 1}, + }, + }, + { + Job: &TestChainSyncJob{}, + Args: []queue.Arg{ + {Type: "string", Value: "TestChainSyncQueue"}, + {Type: "int", Value: 1}, + }, + }, + }).Dispatch()) + time.Sleep(1 * time.Second) + s.Equal(1, testChainAsyncJob) + s.Equal(1, testChainSyncJob) + + mockConfig.AssertExpectations(s.T()) + mockQueue.AssertExpectations(s.T()) + mockEvent.AssertExpectations(s.T()) +} + +type TestAsyncJob struct { +} + +//Signature The name and signature of the job. +func (receiver *TestAsyncJob) Signature() string { + return "test_async_job" +} + +//Handle Execute the job. +func (receiver *TestAsyncJob) Handle(args ...interface{}) error { + testAsyncJob++ + + return nil +} + +type TestSyncJob struct { +} + +//Signature The name and signature of the job. +func (receiver *TestSyncJob) Signature() string { + return "test_sync_job" +} + +//Handle Execute the job. +func (receiver *TestSyncJob) Handle(args ...interface{}) error { + testSyncJob++ + + return nil +} + +type TestCustomAsyncJob struct { +} + +//Signature The name and signature of the job. +func (receiver *TestCustomAsyncJob) Signature() string { + return "test_async_job" +} + +//Handle Execute the job. +func (receiver *TestCustomAsyncJob) Handle(args ...interface{}) error { + testCustomAsyncJob++ + + return nil +} + +type TestErrorAsyncJob struct { +} + +//Signature The name and signature of the job. +func (receiver *TestErrorAsyncJob) Signature() string { + return "test_async_job" +} + +//Handle Execute the job. +func (receiver *TestErrorAsyncJob) Handle(args ...interface{}) error { + testErrorAsyncJob++ + + return nil +} + +type TestChainAsyncJob struct { +} + +//Signature The name and signature of the job. +func (receiver *TestChainAsyncJob) Signature() string { + return "test_async_job" +} + +//Handle Execute the job. +func (receiver *TestChainAsyncJob) Handle(args ...interface{}) error { + testChainAsyncJob++ + + return nil +} + +type TestChainSyncJob struct { +} + +//Signature The name and signature of the job. +func (receiver *TestChainSyncJob) Signature() string { + return "test_sync_job" +} + +//Handle Execute the job. +func (receiver *TestChainSyncJob) Handle(args ...interface{}) error { + testChainSyncJob++ + + return nil +} diff --git a/route/application.go b/route/application.go deleted file mode 100644 index c31303a6a..000000000 --- a/route/application.go +++ /dev/null @@ -1,12 +0,0 @@ -package route - -import ( - "github.com/goravel/framework/contracts/route" -) - -type Application struct { -} - -func (app *Application) Init() route.Engine { - return NewGin() -} diff --git a/route/gin.go b/route/gin.go index f6949f5d1..9de04ff49 100644 --- a/route/gin.go +++ b/route/gin.go @@ -3,9 +3,6 @@ package route import ( "fmt" "net/http" - "regexp" - "strings" - "time" "github.com/gin-gonic/gin" "github.com/gookit/color" @@ -13,8 +10,7 @@ import ( httpcontract "github.com/goravel/framework/contracts/http" "github.com/goravel/framework/contracts/route" "github.com/goravel/framework/facades" - "github.com/goravel/framework/foundation" - frameworkhttp "github.com/goravel/framework/http" + goravelhttp "github.com/goravel/framework/http" ) type Gin struct { @@ -22,7 +18,7 @@ type Gin struct { instance *gin.Engine } -func NewGin() route.Engine { +func NewGin() *Gin { gin.SetMode(gin.ReleaseMode) engine := gin.New() if debugLog := getDebugLog(); debugLog != nil { @@ -33,12 +29,12 @@ func NewGin() route.Engine { engine.Group("/"), "", []httpcontract.Middleware{}, + []httpcontract.Middleware{goravelhttp.GinResponseMiddleware()}, )} } func (r *Gin) Run(addr string) error { - rootApp := foundation.Application{} - if facades.Config.GetBool("app.debug") && !rootApp.RunningInConsole() { + if facades.Config.GetBool("app.debug") && !runningInConsole() { routes := r.instance.Routes() for _, item := range routes { fmt.Printf("%-10s %s\n", item.Method, colonToBracket(item.Path)) @@ -55,193 +51,13 @@ func (r *Gin) ServeHTTP(w http.ResponseWriter, req *http.Request) { } func (r *Gin) GlobalMiddleware(handlers ...httpcontract.Middleware) { - r.instance.Use(middlewaresToGinHandlers(handlers)...) + if len(handlers) > 0 { + r.instance.Use(middlewaresToGinHandlers(handlers)...) + } r.Route = NewGinGroup( r.instance.Group("/"), "", []httpcontract.Middleware{}, + []httpcontract.Middleware{goravelhttp.GinResponseMiddleware()}, ) } - -type GinGroup struct { - instance gin.IRouter - originPrefix string - originMiddlewares []httpcontract.Middleware - prefix string - middlewares []httpcontract.Middleware -} - -func NewGinGroup(instance gin.IRouter, prefix string, originMiddlewares []httpcontract.Middleware) route.Route { - return &GinGroup{ - instance: instance, - originPrefix: prefix, - originMiddlewares: originMiddlewares, - } -} - -func (r *GinGroup) Group(handler route.GroupFunc) { - var middlewares []httpcontract.Middleware - middlewares = append(middlewares, r.originMiddlewares...) - middlewares = append(middlewares, r.middlewares...) - r.middlewares = []httpcontract.Middleware{} - prefix := pathToGinPath(r.originPrefix + "/" + r.prefix) - r.prefix = "" - - handler(NewGinGroup(r.instance, prefix, middlewares)) -} - -func (r *GinGroup) Prefix(addr string) route.Route { - r.prefix += "/" + addr - - return r -} - -func (r *GinGroup) Middleware(handlers ...httpcontract.Middleware) route.Route { - r.middlewares = append(r.middlewares, handlers...) - - return r -} - -func (r *GinGroup) Any(relativePath string, handler httpcontract.HandlerFunc) { - r.getGinRoutesWithMiddlewares().Any(pathToGinPath(relativePath), []gin.HandlerFunc{handlerToGinHandler(handler)}...) -} - -func (r *GinGroup) Get(relativePath string, handler httpcontract.HandlerFunc) { - r.getGinRoutesWithMiddlewares().GET(pathToGinPath(relativePath), []gin.HandlerFunc{handlerToGinHandler(handler)}...) -} - -func (r *GinGroup) Post(relativePath string, handler httpcontract.HandlerFunc) { - r.getGinRoutesWithMiddlewares().POST(pathToGinPath(relativePath), []gin.HandlerFunc{handlerToGinHandler(handler)}...) -} - -func (r *GinGroup) Delete(relativePath string, handler httpcontract.HandlerFunc) { - r.getGinRoutesWithMiddlewares().DELETE(pathToGinPath(relativePath), []gin.HandlerFunc{handlerToGinHandler(handler)}...) -} - -func (r *GinGroup) Patch(relativePath string, handler httpcontract.HandlerFunc) { - r.getGinRoutesWithMiddlewares().PATCH(pathToGinPath(relativePath), []gin.HandlerFunc{handlerToGinHandler(handler)}...) -} - -func (r *GinGroup) Put(relativePath string, handler httpcontract.HandlerFunc) { - r.getGinRoutesWithMiddlewares().PUT(pathToGinPath(relativePath), []gin.HandlerFunc{handlerToGinHandler(handler)}...) -} - -func (r *GinGroup) Options(relativePath string, handler httpcontract.HandlerFunc) { - r.getGinRoutesWithMiddlewares().OPTIONS(pathToGinPath(relativePath), []gin.HandlerFunc{handlerToGinHandler(handler)}...) -} - -func (r *GinGroup) Static(relativePath, root string) { - r.getGinRoutesWithMiddlewares().Static(pathToGinPath(relativePath), root) -} - -func (r *GinGroup) StaticFile(relativePath, filepath string) { - r.getGinRoutesWithMiddlewares().StaticFile(pathToGinPath(relativePath), filepath) -} - -func (r *GinGroup) StaticFS(relativePath string, fs http.FileSystem) { - r.getGinRoutesWithMiddlewares().StaticFS(pathToGinPath(relativePath), fs) -} - -func (r *GinGroup) getGinRoutesWithMiddlewares() gin.IRoutes { - prefix := pathToGinPath(r.originPrefix + "/" + r.prefix) - r.prefix = "" - ginGroup := r.instance.Group(prefix) - - var middlewares []gin.HandlerFunc - ginOriginMiddlewares := middlewaresToGinHandlers(r.originMiddlewares) - ginMiddlewares := middlewaresToGinHandlers(r.middlewares) - middlewares = append(middlewares, ginOriginMiddlewares...) - middlewares = append(middlewares, ginMiddlewares...) - r.middlewares = []httpcontract.Middleware{} - if len(middlewares) > 0 { - return ginGroup.Use(middlewares...) - } else { - return ginGroup - } -} - -func pathToGinPath(relativePath string) string { - return bracketToColon(mergeSlashForPath(relativePath)) -} - -func middlewaresToGinHandlers(middlewares []httpcontract.Middleware) []gin.HandlerFunc { - var ginHandlers []gin.HandlerFunc - for _, item := range middlewares { - ginHandlers = append(ginHandlers, middlewareToGinHandler(item)) - } - - return ginHandlers -} - -func handlerToGinHandler(handler httpcontract.HandlerFunc) gin.HandlerFunc { - return func(ginCtx *gin.Context) { - handler(frameworkhttp.NewGinContext(ginCtx)) - } -} - -func middlewareToGinHandler(handler httpcontract.Middleware) gin.HandlerFunc { - return func(ginCtx *gin.Context) { - handler(frameworkhttp.NewGinContext(ginCtx)) - } -} - -func getDebugLog() gin.HandlerFunc { - logFormatter := func(param gin.LogFormatterParams) string { - var statusColor, methodColor, resetColor string - if param.IsOutputColor() { - statusColor = param.StatusCodeColor() - methodColor = param.MethodColor() - resetColor = param.ResetColor() - } - - if param.Latency > time.Minute { - // Truncate in a golang < 1.8 safe way - param.Latency = param.Latency - param.Latency%time.Second - } - return fmt.Sprintf("[HTTP] %v |%s %3d %s| %13v | %15s |%s %-7s %s %#v\n%s", - param.TimeStamp.Format("2006/01/02 - 15:04:05"), - statusColor, param.StatusCode, resetColor, - param.Latency, - param.ClientIP, - methodColor, param.Method, resetColor, - param.Path, - param.ErrorMessage, - ) - } - - if facades.Config.GetBool("app.debug") { - return gin.LoggerWithFormatter(logFormatter) - } - - return nil -} - -func colonToBracket(relativePath string) string { - arr := strings.Split(relativePath, "/") - var newArr []string - for _, item := range arr { - if strings.HasPrefix(item, ":") { - item = "{" + strings.ReplaceAll(item, ":", "") + "}" - } - newArr = append(newArr, item) - } - - return strings.Join(newArr, "/") -} - -func bracketToColon(relativePath string) string { - compileRegex := regexp.MustCompile("\\{(.*?)\\}") - matchArr := compileRegex.FindAllStringSubmatch(relativePath, -1) - - for _, item := range matchArr { - relativePath = strings.ReplaceAll(relativePath, item[0], ":"+item[1]) - } - - return relativePath -} - -func mergeSlashForPath(path string) string { - path = strings.ReplaceAll(path, "//", "/") - - return strings.ReplaceAll(path, "//", "/") -} diff --git a/route/gin_group.go b/route/gin_group.go new file mode 100644 index 000000000..f68a22414 --- /dev/null +++ b/route/gin_group.go @@ -0,0 +1,111 @@ +package route + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + httpcontract "github.com/goravel/framework/contracts/http" + "github.com/goravel/framework/contracts/route" +) + +type GinGroup struct { + instance gin.IRouter + originPrefix string + prefix string + originMiddlewares []httpcontract.Middleware + middlewares []httpcontract.Middleware + lastMiddlewares []httpcontract.Middleware +} + +func NewGinGroup(instance gin.IRouter, prefix string, originMiddlewares []httpcontract.Middleware, lastMiddlewares []httpcontract.Middleware) route.Route { + return &GinGroup{ + instance: instance, + originPrefix: prefix, + originMiddlewares: originMiddlewares, + lastMiddlewares: lastMiddlewares, + } +} + +func (r *GinGroup) Group(handler route.GroupFunc) { + var middlewares []httpcontract.Middleware + middlewares = append(middlewares, r.originMiddlewares...) + middlewares = append(middlewares, r.middlewares...) + r.middlewares = []httpcontract.Middleware{} + prefix := pathToGinPath(r.originPrefix + "/" + r.prefix) + r.prefix = "" + + handler(NewGinGroup(r.instance, prefix, middlewares, r.lastMiddlewares)) +} + +func (r *GinGroup) Prefix(addr string) route.Route { + r.prefix += "/" + addr + + return r +} + +func (r *GinGroup) Middleware(handlers ...httpcontract.Middleware) route.Route { + r.middlewares = append(r.middlewares, handlers...) + + return r +} + +func (r *GinGroup) Any(relativePath string, handler httpcontract.HandlerFunc) { + r.getGinRoutesWithMiddlewares().Any(pathToGinPath(relativePath), []gin.HandlerFunc{handlerToGinHandler(handler)}...) +} + +func (r *GinGroup) Get(relativePath string, handler httpcontract.HandlerFunc) { + r.getGinRoutesWithMiddlewares().GET(pathToGinPath(relativePath), []gin.HandlerFunc{handlerToGinHandler(handler)}...) +} + +func (r *GinGroup) Post(relativePath string, handler httpcontract.HandlerFunc) { + r.getGinRoutesWithMiddlewares().POST(pathToGinPath(relativePath), []gin.HandlerFunc{handlerToGinHandler(handler)}...) +} + +func (r *GinGroup) Delete(relativePath string, handler httpcontract.HandlerFunc) { + r.getGinRoutesWithMiddlewares().DELETE(pathToGinPath(relativePath), []gin.HandlerFunc{handlerToGinHandler(handler)}...) +} + +func (r *GinGroup) Patch(relativePath string, handler httpcontract.HandlerFunc) { + r.getGinRoutesWithMiddlewares().PATCH(pathToGinPath(relativePath), []gin.HandlerFunc{handlerToGinHandler(handler)}...) +} + +func (r *GinGroup) Put(relativePath string, handler httpcontract.HandlerFunc) { + r.getGinRoutesWithMiddlewares().PUT(pathToGinPath(relativePath), []gin.HandlerFunc{handlerToGinHandler(handler)}...) +} + +func (r *GinGroup) Options(relativePath string, handler httpcontract.HandlerFunc) { + r.getGinRoutesWithMiddlewares().OPTIONS(pathToGinPath(relativePath), []gin.HandlerFunc{handlerToGinHandler(handler)}...) +} + +func (r *GinGroup) Static(relativePath, root string) { + r.getGinRoutesWithMiddlewares().Static(pathToGinPath(relativePath), root) +} + +func (r *GinGroup) StaticFile(relativePath, filepath string) { + r.getGinRoutesWithMiddlewares().StaticFile(pathToGinPath(relativePath), filepath) +} + +func (r *GinGroup) StaticFS(relativePath string, fs http.FileSystem) { + r.getGinRoutesWithMiddlewares().StaticFS(pathToGinPath(relativePath), fs) +} + +func (r *GinGroup) getGinRoutesWithMiddlewares() gin.IRoutes { + prefix := pathToGinPath(r.originPrefix + "/" + r.prefix) + r.prefix = "" + ginGroup := r.instance.Group(prefix) + + var middlewares []gin.HandlerFunc + ginOriginMiddlewares := middlewaresToGinHandlers(r.originMiddlewares) + ginMiddlewares := middlewaresToGinHandlers(r.middlewares) + ginLastMiddlewares := middlewaresToGinHandlers(r.lastMiddlewares) + middlewares = append(middlewares, ginOriginMiddlewares...) + middlewares = append(middlewares, ginMiddlewares...) + middlewares = append(middlewares, ginLastMiddlewares...) + r.middlewares = []httpcontract.Middleware{} + if len(middlewares) > 0 { + return ginGroup.Use(middlewares...) + } else { + return ginGroup + } +} diff --git a/route/gin_group_test.go b/route/gin_group_test.go new file mode 100644 index 000000000..e9a6ab093 --- /dev/null +++ b/route/gin_group_test.go @@ -0,0 +1,399 @@ +package route + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + mockconfig "github.com/goravel/framework/contracts/config/mocks" + httpcontract "github.com/goravel/framework/contracts/http" + "github.com/goravel/framework/contracts/route" + "github.com/goravel/framework/http/middleware" + "github.com/goravel/framework/testing/mock" +) + +func TestGinGroup(t *testing.T) { + var ( + gin *Gin + mockConfig *mockconfig.Config + ) + beforeEach := func() { + mockConfig = mock.Config() + mockConfig.On("GetBool", "app.debug").Return(true).Once() + + gin = NewGin() + } + tests := []struct { + name string + setup func(req *http.Request) + method string + url string + expectCode int + expectBody string + }{ + { + name: "Get", + setup: func(req *http.Request) { + gin.Get("/input/{id}", func(ctx httpcontract.Context) { + ctx.Response().Json(http.StatusOK, httpcontract.Json{ + "id": ctx.Request().Input("id"), + }) + }) + }, + method: "GET", + url: "/input/1", + expectCode: http.StatusOK, + expectBody: "{\"id\":\"1\"}", + }, + { + name: "Post", + setup: func(req *http.Request) { + gin.Post("/input/{id}", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "id": ctx.Request().Input("id"), + }) + }) + }, + method: "POST", + url: "/input/1", + expectCode: http.StatusOK, + expectBody: "{\"id\":\"1\"}", + }, + { + name: "Put", + setup: func(req *http.Request) { + gin.Put("/input/{id}", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "id": ctx.Request().Input("id"), + }) + }) + }, + method: "PUT", + url: "/input/1", + expectCode: http.StatusOK, + expectBody: "{\"id\":\"1\"}", + }, + { + name: "Delete", + setup: func(req *http.Request) { + gin.Delete("/input/{id}", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "id": ctx.Request().Input("id"), + }) + }) + }, + method: "DELETE", + url: "/input/1", + expectCode: http.StatusOK, + expectBody: "{\"id\":\"1\"}", + }, + { + name: "Options", + setup: func(req *http.Request) { + gin.Options("/input/{id}", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "id": ctx.Request().Input("id"), + }) + }) + }, + method: "OPTIONS", + url: "/input/1", + expectCode: http.StatusOK, + }, + { + name: "Patch", + setup: func(req *http.Request) { + gin.Patch("/input/{id}", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "id": ctx.Request().Input("id"), + }) + }) + }, + method: "PATCH", + url: "/input/1", + expectCode: http.StatusOK, + expectBody: "{\"id\":\"1\"}", + }, + { + name: "Any Get", + setup: func(req *http.Request) { + gin.Any("/any/{id}", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "id": ctx.Request().Input("id"), + }) + }) + }, + method: "GET", + url: "/any/1", + expectCode: http.StatusOK, + expectBody: "{\"id\":\"1\"}", + }, + { + name: "Any Post", + setup: func(req *http.Request) { + gin.Any("/any/{id}", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "id": ctx.Request().Input("id"), + }) + }) + }, + method: "POST", + url: "/any/1", + expectCode: http.StatusOK, + expectBody: "{\"id\":\"1\"}", + }, + { + name: "Any Put", + setup: func(req *http.Request) { + gin.Any("/any/{id}", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "id": ctx.Request().Input("id"), + }) + }) + }, + method: "PUT", + url: "/any/1", + expectCode: http.StatusOK, + expectBody: "{\"id\":\"1\"}", + }, + { + name: "Any Delete", + setup: func(req *http.Request) { + gin.Any("/any/{id}", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "id": ctx.Request().Input("id"), + }) + }) + }, + method: "DELETE", + url: "/any/1", + expectCode: http.StatusOK, + expectBody: "{\"id\":\"1\"}", + }, + { + name: "Any Options", + setup: func(req *http.Request) { + mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Once() + mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Once() + mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Once() + mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Once() + mockConfig.On("GetInt", "cors.max_age").Return(0).Once() + mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Once() + gin.GlobalMiddleware(middleware.Cors()) + gin.Any("/any/{id}", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "id": ctx.Request().Input("id"), + }) + }) + req.Header.Set("Origin", "http://127.0.0.1") + req.Header.Set("Access-Control-Request-Method", "GET") + }, + method: "OPTIONS", + url: "/any/1", + expectCode: http.StatusNoContent, + }, + { + name: "Any Patch", + setup: func(req *http.Request) { + gin.Any("/any/{id}", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "id": ctx.Request().Input("id"), + }) + }) + }, + method: "PATCH", + url: "/any/1", + expectCode: http.StatusOK, + expectBody: "{\"id\":\"1\"}", + }, + { + name: "Static", + setup: func(req *http.Request) { + gin.Static("static", "../") + }, + method: "GET", + url: "/static/README.md", + expectCode: http.StatusOK, + }, + { + name: "StaticFile", + setup: func(req *http.Request) { + gin.StaticFile("static-file", "../README.md") + }, + method: "GET", + url: "/static-file", + expectCode: http.StatusOK, + }, + { + name: "StaticFS", + setup: func(req *http.Request) { + gin.StaticFS("static-fs", http.Dir("./")) + }, + method: "GET", + url: "/static-fs", + expectCode: http.StatusMovedPermanently, + }, + { + name: "Abort Middleware", + setup: func(req *http.Request) { + gin.Middleware(abortMiddleware()).Get("/middleware/{id}", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "id": ctx.Request().Input("id"), + }) + }) + }, + method: "GET", + url: "/middleware/1", + expectCode: http.StatusNonAuthoritativeInfo, + }, + { + name: "Multiple Middleware", + setup: func(req *http.Request) { + gin.Middleware(contextMiddleware(), contextMiddleware1()).Get("/middlewares/{id}", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "id": ctx.Request().Input("id"), + "ctx": ctx.Value("ctx"), + "ctx1": ctx.Value("ctx1"), + }) + }) + }, + method: "GET", + url: "/middlewares/1", + expectCode: http.StatusOK, + expectBody: "{\"ctx\":\"Goravel\",\"ctx1\":\"Hello\",\"id\":\"1\"}", + }, + { + name: "Multiple Prefix", + setup: func(req *http.Request) { + gin.Prefix("prefix1").Prefix("prefix2").Get("input/{id}", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "id": ctx.Request().Input("id"), + }) + }) + }, + method: "GET", + url: "/prefix1/prefix2/input/1", + expectCode: http.StatusOK, + expectBody: "{\"id\":\"1\"}", + }, + { + name: "Multiple Prefix Group Middleware", + setup: func(req *http.Request) { + gin.Prefix("group1").Middleware(contextMiddleware()).Group(func(route1 route.Route) { + route1.Prefix("group2").Middleware(contextMiddleware1()).Group(func(route2 route.Route) { + route2.Get("/middleware/{id}", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "id": ctx.Request().Input("id"), + "ctx": ctx.Value("ctx").(string), + "ctx1": ctx.Value("ctx1").(string), + }) + }) + }) + route1.Middleware(contextMiddleware2()).Get("/middleware/{id}", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "id": ctx.Request().Input("id"), + "ctx": ctx.Value("ctx").(string), + "ctx2": ctx.Value("ctx2").(string), + }) + }) + }) + }, + method: "GET", + url: "/group1/group2/middleware/1", + expectCode: http.StatusOK, + expectBody: "{\"ctx\":\"Goravel\",\"ctx1\":\"Hello\",\"id\":\"1\"}", + }, + { + name: "Multiple Group Middleware", + setup: func(req *http.Request) { + gin.Prefix("group1").Middleware(contextMiddleware()).Group(func(route1 route.Route) { + route1.Prefix("group2").Middleware(contextMiddleware1()).Group(func(route2 route.Route) { + route2.Get("/middleware/{id}", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "id": ctx.Request().Input("id"), + "ctx": ctx.Value("ctx").(string), + "ctx1": ctx.Value("ctx1").(string), + }) + }) + }) + route1.Middleware(contextMiddleware2()).Get("/middleware/{id}", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "id": ctx.Request().Input("id"), + "ctx": ctx.Value("ctx").(string), + "ctx2": ctx.Value("ctx2").(string), + }) + }) + }) + }, + method: "GET", + url: "/group1/middleware/1", + expectCode: http.StatusOK, + expectBody: "{\"ctx\":\"Goravel\",\"ctx2\":\"World\",\"id\":\"1\"}", + }, + { + name: "Global Middleware", + setup: func(req *http.Request) { + gin.GlobalMiddleware(func(ctx httpcontract.Context) { + ctx.WithValue("global", "goravel") + ctx.Request().Next() + }) + gin.Get("/global-middleware", func(ctx httpcontract.Context) { + ctx.Response().Json(http.StatusOK, httpcontract.Json{ + "global": ctx.Value("global"), + }) + }) + }, + method: "GET", + url: "/global-middleware", + expectCode: http.StatusOK, + expectBody: "{\"global\":\"goravel\"}", + }, + } + for _, test := range tests { + beforeEach() + w := httptest.NewRecorder() + req, _ := http.NewRequest(test.method, test.url, nil) + if test.setup != nil { + test.setup(req) + } + gin.ServeHTTP(w, req) + + if test.expectBody != "" { + assert.Equal(t, test.expectBody, w.Body.String(), test.name) + } + assert.Equal(t, test.expectCode, w.Code, test.name) + } +} + +func abortMiddleware() httpcontract.Middleware { + return func(ctx httpcontract.Context) { + ctx.Request().AbortWithStatus(http.StatusNonAuthoritativeInfo) + return + } +} + +func contextMiddleware() httpcontract.Middleware { + return func(ctx httpcontract.Context) { + ctx.WithValue("ctx", "Goravel") + + ctx.Request().Next() + } +} + +func contextMiddleware1() httpcontract.Middleware { + return func(ctx httpcontract.Context) { + ctx.WithValue("ctx1", "Hello") + + ctx.Request().Next() + } +} + +func contextMiddleware2() httpcontract.Middleware { + return func(ctx httpcontract.Context) { + ctx.WithValue("ctx2", "World") + + ctx.Request().Next() + } +} diff --git a/route/gin_test.go b/route/gin_test.go index d8850a01c..73ef12dfa 100644 --- a/route/gin_test.go +++ b/route/gin_test.go @@ -1,11 +1,876 @@ package route import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" "testing" "github.com/stretchr/testify/assert" + + mockconfig "github.com/goravel/framework/contracts/config/mocks" + httpcontract "github.com/goravel/framework/contracts/http" + "github.com/goravel/framework/contracts/validation" + "github.com/goravel/framework/testing/mock" ) -func TestBracketToColon(t *testing.T) { - assert.Equal(t, "/:id/:name", bracketToColon("/{id}/{name}")) +func TestGinRequest(t *testing.T) { + var ( + gin *Gin + req *http.Request + mockConfig *mockconfig.Config + ) + beforeEach := func() { + mockConfig = mock.Config() + mockConfig.On("GetBool", "app.debug").Return(true).Once() + + gin = NewGin() + } + tests := []struct { + name string + method string + url string + setup func(method, url string) error + expectCode int + expectBody string + }{ + { + name: "Methods", + method: "GET", + url: "/get/1?name=Goravel", + setup: func(method, url string) error { + gin.Get("/get/{id}", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "id": ctx.Request().Input("id"), + "name": ctx.Request().Query("name", "Hello"), + "header": ctx.Request().Header("Hello", "World"), + "method": ctx.Request().Method(), + "path": ctx.Request().Path(), + "url": ctx.Request().Url(), + "full_url": ctx.Request().FullUrl(), + "ip": ctx.Request().Ip(), + }) + }) + + var err error + req, err = http.NewRequest(method, url, nil) + if err != nil { + return err + } + req.Header.Set("Hello", "goravel") + + return nil + }, + expectCode: http.StatusOK, + expectBody: "{\"full_url\":\"\",\"header\":\"goravel\",\"id\":\"1\",\"ip\":\"\",\"method\":\"GET\",\"name\":\"Goravel\",\"path\":\"/get/1\",\"url\":\"\"}", + }, + { + name: "Headers", + method: "GET", + url: "/headers", + setup: func(method, url string) error { + gin.Get("/headers", func(ctx httpcontract.Context) { + str, _ := json.Marshal(ctx.Request().Headers()) + ctx.Response().Success().String(string(str)) + }) + + var err error + req, err = http.NewRequest(method, url, nil) + if err != nil { + return err + } + req.Header.Set("Hello", "Goravel") + + return nil + }, + expectCode: http.StatusOK, + expectBody: "{\"Hello\":[\"Goravel\"]}", + }, + { + name: "Form", + method: "POST", + url: "/post", + setup: func(method, url string) error { + gin.Post("/post", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "name": ctx.Request().Form("name", "Hello"), + }) + }) + + payload := &bytes.Buffer{} + writer := multipart.NewWriter(payload) + if err := writer.WriteField("name", "Goravel"); err != nil { + return err + } + if err := writer.Close(); err != nil { + return err + } + + req, _ = http.NewRequest(method, url, payload) + req.Header.Set("Content-Type", writer.FormDataContentType()) + + return nil + }, + expectCode: http.StatusOK, + expectBody: "{\"name\":\"Goravel\"}", + }, + { + name: "Bind", + method: "POST", + url: "/bind", + setup: func(method, url string) error { + gin.Post("/bind", func(ctx httpcontract.Context) { + type Test struct { + Name string + } + var test Test + _ = ctx.Request().Bind(&test) + ctx.Response().Success().Json(httpcontract.Json{ + "name": test.Name, + }) + }) + + payload := strings.NewReader(`{ + "Name": "Goravel" + }`) + req, _ = http.NewRequest(method, url, payload) + req.Header.Set("Content-Type", "application/json") + + return nil + }, + expectCode: http.StatusOK, + expectBody: "{\"name\":\"Goravel\"}", + }, + { + name: "QueryArray", + method: "GET", + url: "/query-array?name=Goravel&name=Goravel1", + setup: func(method, url string) error { + gin.Get("/query-array", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "name": ctx.Request().QueryArray("name"), + }) + }) + + req, _ = http.NewRequest(method, url, nil) + req.Header.Set("Content-Type", "application/json") + + return nil + }, + expectCode: http.StatusOK, + expectBody: "{\"name\":[\"Goravel\",\"Goravel1\"]}", + }, + { + name: "QueryMap", + method: "GET", + url: "/query-array?name[a]=Goravel&name[b]=Goravel1", + setup: func(method, url string) error { + gin.Get("/query-array", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "name": ctx.Request().QueryMap("name"), + }) + }) + + req, _ = http.NewRequest(method, url, nil) + req.Header.Set("Content-Type", "application/json") + + return nil + }, + expectCode: http.StatusOK, + expectBody: "{\"name\":{\"a\":\"Goravel\",\"b\":\"Goravel1\"}}", + }, + { + name: "File", + method: "POST", + url: "/file", + setup: func(method, url string) error { + gin.Post("/file", func(ctx httpcontract.Context) { + mockConfig.On("GetString", "filesystems.default").Return("local").Once() + + fileInfo, err := ctx.Request().File("file") + + mockStorage, mockDriver, _ := mock.Storage() + mockStorage.On("Disk", "local").Return(mockDriver).Once() + mockDriver.On("PutFile", "test", fileInfo).Return("test/logo.png", nil).Once() + mockStorage.On("Exists", "test/logo.png").Return(true).Once() + + if err != nil { + ctx.Response().Success().String("get file error") + return + } + filePath, err := fileInfo.Store("test") + if err != nil { + ctx.Response().Success().String("store file error: " + err.Error()) + return + } + + extension, err := fileInfo.Extension() + if err != nil { + ctx.Response().Success().String("get file extension error: " + err.Error()) + return + } + + ctx.Response().Success().Json(httpcontract.Json{ + "exist": mockStorage.Exists(filePath), + "hash_name_length": len(fileInfo.HashName()), + "hash_name_length1": len(fileInfo.HashName("test")), + "file_path_length": len(filePath), + "extension": extension, + "original_name": fileInfo.GetClientOriginalName(), + "original_extension": fileInfo.GetClientOriginalExtension(), + }) + }) + + payload := &bytes.Buffer{} + writer := multipart.NewWriter(payload) + logo, errFile1 := os.Open("../logo.png") + defer logo.Close() + part1, errFile1 := writer.CreateFormFile("file", filepath.Base("../logo.png")) + _, errFile1 = io.Copy(part1, logo) + if errFile1 != nil { + return errFile1 + } + err := writer.Close() + if err != nil { + return err + } + + req, _ = http.NewRequest(method, url, payload) + req.Header.Set("Content-Type", writer.FormDataContentType()) + + return nil + }, + expectCode: http.StatusOK, + expectBody: "{\"exist\":true,\"extension\":\"png\",\"file_path_length\":13,\"hash_name_length\":44,\"hash_name_length1\":49,\"original_extension\":\"png\",\"original_name\":\"logo.png\"}", + }, + { + name: "GET with validator and validate pass", + method: "GET", + url: "/validator/validate/success?name=Goravel", + setup: func(method, url string) error { + gin.Get("/validator/validate/success", func(ctx httpcontract.Context) { + mockValication, _, _ := mock.Validation() + mockValication.On("Rules").Return([]validation.Rule{}).Once() + + validator, err := ctx.Request().Validate(map[string]string{ + "name": "required", + }) + if err != nil { + ctx.Response().String(400, "Validate error: "+err.Error()) + return + } + if validator.Fails() { + ctx.Response().String(400, fmt.Sprintf("Validate fail: %+v", validator.Errors().All())) + return + } + + type Test struct { + Name string `form:"name" json:"name"` + } + var test Test + if err := validator.Bind(&test); err != nil { + ctx.Response().String(400, "Validate bind error: "+err.Error()) + return + } + + ctx.Response().Success().Json(httpcontract.Json{ + "name": test.Name, + }) + }) + var err error + req, err = http.NewRequest(method, url, nil) + if err != nil { + return err + } + + return nil + }, + expectCode: http.StatusOK, + expectBody: "{\"name\":\"Goravel\"}", + }, + { + name: "GET with validator but validate fail", + method: "GET", + url: "/validator/validate/fail?name=Goravel", + setup: func(method, url string) error { + gin.Get("/validator/validate/fail", func(ctx httpcontract.Context) { + mockValication, _, _ := mock.Validation() + mockValication.On("Rules").Return([]validation.Rule{}).Once() + + validator, err := ctx.Request().Validate(map[string]string{ + "name1": "required", + }) + if err != nil { + ctx.Response().String(http.StatusBadRequest, "Validate error: "+err.Error()) + return + } + if validator.Fails() { + ctx.Response().String(http.StatusBadRequest, fmt.Sprintf("Validate fail: %+v", validator.Errors().All())) + return + } + + ctx.Response().Success().Json(httpcontract.Json{ + "name": "", + }) + }) + + var err error + req, err = http.NewRequest(method, url, nil) + if err != nil { + return err + } + + return nil + }, + expectCode: http.StatusBadRequest, + expectBody: "Validate fail: map[name1:map[required:name1 is required to not be empty]]", + }, + { + name: "GET with validator and validate request pass", + method: "GET", + url: "/validator/validate-request/success?name=Goravel", + setup: func(method, url string) error { + gin.Get("/validator/validate-request/success", func(ctx httpcontract.Context) { + mockValication, _, _ := mock.Validation() + mockValication.On("Rules").Return([]validation.Rule{}).Once() + + var createUser CreateUser + validateErrors, err := ctx.Request().ValidateRequest(&createUser) + if err != nil { + ctx.Response().String(http.StatusBadRequest, "Validate error: "+err.Error()) + return + } + if validateErrors != nil { + ctx.Response().String(http.StatusBadRequest, fmt.Sprintf("Validate fail: %+v", validateErrors.All())) + return + } + + ctx.Response().Success().Json(httpcontract.Json{ + "name": createUser.Name, + }) + }) + + var err error + req, err = http.NewRequest(method, url, nil) + if err != nil { + return err + } + + return nil + }, + expectCode: http.StatusOK, + expectBody: "{\"name\":\"Goravel1\"}", + }, + { + name: "GET with validator but validate request fail", + method: "GET", + url: "/validator/validate-request/fail?name1=Goravel", + setup: func(method, url string) error { + gin.Get("/validator/validate-request/fail", func(ctx httpcontract.Context) { + mockValication, _, _ := mock.Validation() + mockValication.On("Rules").Return([]validation.Rule{}).Once() + + var createUser CreateUser + validateErrors, err := ctx.Request().ValidateRequest(&createUser) + if err != nil { + ctx.Response().String(http.StatusBadRequest, "Validate error: "+err.Error()) + return + } + if validateErrors != nil { + ctx.Response().String(http.StatusBadRequest, fmt.Sprintf("Validate fail: %+v", validateErrors.All())) + return + } + + ctx.Response().Success().Json(httpcontract.Json{ + "name": createUser.Name, + }) + }) + + var err error + req, err = http.NewRequest(method, url, nil) + if err != nil { + return err + } + + return nil + }, + expectCode: http.StatusBadRequest, + expectBody: "Validate fail: map[name:map[required:name is required to not be empty]]", + }, + { + name: "POST with validator and validate pass", + method: "POST", + url: "/validator/validate/success", + setup: func(method, url string) error { + gin.Post("/validator/validate/success", func(ctx httpcontract.Context) { + mockValication, _, _ := mock.Validation() + mockValication.On("Rules").Return([]validation.Rule{}).Once() + + validator, err := ctx.Request().Validate(map[string]string{ + "name": "required", + }) + if err != nil { + ctx.Response().String(400, "Validate error: "+err.Error()) + return + } + if validator.Fails() { + ctx.Response().String(400, fmt.Sprintf("Validate fail: %+v", validator.Errors().All())) + return + } + + type Test struct { + Name string `form:"name" json:"name"` + } + var test Test + if err := validator.Bind(&test); err != nil { + ctx.Response().String(400, "Validate bind error: "+err.Error()) + return + } + + ctx.Response().Success().Json(httpcontract.Json{ + "name": test.Name, + }) + }) + + payload := strings.NewReader(`{ + "name": "Goravel" + }`) + req, _ = http.NewRequest(method, url, payload) + req.Header.Set("Content-Type", "application/json") + + return nil + }, + expectCode: http.StatusOK, + expectBody: "{\"name\":\"Goravel\"}", + }, + { + name: "POST with validator and validate fail", + method: "POST", + url: "/validator/validate/fail", + setup: func(method, url string) error { + gin.Post("/validator/validate/fail", func(ctx httpcontract.Context) { + mockValication, _, _ := mock.Validation() + mockValication.On("Rules").Return([]validation.Rule{}).Once() + + validator, err := ctx.Request().Validate(map[string]string{ + "name1": "required", + }) + if err != nil { + ctx.Response().String(400, "Validate error: "+err.Error()) + return + } + if validator.Fails() { + ctx.Response().String(400, fmt.Sprintf("Validate fail: %+v", validator.Errors().All())) + return + } + + ctx.Response().Success().Json(httpcontract.Json{ + "name": "", + }) + }) + payload := strings.NewReader(`{ + "name": "Goravel" + }`) + req, _ = http.NewRequest(method, url, payload) + req.Header.Set("Content-Type", "application/json") + + return nil + }, + expectCode: http.StatusBadRequest, + expectBody: "Validate fail: map[name1:map[required:name1 is required to not be empty]]", + }, + { + name: "POST with validator and validate request pass", + method: "POST", + url: "/validator/validate-request/success", + setup: func(method, url string) error { + gin.Post("/validator/validate-request/success", func(ctx httpcontract.Context) { + mockValication, _, _ := mock.Validation() + mockValication.On("Rules").Return([]validation.Rule{}).Once() + + var createUser CreateUser + validateErrors, err := ctx.Request().ValidateRequest(&createUser) + if err != nil { + ctx.Response().String(http.StatusBadRequest, "Validate error: "+err.Error()) + return + } + if validateErrors != nil { + ctx.Response().String(http.StatusBadRequest, fmt.Sprintf("Validate fail: %+v", validateErrors.All())) + return + } + + ctx.Response().Success().Json(httpcontract.Json{ + "name": createUser.Name, + }) + }) + + payload := strings.NewReader(`{ + "name": "Goravel" + }`) + req, _ = http.NewRequest(method, url, payload) + req.Header.Set("Content-Type", "application/json") + + return nil + }, + expectCode: http.StatusOK, + expectBody: "{\"name\":\"Goravel1\"}", + }, + { + name: "POST with validator and validate request fail", + method: "POST", + url: "/validator/validate-request/fail", + setup: func(method, url string) error { + gin.Post("/validator/validate-request/fail", func(ctx httpcontract.Context) { + mockValication, _, _ := mock.Validation() + mockValication.On("Rules").Return([]validation.Rule{}).Once() + + var createUser CreateUser + validateErrors, err := ctx.Request().ValidateRequest(&createUser) + if err != nil { + ctx.Response().String(http.StatusBadRequest, "Validate error: "+err.Error()) + return + } + if validateErrors != nil { + ctx.Response().String(http.StatusBadRequest, fmt.Sprintf("Validate fail: %+v", validateErrors.All())) + return + } + + ctx.Response().Success().Json(httpcontract.Json{ + "name": createUser.Name, + }) + }) + + payload := strings.NewReader(`{ + "name1": "Goravel" + }`) + req, _ = http.NewRequest(method, url, payload) + req.Header.Set("Content-Type", "application/json") + + return nil + }, + expectCode: http.StatusBadRequest, + expectBody: "Validate fail: map[name:map[required:name is required to not be empty]]", + }, + { + name: "POST with validator and validate request unauthorize", + method: "POST", + url: "/validator/validate-request/unauthorize", + setup: func(method, url string) error { + gin.Post("/validator/validate-request/unauthorize", func(ctx httpcontract.Context) { + var unauthorize Unauthorize + validateErrors, err := ctx.Request().ValidateRequest(&unauthorize) + if err != nil { + ctx.Response().String(http.StatusBadRequest, "Validate error: "+err.Error()) + return + } + if validateErrors != nil { + ctx.Response().String(http.StatusBadRequest, fmt.Sprintf("Validate fail: %+v", validateErrors.All())) + return + } + + ctx.Response().Success().Json(httpcontract.Json{ + "name": unauthorize.Name, + }) + }) + payload := strings.NewReader(`{ + "name": "Goravel" + }`) + req, _ = http.NewRequest(method, url, payload) + req.Header.Set("Content-Type", "application/json") + + return nil + }, + expectCode: http.StatusBadRequest, + expectBody: "Validate error: error", + }, + } + + for _, test := range tests { + beforeEach() + err := test.setup(test.method, test.url) + assert.Nil(t, err) + + w := httptest.NewRecorder() + gin.ServeHTTP(w, req) + + if test.expectBody != "" { + assert.Equal(t, test.expectBody, w.Body.String(), test.name) + } + assert.Equal(t, test.expectCode, w.Code, test.name) + } +} + +func TestGinResponse(t *testing.T) { + var ( + gin *Gin + req *http.Request + mockConfig *mockconfig.Config + ) + beforeEach := func() { + mockConfig = mock.Config() + mockConfig.On("GetBool", "app.debug").Return(true).Once() + + gin = NewGin() + } + tests := []struct { + name string + method string + url string + setup func(method, url string) error + expectCode int + expectBody string + expectHeader string + }{ + { + name: "Json", + method: "GET", + url: "/json", + setup: func(method, url string) error { + gin.Get("/json", func(ctx httpcontract.Context) { + ctx.Response().Json(http.StatusOK, httpcontract.Json{ + "id": "1", + }) + }) + + var err error + req, err = http.NewRequest(method, url, nil) + if err != nil { + return err + } + + return nil + }, + expectCode: http.StatusOK, + expectBody: "{\"id\":\"1\"}", + }, + { + name: "String", + method: "GET", + url: "/string", + setup: func(method, url string) error { + gin.Get("/string", func(ctx httpcontract.Context) { + ctx.Response().String(http.StatusCreated, "Goravel") + }) + + var err error + req, err = http.NewRequest(method, url, nil) + if err != nil { + return err + } + + return nil + }, + expectCode: http.StatusCreated, + expectBody: "Goravel", + }, + { + name: "Success Json", + method: "GET", + url: "/success/json", + setup: func(method, url string) error { + gin.Get("/success/json", func(ctx httpcontract.Context) { + ctx.Response().Success().Json(httpcontract.Json{ + "id": "1", + }) + }) + + var err error + req, err = http.NewRequest(method, url, nil) + if err != nil { + return err + } + + return nil + }, + expectCode: http.StatusOK, + expectBody: "{\"id\":\"1\"}", + }, + { + name: "Success String", + method: "GET", + url: "/success/string", + setup: func(method, url string) error { + gin.Get("/success/string", func(ctx httpcontract.Context) { + ctx.Response().Success().String("Goravel") + }) + + var err error + req, err = http.NewRequest(method, url, nil) + if err != nil { + return err + } + + return nil + }, + expectCode: http.StatusOK, + expectBody: "Goravel", + }, + { + name: "File", + method: "GET", + url: "/file", + setup: func(method, url string) error { + gin.Get("/file", func(ctx httpcontract.Context) { + ctx.Response().File("../logo.png") + }) + + var err error + req, err = http.NewRequest(method, url, nil) + if err != nil { + return err + } + + return nil + }, + expectCode: http.StatusOK, + }, + { + name: "Download", + method: "GET", + url: "/download", + setup: func(method, url string) error { + gin.Get("/download", func(ctx httpcontract.Context) { + ctx.Response().Download("../logo.png", "1.png") + }) + + var err error + req, err = http.NewRequest(method, url, nil) + if err != nil { + return err + } + + return nil + }, + expectCode: http.StatusOK, + }, + { + name: "Header", + method: "GET", + url: "/header", + setup: func(method, url string) error { + gin.Get("/header", func(ctx httpcontract.Context) { + ctx.Response().Header("Hello", "goravel").String(http.StatusOK, "Goravel") + }) + + var err error + req, err = http.NewRequest(method, url, nil) + if err != nil { + return err + } + + return nil + }, + expectCode: http.StatusOK, + expectBody: "Goravel", + expectHeader: "goravel", + }, + { + name: "Origin", + method: "GET", + url: "/origin", + setup: func(method, url string) error { + gin.GlobalMiddleware(func(ctx httpcontract.Context) { + ctx.Response().Header("global", "goravel") + ctx.Request().Next() + + assert.Equal(t, "Goravel", ctx.Response().Origin().Body().String()) + assert.Equal(t, "goravel", ctx.Response().Origin().Header().Get("global")) + assert.Equal(t, 7, ctx.Response().Origin().Size()) + assert.Equal(t, 200, ctx.Response().Origin().Status()) + }) + gin.Get("/origin", func(ctx httpcontract.Context) { + ctx.Response().String(http.StatusOK, "Goravel") + }) + + var err error + req, err = http.NewRequest(method, url, nil) + if err != nil { + return err + } + + return nil + }, + expectCode: http.StatusOK, + expectBody: "Goravel", + }, + } + + for _, test := range tests { + beforeEach() + err := test.setup(test.method, test.url) + assert.Nil(t, err) + + w := httptest.NewRecorder() + gin.ServeHTTP(w, req) + + if test.expectBody != "" { + assert.Equal(t, test.expectBody, w.Body.String(), test.name) + } + if test.expectHeader != "" { + assert.Equal(t, test.expectHeader, strings.Join(w.Header().Values("Hello"), ""), test.name) + } + assert.Equal(t, test.expectCode, w.Code, test.name) + } +} + +type CreateUser struct { + Name string `form:"name" json:"name"` +} + +func (r *CreateUser) Authorize(ctx httpcontract.Context) error { + return nil +} + +func (r *CreateUser) Rules() map[string]string { + return map[string]string{ + "name": "required", + } +} + +func (r *CreateUser) Messages() map[string]string { + return map[string]string{} +} + +func (r *CreateUser) Attributes() map[string]string { + return map[string]string{} +} + +func (r *CreateUser) PrepareForValidation(data validation.Data) { + if name, exist := data.Get("name"); exist { + _ = data.Set("name", name.(string)+"1") + } +} + +type Unauthorize struct { + Name string `form:"name" json:"name"` +} + +func (r *Unauthorize) Authorize(ctx httpcontract.Context) error { + return errors.New("error") +} + +func (r *Unauthorize) Rules() map[string]string { + return map[string]string{ + "name": "required", + } +} + +func (r *Unauthorize) Messages() map[string]string { + return map[string]string{} +} + +func (r *Unauthorize) Attributes() map[string]string { + return map[string]string{} +} + +func (r *Unauthorize) PrepareForValidation(data validation.Data) { + } diff --git a/route/service_provider.go b/route/service_provider.go index fbb56bcb1..ffcc16231 100644 --- a/route/service_provider.go +++ b/route/service_provider.go @@ -8,8 +8,7 @@ type ServiceProvider struct { } func (route *ServiceProvider) Register() { - app := Application{} - facades.Route = app.Init() + facades.Route = NewGin() } func (route *ServiceProvider) Boot() { diff --git a/route/utils.go b/route/utils.go new file mode 100644 index 000000000..bea355ba7 --- /dev/null +++ b/route/utils.go @@ -0,0 +1,107 @@ +package route + +import ( + "fmt" + "os" + "regexp" + "strings" + "time" + + "github.com/gin-gonic/gin" + + httpcontract "github.com/goravel/framework/contracts/http" + "github.com/goravel/framework/facades" + frameworkhttp "github.com/goravel/framework/http" +) + +func pathToGinPath(relativePath string) string { + return bracketToColon(mergeSlashForPath(relativePath)) +} + +func middlewaresToGinHandlers(middlewares []httpcontract.Middleware) []gin.HandlerFunc { + var ginHandlers []gin.HandlerFunc + for _, item := range middlewares { + ginHandlers = append(ginHandlers, middlewareToGinHandler(item)) + } + + return ginHandlers +} + +func handlerToGinHandler(handler httpcontract.HandlerFunc) gin.HandlerFunc { + return func(ginCtx *gin.Context) { + handler(frameworkhttp.NewGinContext(ginCtx)) + } +} + +func middlewareToGinHandler(handler httpcontract.Middleware) gin.HandlerFunc { + return func(ginCtx *gin.Context) { + handler(frameworkhttp.NewGinContext(ginCtx)) + } +} + +func getDebugLog() gin.HandlerFunc { + logFormatter := func(param gin.LogFormatterParams) string { + var statusColor, methodColor, resetColor string + if param.IsOutputColor() { + statusColor = param.StatusCodeColor() + methodColor = param.MethodColor() + resetColor = param.ResetColor() + } + + if param.Latency > time.Minute { + // Truncate in a golang < 1.8 safe way + param.Latency = param.Latency - param.Latency%time.Second + } + return fmt.Sprintf("[HTTP] %v |%s %3d %s| %13v | %15s |%s %-7s %s %#v\n%s", + param.TimeStamp.Format("2006/01/02 - 15:04:05"), + statusColor, param.StatusCode, resetColor, + param.Latency, + param.ClientIP, + methodColor, param.Method, resetColor, + param.Path, + param.ErrorMessage, + ) + } + + if facades.Config.GetBool("app.debug") { + return gin.LoggerWithFormatter(logFormatter) + } + + return nil +} + +func colonToBracket(relativePath string) string { + arr := strings.Split(relativePath, "/") + var newArr []string + for _, item := range arr { + if strings.HasPrefix(item, ":") { + item = "{" + strings.ReplaceAll(item, ":", "") + "}" + } + newArr = append(newArr, item) + } + + return strings.Join(newArr, "/") +} + +func bracketToColon(relativePath string) string { + compileRegex := regexp.MustCompile("\\{(.*?)\\}") + matchArr := compileRegex.FindAllStringSubmatch(relativePath, -1) + + for _, item := range matchArr { + relativePath = strings.ReplaceAll(relativePath, item[0], ":"+item[1]) + } + + return relativePath +} + +func mergeSlashForPath(path string) string { + path = strings.ReplaceAll(path, "//", "/") + + return strings.ReplaceAll(path, "//", "/") +} + +func runningInConsole() bool { + args := os.Args + + return len(args) >= 2 && args[1] == "artisan" +} diff --git a/route/utils_test.go b/route/utils_test.go new file mode 100644 index 000000000..d8850a01c --- /dev/null +++ b/route/utils_test.go @@ -0,0 +1,11 @@ +package route + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBracketToColon(t *testing.T) { + assert.Equal(t, "/:id/:name", bracketToColon("/{id}/{name}")) +} diff --git a/schedule/application.go b/schedule/application.go index 4ab0444c2..386af6c0d 100644 --- a/schedule/application.go +++ b/schedule/application.go @@ -13,6 +13,10 @@ type Application struct { cron *cron.Cron } +func NewApplication() *Application { + return &Application{} +} + func (app *Application) Call(callback func()) schedule.Event { return &support.Event{Callback: callback} } diff --git a/schedule/application_test.go b/schedule/application_test.go new file mode 100644 index 000000000..90b3ead1c --- /dev/null +++ b/schedule/application_test.go @@ -0,0 +1,59 @@ +package schedule + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/goravel/framework/contracts/schedule" + "github.com/goravel/framework/testing/mock" +) + +func TestApplication(t *testing.T) { + mockArtisan := mock.Artisan() + mockArtisan.On("Call", "test --name Goravel argument0 argument1").Return().Times(3) + + immediatelyCall := 0 + delayIfStillRunningCall := 0 + skipIfStillRunningCall := 0 + + app := NewApplication() + app.Register([]schedule.Event{ + app.Call(func() { + immediatelyCall++ + }).EveryMinute(), + app.Call(func() { + time.Sleep(61 * time.Second) + delayIfStillRunningCall++ + }).EveryMinute().DelayIfStillRunning(), + app.Call(func() { + time.Sleep(61 * time.Second) + skipIfStillRunningCall++ + }).EveryMinute().SkipIfStillRunning(), + app.Command("test --name Goravel argument0 argument1").EveryMinute(), + }) + + second, _ := strconv.Atoi(time.Now().Format("05")) + // Make sure run 3 times + ctx, _ := context.WithTimeout(context.Background(), time.Duration(120+6+60-second)*time.Second) + go func(ctx context.Context) { + app.Run() + + for { + select { + case <-ctx.Done(): + return + } + } + }(ctx) + + time.Sleep(time.Duration(120+5+60-second) * time.Second) + + assert.Equal(t, 3, immediatelyCall) + assert.Equal(t, 2, delayIfStillRunningCall) + assert.Equal(t, 1, skipIfStillRunningCall) + mockArtisan.AssertExpectations(t) +} diff --git a/schedule/service_provider.go b/schedule/service_provider.go index e76d7c1e1..6714a7663 100644 --- a/schedule/service_provider.go +++ b/schedule/service_provider.go @@ -8,7 +8,7 @@ type ServiceProvider struct { } func (receiver *ServiceProvider) Register() { - facades.Schedule = &Application{} + facades.Schedule = NewApplication() } func (receiver *ServiceProvider) Boot() { diff --git a/support/constant.go b/support/constant.go index 085b4349a..092029ad0 100644 --- a/support/constant.go +++ b/support/constant.go @@ -1,5 +1,5 @@ package support -const Version string = "1.7.3" +const Version string = "1.8.0" var RootPath string diff --git a/support/database/database.go b/support/database/database.go new file mode 100644 index 000000000..028755a59 --- /dev/null +++ b/support/database/database.go @@ -0,0 +1,36 @@ +package database + +import "reflect" + +func GetID(dest any) any { + if dest == nil { + return nil + } + + t := reflect.TypeOf(dest) + v := reflect.ValueOf(dest) + + if t.Kind() == reflect.Pointer { + return GetIDByReflect(t.Elem(), v.Elem()) + } + + return GetIDByReflect(t, v) +} + +func GetIDByReflect(t reflect.Type, v reflect.Value) any { + for i := 0; i < t.NumField(); i++ { + if t.Field(i).Name == "Model" && v.Field(i).Type().Kind() == reflect.Struct { + structField := v.Field(i).Type() + for j := 0; j < structField.NumField(); j++ { + if structField.Field(j).Tag.Get("gorm") == "primaryKey" { + return v.Field(i).Field(j).Interface() + } + } + } + if t.Field(i).Tag.Get("gorm") == "primaryKey" { + return v.Field(i).Interface() + } + } + + return nil +} diff --git a/support/database/database_test.go b/support/database/database_test.go new file mode 100644 index 000000000..5c5112453 --- /dev/null +++ b/support/database/database_test.go @@ -0,0 +1,105 @@ +package database + +import ( + "testing" + + "github.com/goravel/framework/database/orm" + + "github.com/stretchr/testify/assert" +) + +func TestGetID(t *testing.T) { + tests := []struct { + description string + setup func(description string) + }{ + { + description: "return value", + setup: func(description string) { + type User struct { + ID uint `gorm:"primaryKey"` + Name string + Avatar string + } + user := User{} + user.ID = 1 + assert.Equal(t, uint(1), GetID(&user), description) + }, + }, + { + description: "return value with orm.Model", + setup: func(description string) { + type User struct { + orm.Model + Name string + Avatar string + } + user := User{} + user.ID = 1 + assert.Equal(t, uint(1), GetID(&user), description) + }, + }, + { + description: "return nil", + setup: func(description string) { + type User struct { + Name string + Avatar string + } + user := User{} + assert.Nil(t, GetID(&user), description) + }, + }, + { + description: "return value(struct)", + setup: func(description string) { + type User struct { + ID uint `gorm:"primaryKey"` + Name string + Avatar string + } + user := User{} + user.ID = 1 + assert.Equal(t, uint(1), GetID(user), description) + }, + }, + { + description: "return value with orm.Model", + setup: func(description string) { + type User struct { + orm.Model + Name string + Avatar string + } + user := User{} + user.ID = 1 + assert.Equal(t, uint(1), GetID(user), description) + }, + }, + { + description: "return nil", + setup: func(description string) { + type User struct { + Name string + Avatar string + } + user := User{} + assert.Nil(t, GetID(user), description) + }, + }, + { + description: "return nil when model is nil", + setup: func(description string) { + type User struct { + Name string + Avatar string + } + assert.Nil(t, GetID(&User{}), description) + assert.Nil(t, GetID(nil), description) + }, + }, + } + for _, test := range tests { + test.setup(test.description) + } +} diff --git a/testing/docker/docker.go b/testing/docker/docker.go new file mode 100644 index 000000000..0b433f435 --- /dev/null +++ b/testing/docker/docker.go @@ -0,0 +1,47 @@ +package docker + +import ( + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" + "github.com/pkg/errors" +) + +func Pool() (*dockertest.Pool, error) { + pool, err := dockertest.NewPool("") + if err != nil { + return nil, errors.WithMessage(err, "Could not construct pool") + } + + if err := pool.Client.Ping(); err != nil { + return nil, errors.WithMessage(err, "Could not connect to Docker") + } + + return pool, nil +} + +func Resource(pool *dockertest.Pool, opts *dockertest.RunOptions) (*dockertest.Resource, error) { + return pool.RunWithOptions(opts, func(config *docker.HostConfig) { + // set AutoRemove to true so that stopped container goes away by itself + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{ + Name: "no", + } + }) +} + +func Redis() (*dockertest.Pool, *dockertest.Resource, error) { + pool, err := Pool() + if err != nil { + return nil, nil, err + } + resource, err := Resource(pool, &dockertest.RunOptions{ + Repository: "redis", + Tag: "latest", + Env: []string{}, + }) + if err != nil { + return nil, nil, err + } + + return pool, resource, nil +} diff --git a/testing/mock/mock.go b/testing/mock/mock.go index e65321705..f354f3701 100644 --- a/testing/mock/mock.go +++ b/testing/mock/mock.go @@ -38,11 +38,11 @@ func Artisan() *consolemocks.Artisan { return mockArtisan } -func Orm() (*ormmocks.Orm, *ormmocks.DB, *ormmocks.Transaction) { +func Orm() (*ormmocks.Orm, *ormmocks.DB, *ormmocks.Transaction, *ormmocks.Association) { mockOrm := &ormmocks.Orm{} facades.Orm = mockOrm - return mockOrm, &ormmocks.DB{}, &ormmocks.Transaction{} + return mockOrm, &ormmocks.DB{}, &ormmocks.Transaction{}, &ormmocks.Association{} } func Event() (*eventmocks.Instance, *eventmocks.Task) { @@ -79,7 +79,7 @@ func Storage() (*filesystemmocks.Storage, *filesystemmocks.Driver, *filesystemmo return mockStorage, mockDriver, mockFile } -func Validator() (*validatemocks.Validation, *validatemocks.Validator, *validatemocks.Errors) { +func Validation() (*validatemocks.Validation, *validatemocks.Validator, *validatemocks.Errors) { mockValidation := &validatemocks.Validation{} mockValidator := &validatemocks.Validator{} mockErrors := &validatemocks.Errors{} diff --git a/validation/errors_test.go b/validation/errors_test.go index 06453dc63..b851d36c5 100644 --- a/validation/errors_test.go +++ b/validation/errors_test.go @@ -1,9 +1,10 @@ package validation import ( - httpvalidate "github.com/goravel/framework/contracts/validation" "testing" + httpvalidate "github.com/goravel/framework/contracts/validation" + "github.com/stretchr/testify/assert" ) diff --git a/validation/validation_test.go b/validation/validation_test.go index 321bc0df5..0f17ae2f0 100644 --- a/validation/validation_test.go +++ b/validation/validation_test.go @@ -2,10 +2,11 @@ package validation import ( "errors" - httpvalidate "github.com/goravel/framework/contracts/validation" "strings" "testing" + httpvalidate "github.com/goravel/framework/contracts/validation" + "github.com/spf13/cast" "github.com/stretchr/testify/assert" diff --git a/validation/validator.go b/validation/validator.go index 338de3973..5a33a3b52 100644 --- a/validation/validator.go +++ b/validation/validator.go @@ -1,10 +1,11 @@ package validation import ( - httpvalidate "github.com/goravel/framework/contracts/validation" "net/url" "github.com/gookit/validate" + + httpvalidate "github.com/goravel/framework/contracts/validation" ) func init() {