From 72826a1293febdbb1ed8939837610627ed8de736 Mon Sep 17 00:00:00 2001 From: Chuan-Heng Hsiao Date: Thu, 3 Aug 2023 12:28:37 -0400 Subject: [PATCH] add refresh token. 1. /refresh and /refreshtoken/info 2. VerifyJwt with whether check expireTS. 3. *gin.Context in api --- api/00-config.go | 13 ++- api/api.go | 2 +- api/api_login_required.go | 4 +- api/api_login_required_path.go | 4 +- api/attempt_change_email.go | 2 +- api/attempt_change_email_test.go | 2 +- api/attempt_set_id_email.go | 2 +- api/attempt_set_id_email_test.go | 2 +- api/auth_utils.go | 113 ++++++++++++++++++++--- api/auth_utils_test.go | 4 +- api/change_email.go | 2 +- api/change_email_test.go | 2 +- api/change_passwd.go | 2 +- api/change_passwd_test.go | 2 +- api/check_exists_user.go | 2 +- api/check_exists_user_test.go | 2 +- api/config.go | 9 +- api/config_util.go | 18 +++- api/const.go | 4 + api/create_article.go | 10 +- api/create_article_test.go | 2 +- api/create_board.go | 2 +- api/create_board_test.go | 2 +- api/create_comment.go | 2 +- api/create_comment_test.go | 2 +- api/cross_post.go | 2 +- api/cross_post_test.go | 2 +- api/delete_articles.go | 2 +- api/delete_articles_test.go | 2 +- api/edit_article.go | 2 +- api/edit_article_test.go | 2 +- api/get_article.go | 2 +- api/get_article_test.go | 2 +- api/get_email_token_info.go | 8 +- api/get_email_token_info_test.go | 6 +- api/get_favorites.go | 2 +- api/get_favorites_test.go | 2 +- api/get_post_template.go | 2 +- api/get_post_template_test.go | 2 +- api/get_refresh_token_info.go | 48 ++++++++++ api/get_refresh_token_info_test.go | 61 ++++++++++++ api/get_token_info.go | 6 +- api/get_token_info_test.go | 5 +- api/get_user.go | 2 +- api/get_user_test.go | 2 +- api/get_user_visit_count.go | 2 +- api/get_user_visit_count_test.go | 2 +- api/get_version.go | 2 +- api/index.go | 2 +- api/index_test.go | 2 +- api/is_board_valid_user.go | 2 +- api/is_board_valid_user_test.go | 2 +- api/is_boards_valid_user.go | 2 +- api/is_boards_valid_user_test.go | 2 +- api/load_auto_complete_boards.go | 2 +- api/load_auto_complete_boards_test.go | 2 +- api/load_board_detail.go | 2 +- api/load_board_detail_test.go | 2 +- api/load_board_summary.go | 2 +- api/load_board_summary_test.go | 2 +- api/load_boards_by_bids.go | 2 +- api/load_boards_by_bids_test.go | 2 +- api/load_bottom_articles.go | 2 +- api/load_bottom_articles_test.go | 2 +- api/load_class_boards.go | 2 +- api/load_class_boards_test.go | 2 +- api/load_full_class_boards.go | 2 +- api/load_full_class_boards_test.go | 2 +- api/load_general_articles.go | 2 +- api/load_general_articles_test.go | 2 +- api/load_general_board_details.go | 2 +- api/load_general_board_details_test.go | 2 +- api/load_general_boards.go | 2 +- api/load_general_boards_by_class.go | 2 +- api/load_general_boards_by_class_test.go | 2 +- api/load_general_boards_test.go | 2 +- api/load_hot_boards.go | 2 +- api/load_hot_boards_test.go | 2 +- api/login.go | 9 +- api/login_test.go | 2 +- api/refresh.go | 73 +++++++++++++++ api/refresh_test.go | 70 ++++++++++++++ api/register.go | 2 +- api/register_test.go | 2 +- api/reload_uhash.go | 2 +- api/reload_uhash_test.go | 2 +- api/set_id_email.go | 2 +- api/set_id_email_test.go | 2 +- api/set_user_perm.go | 2 +- api/set_user_perm_test.go | 2 +- api/types.go | 14 ++- api/user_utils.go | 2 +- api/write_favorites.go | 2 +- api/write_favorites_test.go | 2 +- apidoc/apidoc.py | 15 +++ apidoc/defs/token.yaml | 4 +- apidoc/get_email_token_info.yaml | 2 + apidoc/get_refresh_token_info.yaml | 32 +++++++ apidoc/get_token_info.yaml | 3 + apidoc/refresh.yaml | 21 +++++ initgin/init_gin.go | 3 + 101 files changed, 582 insertions(+), 127 deletions(-) create mode 100644 api/get_refresh_token_info.go create mode 100644 api/get_refresh_token_info_test.go create mode 100644 api/refresh.go create mode 100644 api/refresh_test.go create mode 100644 apidoc/get_refresh_token_info.yaml create mode 100644 apidoc/refresh.yaml diff --git a/api/00-config.go b/api/00-config.go index cbd048fc..118f6b30 100644 --- a/api/00-config.go +++ b/api/00-config.go @@ -6,14 +6,19 @@ import ( var ( // Creating JWT Token - JWT_SECRET = []byte("jwt_secret") JWT_ISSUER = "go-pttbbs" GUEST = "guest" - EMAIL_JWT_SECRET = []byte("email_jwt_secret") + JWT_SECRET = []byte("jwt_secret") + JWT_TOKEN_EXPIRE_TS = 86400 * 1 // 1 days + JWT_TOKEN_EXPIRE_DURATION = time.Duration(JWT_TOKEN_EXPIRE_TS) * time.Second - JWT_TOKEN_EXPIRE_TS = 86400 * 1 // 1 days - JWT_TOKEN_EXPIRE_DURATION = time.Duration(JWT_TOKEN_EXPIRE_TS) * time.Second + EMAIL_JWT_SECRET = []byte("email_jwt_secret") EMAIL_JWT_TOKEN_EXPIRE_TS = 60 * 15 // 15 mins EMAIL_JWT_TOKEN_EXPIRE_DURATION = time.Duration(EMAIL_JWT_TOKEN_EXPIRE_TS) * time.Second + + REFRESH_JWT_CLAIM_TYPE = "refresh" + REFRESH_JWT_SECRET = []byte("refresh_jwt_secret") + REFRESH_JWT_TOKEN_EXPIRE_TS = 86400 * 7 // 7 days + REFRESH_JWT_TOKEN_EXPIRE_DURATION = time.Duration(REFRESH_JWT_TOKEN_EXPIRE_TS) * time.Second ) diff --git a/api/api.go b/api/api.go index 1405d6d1..acb760f5 100644 --- a/api/api.go +++ b/api/api.go @@ -40,6 +40,6 @@ func process(theFunc APIFunc, params interface{}, c *gin.Context) { return } - result, err := theFunc(remoteAddr, params) + result, err := theFunc(remoteAddr, params, c) processResult(c, result, err) } diff --git a/api/api_login_required.go b/api/api_login_required.go index 1f5b4135..32d87972 100644 --- a/api/api_login_required.go +++ b/api/api_login_required.go @@ -43,11 +43,11 @@ func loginRequiredProcess(theFunc LoginRequiredAPIFunc, params interface{}, c *g jwt := GetJwt(c) - userID, _, err := VerifyJwt(jwt) + userID, _, _, err := VerifyJwt(jwt, true) if err != nil { userID = bbs.UUserID(GUEST) } - result, err := theFunc(remoteAddr, userID, params) + result, err := theFunc(remoteAddr, userID, params, c) processResult(c, result, err) } diff --git a/api/api_login_required_path.go b/api/api_login_required_path.go index 60aeb561..d4a1e35d 100644 --- a/api/api_login_required_path.go +++ b/api/api_login_required_path.go @@ -49,11 +49,11 @@ func loginRequiredPathProcess(theFunc LoginRequiredPathAPIFunc, params interface jwt := GetJwt(c) - userID, _, err := VerifyJwt(jwt) + userID, _, _, err := VerifyJwt(jwt, true) if err != nil { userID = bbs.UUserID(GUEST) } - result, err := theFunc(remoteAddr, userID, params, path) + result, err := theFunc(remoteAddr, userID, params, path, c) processResult(c, result, err) } diff --git a/api/attempt_change_email.go b/api/attempt_change_email.go index 460c250c..d9684ce8 100644 --- a/api/attempt_change_email.go +++ b/api/attempt_change_email.go @@ -28,7 +28,7 @@ func AttemptChangeEmailWrapper(c *gin.Context) { LoginRequiredPathJSON(AttemptChangeEmail, params, path, c) } -func AttemptChangeEmail(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (result interface{}, err error) { +func AttemptChangeEmail(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*AttemptChangeEmailParams) if !ok { return nil, ErrInvalidParams diff --git a/api/attempt_change_email_test.go b/api/attempt_change_email_test.go index 8916493d..8fa30136 100644 --- a/api/attempt_change_email_test.go +++ b/api/attempt_change_email_test.go @@ -51,7 +51,7 @@ func TestAttemptChangeEmail(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := AttemptChangeEmail(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResult, err := AttemptChangeEmail(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("AttemptChangeEmail() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/attempt_set_id_email.go b/api/attempt_set_id_email.go index ef0dc10e..cb0e6583 100644 --- a/api/attempt_set_id_email.go +++ b/api/attempt_set_id_email.go @@ -28,7 +28,7 @@ func AttemptSetIDEmailWrapper(c *gin.Context) { LoginRequiredPathJSON(AttemptSetIDEmail, params, path, c) } -func AttemptSetIDEmail(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (result interface{}, err error) { +func AttemptSetIDEmail(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*AttemptSetIDEmailParams) if !ok { return nil, ErrInvalidParams diff --git a/api/attempt_set_id_email_test.go b/api/attempt_set_id_email_test.go index 1241ea75..24e2f08e 100644 --- a/api/attempt_set_id_email_test.go +++ b/api/attempt_set_id_email_test.go @@ -50,7 +50,7 @@ func TestAttemptSetIDEmail(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := AttemptSetIDEmail(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResult, err := AttemptSetIDEmail(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("AttemptSetIDEmail() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/auth_utils.go b/api/auth_utils.go index f5cd9b35..609b68c9 100644 --- a/api/auth_utils.go +++ b/api/auth_utils.go @@ -20,22 +20,24 @@ func GetJwt(c *gin.Context) (jwt string) { return tokenList[1] } -func VerifyJwt(raw string) (userID bbs.UUserID, clientInfo string, err error) { +func VerifyJwt(raw string, isCheckExpire bool) (userID bbs.UUserID, expireTS int, clientInfo string, err error) { if raw == "" { - return bbs.UUserID(GUEST), "", nil + return bbs.UUserID(GUEST), 0, "", nil } cl, err := parseJwtClaim(raw) if err != nil { - return "", "", ErrInvalidToken + return "", 0, "", ErrInvalidToken } - currentTS := int(types.NowTS()) - if currentTS > cl.Expire { - return "", "", ErrInvalidToken + if isCheckExpire { + currentTS := int(types.NowTS()) + if currentTS > cl.Expire { + return "", 0, "", ErrInvalidToken + } } - return bbs.UUserID(cl.UUserID), cl.ClientInfo, nil + return bbs.UUserID(cl.UUserID), cl.Expire, cl.ClientInfo, nil } func parseJwtClaim(raw string) (cl *JwtClaim, err error) { @@ -95,26 +97,26 @@ func CreateToken(userID bbs.UUserID, clientInfo string) (raw string, err error) return raw, nil } -func VerifyEmailJwt(raw string, context EmailTokenContext) (userID bbs.UUserID, clientInfo string, email string, err error) { +func VerifyEmailJwt(raw string, context EmailTokenContext) (userID bbs.UUserID, expireTS int, clientInfo string, email string, err error) { if raw == "" { - return "", "", "", ErrInvalidToken + return "", 0, "", "", ErrInvalidToken } cl, err := parseEmailJwtClaim(raw) if err != nil { - return "", "", "", ErrInvalidToken + return "", 0, "", "", ErrInvalidToken } currentTS := int(types.NowTS()) if currentTS > cl.Expire { - return "", "", "", ErrInvalidToken + return "", 0, "", "", ErrInvalidToken } if cl.Context != string(context) { - return "", "", "", ErrInvalidToken + return "", 0, "", "", ErrInvalidToken } - return bbs.UUserID(cl.UUserID), cl.ClientInfo, cl.Email, nil + return bbs.UUserID(cl.UUserID), cl.Expire, cl.ClientInfo, cl.Email, nil } func parseEmailJwtClaim(raw string) (cl *EmailJwtClaim, err error) { @@ -186,6 +188,91 @@ func CreateEmailToken(userID bbs.UUserID, clientInfo string, email string, conte return raw, nil } +func VerifyRefreshJwt(raw string) (userID bbs.UUserID, expireTS int, clientInfo string, err error) { + if raw == "" { + return bbs.UUserID(GUEST), 0, "", nil + } + + cl, err := parseRefreshJwtClaim(raw) + if err != nil { + return "", 0, "", ErrInvalidToken + } + + currentTS := int(types.NowTS()) + if currentTS > cl.Expire { + return "", 0, "", ErrInvalidToken + } + + if cl.TheType != REFRESH_JWT_CLAIM_TYPE { + return "", 0, "", ErrInvalidToken + } + + return bbs.UUserID(cl.UUserID), cl.Expire, cl.ClientInfo, nil +} + +func parseRefreshJwtClaim(raw string) (cl *RefreshJwtClaim, err error) { + tok, err := ParseJwt(raw, REFRESH_JWT_SECRET) + if err != nil { + return nil, err + } + + claim, ok := tok.Claims.(jwt.MapClaims) + if !ok { + return nil, ErrInvalidToken + } + + cli, err := ParseClaimString(claim, "cli") + if err != nil { + return nil, err + } + sub, err := ParseClaimString(claim, "sub") + if err != nil { + return nil, err + } + exp, err := ParseClaimInt(claim, "exp") + if err != nil { + return nil, err + } + typ, err := ParseClaimString(claim, "typ") + if err != nil { + return nil, err + } + + cl = &RefreshJwtClaim{ + ClientInfo: cli, + UUserID: sub, + Expire: exp, + TheType: typ, + } + + return cl, nil +} + +func CreateRefreshToken(userID bbs.UUserID, clientInfo string) (raw string, err error) { + defer func() { + err2 := recover() + if err2 == nil { + return + } + + err = types.ErrRecover(err2) + }() + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "cli": clientInfo, + "sub": userID, + "exp": int(types.NowTS()) + REFRESH_JWT_TOKEN_EXPIRE_TS, + "typ": REFRESH_JWT_CLAIM_TYPE, + }) + + raw, err = token.SignedString(REFRESH_JWT_SECRET) + if err != nil { + return "", err + } + + return raw, nil +} + func ParseJwt(raw string, secret []byte) (tok *jwt.Token, err error) { tok, err = jwt.Parse(raw, func(token *jwt.Token) (interface{}, error) { return secret, nil diff --git a/api/auth_utils_test.go b/api/auth_utils_test.go index bbed6454..17040f00 100644 --- a/api/auth_utils_test.go +++ b/api/auth_utils_test.go @@ -45,7 +45,7 @@ func TestVerifyJwt(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotUserID, _, err := VerifyJwt(tt.args.raw) + gotUserID, _, _, err := VerifyJwt(tt.args.raw, true) if (err != nil) != tt.wantErr { t.Errorf("VerifyJwt() error = %v, wantErr %v", err, tt.wantErr) return @@ -88,7 +88,7 @@ func TestVerifyEmailJwt(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotUserID, gotClientInfo, gotEmail, err := VerifyEmailJwt(tt.args.raw, CONTEXT_CHANGE_EMAIL) + gotUserID, _, gotClientInfo, gotEmail, err := VerifyEmailJwt(tt.args.raw, CONTEXT_CHANGE_EMAIL) if (err != nil) != tt.wantErr { t.Errorf("VerifyEmailJwt() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/change_email.go b/api/change_email.go index 869c29e5..1dd43fb3 100644 --- a/api/change_email.go +++ b/api/change_email.go @@ -30,7 +30,7 @@ func ChangeEmailWrapper(c *gin.Context) { // // Sysop initiates only attempt-change-mail. // Sysop does not change email directly. -func ChangeEmail(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (result interface{}, err error) { +func ChangeEmail(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*ChangeEmailParams) if !ok { return nil, ErrInvalidParams diff --git a/api/change_email_test.go b/api/change_email_test.go index 3af1b75c..3693c787 100644 --- a/api/change_email_test.go +++ b/api/change_email_test.go @@ -49,7 +49,7 @@ func TestChangeEmail(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := ChangeEmail(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResult, err := ChangeEmail(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("ChangeEmail() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/change_passwd.go b/api/change_passwd.go index cb9ffebe..57c63674 100644 --- a/api/change_passwd.go +++ b/api/change_passwd.go @@ -29,7 +29,7 @@ func ChangePasswdWrapper(c *gin.Context) { LoginRequiredPathJSON(ChangePasswd, params, path, c) } -func ChangePasswd(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (result interface{}, err error) { +func ChangePasswd(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*ChangePasswdParams) if !ok { return nil, ErrInvalidParams diff --git a/api/change_passwd_test.go b/api/change_passwd_test.go index c7e875a0..ae0411c1 100644 --- a/api/change_passwd_test.go +++ b/api/change_passwd_test.go @@ -69,7 +69,7 @@ func TestChangePasswd(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := ChangePasswd(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResult, err := ChangePasswd(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("ChangePasswd() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/check_exists_user.go b/api/check_exists_user.go index 960e1dec..3d8e4552 100644 --- a/api/check_exists_user.go +++ b/api/check_exists_user.go @@ -21,7 +21,7 @@ func CheckExistsUserWrapper(c *gin.Context) { JSON(CheckExistsUser, params, c) } -func CheckExistsUser(remoteAddr string, params interface{}) (result interface{}, err error) { +func CheckExistsUser(remoteAddr string, params interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*CheckExistsUserParams) if !ok { return nil, ErrInvalidParams diff --git a/api/check_exists_user_test.go b/api/check_exists_user_test.go index ff2f3aec..db227b41 100644 --- a/api/check_exists_user_test.go +++ b/api/check_exists_user_test.go @@ -47,7 +47,7 @@ func TestCheckExistsUser(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := CheckExistsUser(tt.args.remoteAddr, tt.args.params) + gotResult, err := CheckExistsUser(tt.args.remoteAddr, tt.args.params, nil) if (err != nil) != tt.wantErr { t.Errorf("CheckExistsUser() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/config.go b/api/config.go index a13acaae..1fa8f388 100644 --- a/api/config.go +++ b/api/config.go @@ -1,13 +1,16 @@ package api func config() { - JWT_SECRET = setBytesConfig("JWT_SECRET", JWT_SECRET) JWT_ISSUER = setStringConfig("JWT_ISSUER", JWT_ISSUER) GUEST = setStringConfig("GUEST", GUEST) - EMAIL_JWT_SECRET = setBytesConfig("EMAIL_JWT_SECRET", EMAIL_JWT_SECRET) - + JWT_SECRET = setBytesConfig("JWT_SECRET", JWT_SECRET) JWT_TOKEN_EXPIRE_TS = setIntConfig("JWT_TOKEN_EXPIRE_TS", JWT_TOKEN_EXPIRE_TS) + EMAIL_JWT_SECRET = setBytesConfig("EMAIL_JWT_SECRET", EMAIL_JWT_SECRET) EMAIL_JWT_TOKEN_EXPIRE_TS = setIntConfig("EMAIL_JWT_TOKEN_EXPIRE_TS", EMAIL_JWT_TOKEN_EXPIRE_TS) + + REFRESH_JWT_CLAIM_TYPE = setStringConfig("REFRESH_JWT_CLAIM_TYPE", REFRESH_JWT_CLAIM_TYPE) + REFRESH_JWT_SECRET = setBytesConfig("REFRESH_JWT_SECRET", REFRESH_JWT_SECRET) + REFRESH_JWT_TOKEN_EXPIRE_TS = setIntConfig("REFRESH_JWT_TOKEN_EXPIRE_TS", REFRESH_JWT_TOKEN_EXPIRE_TS) } diff --git a/api/config_util.go b/api/config_util.go index 1643f5dd..77c61076 100644 --- a/api/config_util.go +++ b/api/config_util.go @@ -29,22 +29,32 @@ func setIntConfig(idx string, orig int) int { func postInitConfig() { _ = setJwtTokenExpireTS(JWT_TOKEN_EXPIRE_TS) _ = setEmailJwtTokenExpireTS(EMAIL_JWT_TOKEN_EXPIRE_TS) + _ = setRefreshJwtTokenExpireTS(REFRESH_JWT_TOKEN_EXPIRE_TS) } -func setJwtTokenExpireTS(JwtTokenExpireTS int) (origJwtTokenExpireTS int) { +func setJwtTokenExpireTS(jwtTokenExpireTS int) (origJwtTokenExpireTS int) { origJwtTokenExpireTS = JWT_TOKEN_EXPIRE_TS - JWT_TOKEN_EXPIRE_TS = JwtTokenExpireTS + JWT_TOKEN_EXPIRE_TS = jwtTokenExpireTS JWT_TOKEN_EXPIRE_DURATION = time.Duration(JWT_TOKEN_EXPIRE_TS) * time.Second return origJwtTokenExpireTS } -func setEmailJwtTokenExpireTS(EmailJwtTokenExpireTS int) (origEmailJwtTokenExpireTS int) { +func setEmailJwtTokenExpireTS(emailJwtTokenExpireTS int) (origEmailJwtTokenExpireTS int) { origEmailJwtTokenExpireTS = EMAIL_JWT_TOKEN_EXPIRE_TS - EMAIL_JWT_TOKEN_EXPIRE_TS = EmailJwtTokenExpireTS + EMAIL_JWT_TOKEN_EXPIRE_TS = emailJwtTokenExpireTS EMAIL_JWT_TOKEN_EXPIRE_DURATION = time.Duration(EMAIL_JWT_TOKEN_EXPIRE_TS) * time.Second return origEmailJwtTokenExpireTS } + +func setRefreshJwtTokenExpireTS(refreshJwtTokenExpireTS int) (origRefreshJwtTokenExpireTS int) { + origRefreshJwtTokenExpireTS = REFRESH_JWT_TOKEN_EXPIRE_TS + + REFRESH_JWT_TOKEN_EXPIRE_TS = refreshJwtTokenExpireTS + REFRESH_JWT_TOKEN_EXPIRE_DURATION = time.Duration(REFRESH_JWT_TOKEN_EXPIRE_TS) * time.Second + + return origRefreshJwtTokenExpireTS +} diff --git a/api/const.go b/api/const.go index 778f64ec..7877ee91 100644 --- a/api/const.go +++ b/api/const.go @@ -1 +1,5 @@ package api + +const ( + EPSILON_EXPIRE_TS = 2 +) diff --git a/api/create_article.go b/api/create_article.go index f96724e4..bdd6f5f6 100644 --- a/api/create_article.go +++ b/api/create_article.go @@ -25,15 +25,7 @@ func CreateArticleWrapper(c *gin.Context) { LoginRequiredPathJSON(CreateArticle, params, path, c) } -func CreateArticle( - remoteAddr string, - uuser bbs.UUserID, - params interface{}, - path interface{}) ( - - result interface{}, - err error) { - +func CreateArticle(remoteAddr string, uuser bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*CreateArticleParams) if !ok { return nil, ErrInvalidParams diff --git a/api/create_article_test.go b/api/create_article_test.go index c81385b4..d08456c6 100644 --- a/api/create_article_test.go +++ b/api/create_article_test.go @@ -69,7 +69,7 @@ func TestCreateArticle(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := CreateArticle(tt.args.remoteAddr, tt.args.uuser, tt.args.params, tt.args.path) + gotResult, err := CreateArticle(tt.args.remoteAddr, tt.args.uuser, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("CreateArticle() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/create_board.go b/api/create_board.go index 6a7d236b..63825a2e 100644 --- a/api/create_board.go +++ b/api/create_board.go @@ -31,7 +31,7 @@ func CreateBoardWrapper(c *gin.Context) { LoginRequiredPathJSON(CreateBoard, params, path, c) } -func CreateBoard(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (result interface{}, err error) { +func CreateBoard(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*CreateBoardParams) if !ok { return nil, ErrInvalidParams diff --git a/api/create_board_test.go b/api/create_board_test.go index 2eb13619..004ed755 100644 --- a/api/create_board_test.go +++ b/api/create_board_test.go @@ -63,7 +63,7 @@ func TestCreateBoard(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := CreateBoard(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResult, err := CreateBoard(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("CreateBoard() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/create_comment.go b/api/create_comment.go index 77b1ddae..3172ca7a 100644 --- a/api/create_comment.go +++ b/api/create_comment.go @@ -30,7 +30,7 @@ func CreateCommentWrapper(c *gin.Context) { LoginRequiredPathJSON(CreateComment, params, path, c) } -func CreateComment(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (result interface{}, err error) { +func CreateComment(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*CreateCommentParams) if !ok { return nil, ErrInvalidParams diff --git a/api/create_comment_test.go b/api/create_comment_test.go index 718644e2..79df01f2 100644 --- a/api/create_comment_test.go +++ b/api/create_comment_test.go @@ -55,7 +55,7 @@ func TestCreateComment(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := CreateComment(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResult, err := CreateComment(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("CreateComment() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/cross_post.go b/api/cross_post.go index a3546e88..5f7d8416 100644 --- a/api/cross_post.go +++ b/api/cross_post.go @@ -31,7 +31,7 @@ type CrossPostResult struct { CommentMTime types.Time4 `json:"comment_mtime"` } -func CrossPost(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (result interface{}, err error) { +func CrossPost(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*CrossPostParams) if !ok { return nil, ErrInvalidParams diff --git a/api/cross_post_test.go b/api/cross_post_test.go index 6c5c4aec..fe72903d 100644 --- a/api/cross_post_test.go +++ b/api/cross_post_test.go @@ -94,7 +94,7 @@ func TestCrossPost(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := CrossPost(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResult, err := CrossPost(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("CrossPost() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/delete_articles.go b/api/delete_articles.go index dc5c9d46..0d1401c6 100644 --- a/api/delete_articles.go +++ b/api/delete_articles.go @@ -25,7 +25,7 @@ func DeleteArticlesWrapper(c *gin.Context) { LoginRequiredPathJSON(DeleteArticles, params, path, c) } -func DeleteArticles(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (result interface{}, err error) { +func DeleteArticles(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*DeleteArticlesParams) if !ok { return nil, ErrInvalidParams diff --git a/api/delete_articles_test.go b/api/delete_articles_test.go index f24a824e..810b8e59 100644 --- a/api/delete_articles_test.go +++ b/api/delete_articles_test.go @@ -104,7 +104,7 @@ func TestDeleteArticles(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { wg.Done() - gotResult, err := DeleteArticles(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResult, err := DeleteArticles(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("DeleteArticles() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/edit_article.go b/api/edit_article.go index b2ffd6c8..1708a9a1 100644 --- a/api/edit_article.go +++ b/api/edit_article.go @@ -37,7 +37,7 @@ func EditArticleWrapper(c *gin.Context) { LoginRequiredPathJSON(EditArticle, params, path, c) } -func EditArticle(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (result interface{}, err error) { +func EditArticle(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*EditArticleParams) if !ok { return nil, ErrInvalidParams diff --git a/api/edit_article_test.go b/api/edit_article_test.go index 31232540..cbd12447 100644 --- a/api/edit_article_test.go +++ b/api/edit_article_test.go @@ -135,7 +135,7 @@ func TestEditArticle(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := EditArticle(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResult, err := EditArticle(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("EditArticle() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/get_article.go b/api/get_article.go index 573b9259..bf7d46c6 100644 --- a/api/get_article.go +++ b/api/get_article.go @@ -41,7 +41,7 @@ func GetArticleWrapper(c *gin.Context) { // // Require middleware to parse the content. // Require middleware to take care of user-read-article. -func GetArticle(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (result interface{}, err error) { +func GetArticle(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*GetArticleParams) if !ok { return nil, ErrInvalidParams diff --git a/api/get_article_test.go b/api/get_article_test.go index 94a8290c..6745a129 100644 --- a/api/get_article_test.go +++ b/api/get_article_test.go @@ -89,7 +89,7 @@ func TestGetArticle(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := GetArticle(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResult, err := GetArticle(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("GetArticle() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/get_email_token_info.go b/api/get_email_token_info.go index 7a90c82c..1b3080f5 100644 --- a/api/get_email_token_info.go +++ b/api/get_email_token_info.go @@ -16,6 +16,7 @@ type GetEmailTokenInfoResult struct { ClientInfo string `json:"client_info"` UserID bbs.UUserID `json:"user_id"` Email string `json:"email"` + Expire int `json:"expire"` } func GetEmailTokenInfoWrapper(c *gin.Context) { @@ -24,26 +25,27 @@ func GetEmailTokenInfoWrapper(c *gin.Context) { LoginRequiredJSON(GetEmailTokenInfo, params, c) } -func GetEmailTokenInfo(remoteAddr string, uuserID bbs.UUserID, params interface{}) (result interface{}, err error) { +func GetEmailTokenInfo(remoteAddr string, uuserID bbs.UUserID, params interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*GetEmailTokenInfoParams) if !ok { return nil, ErrInvalidParams } - userID, clientInfo, email, err := VerifyEmailJwt(theParams.Jwt, theParams.Context) + userID, expireTS, clientInfo, email, err := VerifyEmailJwt(theParams.Jwt, theParams.Context) if err != nil { return nil, err } isValid, _ := userInfoIsValidEmailUser(uuserID, userID, theParams.Jwt, theParams.Context, true) if !isValid { - return nil, ErrInvalidUser + return nil, ErrInvalidToken } result = &GetEmailTokenInfoResult{ ClientInfo: clientInfo, UserID: userID, Email: email, + Expire: expireTS, } return result, nil diff --git a/api/get_email_token_info_test.go b/api/get_email_token_info_test.go index 4f254ee9..c04156e3 100644 --- a/api/get_email_token_info_test.go +++ b/api/get_email_token_info_test.go @@ -38,12 +38,14 @@ func TestGetEmailTokenInfo(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := GetEmailTokenInfo(tt.args.remoteAddr, tt.args.uuserID, tt.args.params) + gotResult, err := GetEmailTokenInfo(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, nil) if (err != nil) != tt.wantErr { t.Errorf("GetEmailTokenInfo() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(gotResult, tt.expectedResult) { + result, _ := gotResult.(*GetEmailTokenInfoResult) + result.Expire = 0 + if !reflect.DeepEqual(result, tt.expectedResult) { t.Errorf("GetEmailTokenInfo() = %v, want %v", gotResult, tt.expectedResult) } }) diff --git a/api/get_favorites.go b/api/get_favorites.go index 5d078f8d..392163b6 100644 --- a/api/get_favorites.go +++ b/api/get_favorites.go @@ -27,7 +27,7 @@ func GetFavoritesWrapper(c *gin.Context) { LoginRequiredPathQuery(GetFavorites, params, path, c) } -func GetFavorites(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (result interface{}, err error) { +func GetFavorites(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*GetFavoritesParams) if !ok { return nil, ErrInvalidParams diff --git a/api/get_favorites_test.go b/api/get_favorites_test.go index 06ca7251..75575213 100644 --- a/api/get_favorites_test.go +++ b/api/get_favorites_test.go @@ -58,7 +58,7 @@ func TestGetFavorites(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := GetFavorites(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResult, err := GetFavorites(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("GetFavorites() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/get_post_template.go b/api/get_post_template.go index 053e911b..f70b6b37 100644 --- a/api/get_post_template.go +++ b/api/get_post_template.go @@ -34,7 +34,7 @@ func GetPostTemplateWrapper(c *gin.Context) { LoginRequiredPathQuery(GetPostTemplate, params, path, c) } -func GetPostTemplate(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (result interface{}, err error) { +func GetPostTemplate(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*GetPostTemplateParams) if !ok { return nil, ErrInvalidParams diff --git a/api/get_post_template_test.go b/api/get_post_template_test.go index 2b81635f..4b150034 100644 --- a/api/get_post_template_test.go +++ b/api/get_post_template_test.go @@ -59,7 +59,7 @@ func TestGetPostTemplate(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := GetPostTemplate(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResult, err := GetPostTemplate(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("GetPostTemplate() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/get_refresh_token_info.go b/api/get_refresh_token_info.go new file mode 100644 index 00000000..5de291bd --- /dev/null +++ b/api/get_refresh_token_info.go @@ -0,0 +1,48 @@ +package api + +import ( + "github.com/Ptt-official-app/go-pttbbs/bbs" + "github.com/gin-gonic/gin" +) + +const GET_REFRESH_TOKEN_INFO_R = "/refreshtoken/info" + +type GetRefreshTokenInfoParams struct { + Jwt string `json:"token" form:"token" url:"token"` +} + +type GetRefreshTokenInfoResult struct { + ClientInfo string `json:"client_info"` + UserID bbs.UUserID `json:"user_id"` + Expire int `json:"expire"` +} + +func GetRefreshTokenInfoWrapper(c *gin.Context) { + params := &GetRefreshTokenInfoParams{} + + LoginRequiredJSON(GetRefreshTokenInfo, params, c) +} + +func GetRefreshTokenInfo(remoteAddr string, uuserID bbs.UUserID, params interface{}, c *gin.Context) (result interface{}, err error) { + theParams, ok := params.(*GetRefreshTokenInfoParams) + if !ok { + return nil, ErrInvalidParams + } + + userID, expireTS, clientInfo, err := VerifyRefreshJwt(theParams.Jwt) + if err != nil { + return nil, err + } + + if userID != uuserID { + return nil, ErrInvalidToken + } + + result = &GetRefreshTokenInfoResult{ + ClientInfo: clientInfo, + UserID: userID, + Expire: expireTS, + } + + return result, nil +} diff --git a/api/get_refresh_token_info_test.go b/api/get_refresh_token_info_test.go new file mode 100644 index 00000000..143ef4f9 --- /dev/null +++ b/api/get_refresh_token_info_test.go @@ -0,0 +1,61 @@ +package api + +import ( + "reflect" + "sync" + "testing" + + "github.com/Ptt-official-app/go-pttbbs/bbs" + "github.com/gin-gonic/gin" +) + +func TestGetRefreshTokenInfo(t *testing.T) { + setupTest(t.Name()) + defer teardownTest(t.Name()) + + refreshJwt, _ := CreateRefreshToken("SYSOP", "") + + params0 := &GetRefreshTokenInfoParams{ + Jwt: refreshJwt, + } + expected0 := &GetRefreshTokenInfoResult{ + UserID: "SYSOP", + } + + type args struct { + remoteAddr string + uuserID bbs.UUserID + params interface{} + c *gin.Context + } + tests := []struct { + name string + args args + expected interface{} + wantErr bool + }{ + // TODO: Add test cases. + { + args: args{uuserID: "SYSOP", params: params0}, + expected: expected0, + }, + } + var wg sync.WaitGroup + for _, tt := range tests { + wg.Add(1) + t.Run(tt.name, func(t *testing.T) { + defer wg.Done() + gotResult, err := GetRefreshTokenInfo(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.c) + if (err != nil) != tt.wantErr { + t.Errorf("GetRefreshTokenInfo() error = %v, wantErr %v", err, tt.wantErr) + return + } + result, _ := gotResult.(*GetRefreshTokenInfoResult) + result.Expire = 0 + if !reflect.DeepEqual(gotResult, tt.expected) { + t.Errorf("GetRefreshTokenInfo() = %v, want %v", gotResult, tt.expected) + } + }) + wg.Wait() + } +} diff --git a/api/get_token_info.go b/api/get_token_info.go index eaf809ea..ccbdfd54 100644 --- a/api/get_token_info.go +++ b/api/get_token_info.go @@ -14,6 +14,7 @@ type GetTokenInfoParams struct { type GetTokenInfoResult struct { ClientInfo string `json:"client_info"` UserID bbs.UUserID `json:"user_id"` + Expire int `json:"expire"` } func GetTokenInfoWrapper(c *gin.Context) { @@ -22,13 +23,13 @@ func GetTokenInfoWrapper(c *gin.Context) { LoginRequiredJSON(GetTokenInfo, params, c) } -func GetTokenInfo(remoteAddr string, uuserID bbs.UUserID, params interface{}) (result interface{}, err error) { +func GetTokenInfo(remoteAddr string, uuserID bbs.UUserID, params interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*GetTokenInfoParams) if !ok { return nil, ErrInvalidParams } - userID, clientInfo, err := VerifyJwt(theParams.Jwt) + userID, expireTS, clientInfo, err := VerifyJwt(theParams.Jwt, true) if err != nil { return nil, err } @@ -39,6 +40,7 @@ func GetTokenInfo(remoteAddr string, uuserID bbs.UUserID, params interface{}) (r result = &GetTokenInfoResult{ ClientInfo: clientInfo, UserID: userID, + Expire: expireTS, } return result, nil diff --git a/api/get_token_info_test.go b/api/get_token_info_test.go index 0725afb6..f4b81c07 100644 --- a/api/get_token_info_test.go +++ b/api/get_token_info_test.go @@ -44,11 +44,14 @@ func TestGetTokenInfo(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := GetTokenInfo(tt.args.remoteAddr, tt.args.uuserID, tt.args.params) + gotResult, err := GetTokenInfo(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, nil) if (err != nil) != tt.wantErr { t.Errorf("GetTokenInfo() error = %v, wantErr %v", err, tt.wantErr) return } + + result, _ := gotResult.(*GetTokenInfoResult) + result.Expire = 0 if !reflect.DeepEqual(gotResult, tt.expectedResult) { t.Errorf("GetTokenInfo() = %v, want %v", gotResult, tt.expectedResult) } diff --git a/api/get_user.go b/api/get_user.go index bbc14e23..cb014675 100644 --- a/api/get_user.go +++ b/api/get_user.go @@ -21,7 +21,7 @@ func GetUserWrapper(c *gin.Context) { LoginRequiredPathQuery(GetUser, nil, path, c) } -func GetUser(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (result interface{}, err error) { +func GetUser(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { thePath, ok := path.(*GetUserPath) if !ok { return nil, ErrInvalidPath diff --git a/api/get_user_test.go b/api/get_user_test.go index 6571cea6..70a8489b 100644 --- a/api/get_user_test.go +++ b/api/get_user_test.go @@ -38,7 +38,7 @@ func TestGetUser(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := GetUser(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResult, err := GetUser(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("GetUser() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/get_user_visit_count.go b/api/get_user_visit_count.go index a32d5350..a903560c 100644 --- a/api/get_user_visit_count.go +++ b/api/get_user_visit_count.go @@ -15,7 +15,7 @@ func GetUserVisitCountWrapper(c *gin.Context) { Query(GetUserVisitCount, nil, c) } -func GetUserVisitCount(remoteAddr string, params interface{}) (interface{}, error) { +func GetUserVisitCount(remoteAddr string, params interface{}, c *gin.Context) (interface{}, error) { total := bbs.GetUserVisitCount() return &GetUserVisitCountResult{total}, nil } diff --git a/api/get_user_visit_count_test.go b/api/get_user_visit_count_test.go index 6e04e9e5..c202e2f5 100644 --- a/api/get_user_visit_count_test.go +++ b/api/get_user_visit_count_test.go @@ -31,7 +31,7 @@ func TestGetUserVisitCount(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - got, err := GetUserVisitCount(testIP, nil) + got, err := GetUserVisitCount(testIP, nil, nil) if (err != nil) != tt.wantErr { t.Errorf("GetUserVisitCount() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/get_version.go b/api/get_version.go index 04ee4e74..1588f929 100644 --- a/api/get_version.go +++ b/api/get_version.go @@ -16,7 +16,7 @@ func GetVersionWrapper(c *gin.Context) { Query(GetVersion, nil, c) } -func GetVersion(remoteAddr string, params interface{}) (interface{}, error) { +func GetVersion(remoteAddr string, params interface{}, c *gin.Context) (interface{}, error) { return &GetVersionResult{ Version: types.VERSION, GitVersion: types.GIT_VERSION, diff --git a/api/index.go b/api/index.go index 218a1815..a6532ffd 100644 --- a/api/index.go +++ b/api/index.go @@ -18,7 +18,7 @@ func IndexWrapper(c *gin.Context) { LoginRequiredJSON(Index, params, c) } -func Index(remoteAddr string, uuserID bbs.UUserID, params interface{}) (interface{}, error) { +func Index(remoteAddr string, uuserID bbs.UUserID, params interface{}, c *gin.Context) (interface{}, error) { result := &IndexResult{Data: "index"} return result, nil } diff --git a/api/index_test.go b/api/index_test.go index eb914e1f..8ae88465 100644 --- a/api/index_test.go +++ b/api/index_test.go @@ -37,7 +37,7 @@ func TestIndex(t *testing.T) { t.Run(tt.name, func(t *testing.T) { defer wg.Done() - got, err := Index(testIP, tt.args.uuserID, tt.args.params) + got, err := Index(testIP, tt.args.uuserID, tt.args.params, nil) if (err != nil) != tt.wantErr { t.Errorf("Index() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/is_board_valid_user.go b/api/is_board_valid_user.go index 178ed7b3..0261834a 100644 --- a/api/is_board_valid_user.go +++ b/api/is_board_valid_user.go @@ -20,7 +20,7 @@ func IsBoardValidUserWrapper(c *gin.Context) { LoginRequiredPathQuery(IsBoardValidUser, nil, path, c) } -func IsBoardValidUser(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (result interface{}, err error) { +func IsBoardValidUser(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { thePath, ok := path.(*IsBoardValidUserPath) if !ok { return nil, ErrInvalidPath diff --git a/api/is_board_valid_user_test.go b/api/is_board_valid_user_test.go index 6146424d..7f23a3e3 100644 --- a/api/is_board_valid_user_test.go +++ b/api/is_board_valid_user_test.go @@ -43,7 +43,7 @@ func TestIsBoardValidUser(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := IsBoardValidUser(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResult, err := IsBoardValidUser(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("IsBoardValidUser() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/is_boards_valid_user.go b/api/is_boards_valid_user.go index 13232f31..336e2c88 100644 --- a/api/is_boards_valid_user.go +++ b/api/is_boards_valid_user.go @@ -20,7 +20,7 @@ func IsBoardsValidUserWrapper(c *gin.Context) { LoginRequiredJSON(IsBoardsValidUser, params, c) } -func IsBoardsValidUser(remoteAddr string, uuserID bbs.UUserID, params interface{}) (result interface{}, err error) { +func IsBoardsValidUser(remoteAddr string, uuserID bbs.UUserID, params interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*IsBoardsValidUserParams) if !ok { return nil, ErrInvalidParams diff --git a/api/is_boards_valid_user_test.go b/api/is_boards_valid_user_test.go index 32586bd1..79487bf1 100644 --- a/api/is_boards_valid_user_test.go +++ b/api/is_boards_valid_user_test.go @@ -50,7 +50,7 @@ func TestIsBoardsValidUser(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := IsBoardsValidUser(tt.args.remoteAddr, tt.args.uuserID, tt.args.params) + gotResult, err := IsBoardsValidUser(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, nil) if (err != nil) != tt.wantErr { t.Errorf("IsBoardsValidUser() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/load_auto_complete_boards.go b/api/load_auto_complete_boards.go index d40f745d..0a2ead3e 100644 --- a/api/load_auto_complete_boards.go +++ b/api/load_auto_complete_boards.go @@ -19,7 +19,7 @@ func LoadAutoCompleteBoardsWrapper(c *gin.Context) { LoginRequiredQuery(LoadAutoCompleteBoards, params, c) } -func LoadAutoCompleteBoards(remoteAddr string, uuserID bbs.UUserID, params interface{}) (interface{}, error) { +func LoadAutoCompleteBoards(remoteAddr string, uuserID bbs.UUserID, params interface{}, c *gin.Context) (interface{}, error) { theParams, ok := params.(*LoadAutoCompleteBoardsParams) if !ok { return nil, ErrInvalidParams diff --git a/api/load_auto_complete_boards_test.go b/api/load_auto_complete_boards_test.go index 086a49f7..a609ef92 100644 --- a/api/load_auto_complete_boards_test.go +++ b/api/load_auto_complete_boards_test.go @@ -46,7 +46,7 @@ func TestLoadAutoCompleteBoards(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - got, err := LoadAutoCompleteBoards(tt.args.remoteAddr, tt.args.uuserID, tt.args.params) + got, err := LoadAutoCompleteBoards(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, nil) if (err != nil) != tt.wantErr { t.Errorf("LoadAutoCompleteBoards() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/load_board_detail.go b/api/load_board_detail.go index 4f19eaca..c4d3237a 100644 --- a/api/load_board_detail.go +++ b/api/load_board_detail.go @@ -24,7 +24,7 @@ func LoadBoardDetailWrapper(c *gin.Context) { loginRequiredPathProcess(LoadBoardDetail, nil, path, c) } -func LoadBoardDetail(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (ret interface{}, err error) { +func LoadBoardDetail(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (ret interface{}, err error) { thePath, ok := path.(*LoadBoardDetailPath) if !ok { return nil, ErrInvalidPath diff --git a/api/load_board_detail_test.go b/api/load_board_detail_test.go index 572be26d..459bd715 100644 --- a/api/load_board_detail_test.go +++ b/api/load_board_detail_test.go @@ -43,7 +43,7 @@ func TestLoadBoardDetail(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResults, err := LoadBoardDetail(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResults, err := LoadBoardDetail(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("LoadBoardDetail() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/load_board_summary.go b/api/load_board_summary.go index 3055e6ab..bea5257e 100644 --- a/api/load_board_summary.go +++ b/api/load_board_summary.go @@ -24,7 +24,7 @@ func LoadBoardSummaryWrapper(c *gin.Context) { loginRequiredPathProcess(LoadBoardSummary, nil, path, c) } -func LoadBoardSummary(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (result interface{}, err error) { +func LoadBoardSummary(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { thePath, ok := path.(*LoadBoardSummaryPath) if !ok { return nil, ErrInvalidPath diff --git a/api/load_board_summary_test.go b/api/load_board_summary_test.go index 8bd795ba..fbacb87e 100644 --- a/api/load_board_summary_test.go +++ b/api/load_board_summary_test.go @@ -43,7 +43,7 @@ func TestLoadBoardSummary(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResults, err := LoadBoardSummary(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResults, err := LoadBoardSummary(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("LoadBoardSummary() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/load_boards_by_bids.go b/api/load_boards_by_bids.go index cc1c747f..b18428b8 100644 --- a/api/load_boards_by_bids.go +++ b/api/load_boards_by_bids.go @@ -21,7 +21,7 @@ func LoadBoardsByBidsWrapper(c *gin.Context) { LoginRequiredJSON(LoadBoardsByBids, params, c) } -func LoadBoardsByBids(remoteAddr string, uuserID bbs.UUserID, params interface{}) (result interface{}, err error) { +func LoadBoardsByBids(remoteAddr string, uuserID bbs.UUserID, params interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*LoadBoardsByBidsParams) if !ok { return nil, ErrInvalidParams diff --git a/api/load_boards_by_bids_test.go b/api/load_boards_by_bids_test.go index f4561775..91f6ccc6 100644 --- a/api/load_boards_by_bids_test.go +++ b/api/load_boards_by_bids_test.go @@ -43,7 +43,7 @@ func TestLoadBoardsByBids(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := LoadBoardsByBids(tt.args.remoteAddr, tt.args.uuserID, tt.args.params) + gotResult, err := LoadBoardsByBids(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, nil) if (err != nil) != tt.wantErr { t.Errorf("LoadBoardsByBids() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/load_bottom_articles.go b/api/load_bottom_articles.go index eb82c595..9e945a99 100644 --- a/api/load_bottom_articles.go +++ b/api/load_bottom_articles.go @@ -20,7 +20,7 @@ func LoadBottomArticlesWrapper(c *gin.Context) { LoginRequiredPathQuery(LoadBottomArticles, nil, path, c) } -func LoadBottomArticles(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (result interface{}, err error) { +func LoadBottomArticles(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { thePath, ok := path.(*LoadGeneralArticlesPath) if !ok { return nil, ErrInvalidPath diff --git a/api/load_bottom_articles_test.go b/api/load_bottom_articles_test.go index 51ee80b2..ba5b41ff 100644 --- a/api/load_bottom_articles_test.go +++ b/api/load_bottom_articles_test.go @@ -41,7 +41,7 @@ func TestLoadBottomArticles(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := LoadBottomArticles(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResult, err := LoadBottomArticles(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("LoadBottomArticles() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/load_class_boards.go b/api/load_class_boards.go index bd234d0b..e2c29d01 100644 --- a/api/load_class_boards.go +++ b/api/load_class_boards.go @@ -27,7 +27,7 @@ func LoadClassBoardsWrapper(c *gin.Context) { LoginRequiredPathQuery(LoadClassBoards, params, path, c) } -func LoadClassBoards(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (ret interface{}, err error) { +func LoadClassBoards(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (ret interface{}, err error) { theParams, ok := params.(*LoadClassBoardsParams) if !ok { return nil, ErrInvalidParams diff --git a/api/load_class_boards_test.go b/api/load_class_boards_test.go index b28ab117..e174e727 100644 --- a/api/load_class_boards_test.go +++ b/api/load_class_boards_test.go @@ -41,7 +41,7 @@ func TestLoadClassBoards(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotRet, err := LoadClassBoards(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotRet, err := LoadClassBoards(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("LoadClassBoards() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/load_full_class_boards.go b/api/load_full_class_boards.go index 2c0c2a8d..a4338f7e 100644 --- a/api/load_full_class_boards.go +++ b/api/load_full_class_boards.go @@ -24,7 +24,7 @@ func LoadFullClassBoardsWrapper(c *gin.Context) { LoginRequiredQuery(LoadFullClassBoards, params, c) } -func LoadFullClassBoards(remoteAddr string, uuserID bbs.UUserID, params interface{}) (ret interface{}, err error) { +func LoadFullClassBoards(remoteAddr string, uuserID bbs.UUserID, params interface{}, c *gin.Context) (ret interface{}, err error) { theParams, ok := params.(*LoadFullClassBoardsParams) if !ok { return nil, ErrInvalidParams diff --git a/api/load_full_class_boards_test.go b/api/load_full_class_boards_test.go index 15cc70fa..040c0d48 100644 --- a/api/load_full_class_boards_test.go +++ b/api/load_full_class_boards_test.go @@ -43,7 +43,7 @@ func TestLoadFullClassBoards(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotRet, err := LoadFullClassBoards(tt.args.remoteAddr, tt.args.uuserID, tt.args.params) + gotRet, err := LoadFullClassBoards(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, nil) if (err != nil) != tt.wantErr { t.Errorf("LoadFullClassBoards() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/load_general_articles.go b/api/load_general_articles.go index 259109b6..185927fa 100644 --- a/api/load_general_articles.go +++ b/api/load_general_articles.go @@ -40,7 +40,7 @@ func LoadGeneralArticlesWrapper(c *gin.Context) { LoginRequiredPathQuery(LoadGeneralArticles, params, path, c) } -func LoadGeneralArticles(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (result interface{}, err error) { +func LoadGeneralArticles(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*LoadGeneralArticlesParams) if !ok { return nil, ErrInvalidParams diff --git a/api/load_general_articles_test.go b/api/load_general_articles_test.go index 44079d77..3646955f 100644 --- a/api/load_general_articles_test.go +++ b/api/load_general_articles_test.go @@ -89,7 +89,7 @@ func TestLoadGeneralArticles(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := LoadGeneralArticles(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResult, err := LoadGeneralArticles(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("LoadGeneralArticles() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/load_general_board_details.go b/api/load_general_board_details.go index 57f090c7..162e4284 100644 --- a/api/load_general_board_details.go +++ b/api/load_general_board_details.go @@ -31,7 +31,7 @@ func LoadGeneralBoardDetailsWrapper(c *gin.Context) { LoginRequiredQuery(LoadGeneralBoardDetails, params, c) } -func LoadGeneralBoardDetails(remoteAddr string, uuserID bbs.UUserID, params interface{}) (interface{}, error) { +func LoadGeneralBoardDetails(remoteAddr string, uuserID bbs.UUserID, params interface{}, c *gin.Context) (interface{}, error) { return loadGeneralBoardDetailsCore(remoteAddr, uuserID, params, ptttype.BSORT_BY_NAME) } diff --git a/api/load_general_board_details_test.go b/api/load_general_board_details_test.go index 0fd312cd..f4ade034 100644 --- a/api/load_general_board_details_test.go +++ b/api/load_general_board_details_test.go @@ -41,7 +41,7 @@ func TestLoadGeneralBoardDetails(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := LoadGeneralBoardDetails(tt.args.remoteAddr, tt.args.uuserID, tt.args.params) + got, err := LoadGeneralBoardDetails(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, nil) if (err != nil) != tt.wantErr { t.Errorf("LoadGeneralBoardDetails() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/load_general_boards.go b/api/load_general_boards.go index b38c02d1..0a72e1b8 100644 --- a/api/load_general_boards.go +++ b/api/load_general_boards.go @@ -33,7 +33,7 @@ func LoadGeneralBoardsWrapper(c *gin.Context) { LoginRequiredQuery(LoadGeneralBoards, params, c) } -func LoadGeneralBoards(remoteAddr string, uuserID bbs.UUserID, params interface{}) (interface{}, error) { +func LoadGeneralBoards(remoteAddr string, uuserID bbs.UUserID, params interface{}, c *gin.Context) (interface{}, error) { return loadGeneralBoardsCore(remoteAddr, uuserID, params, ptttype.BSORT_BY_NAME) } diff --git a/api/load_general_boards_by_class.go b/api/load_general_boards_by_class.go index 797f3c5c..35857d0f 100644 --- a/api/load_general_boards_by_class.go +++ b/api/load_general_boards_by_class.go @@ -13,6 +13,6 @@ func LoadGeneralBoardsByClassWrapper(c *gin.Context) { LoginRequiredQuery(LoadGeneralBoardsByClass, params, c) } -func LoadGeneralBoardsByClass(remoteAddr string, uuserID bbs.UUserID, params interface{}) (interface{}, error) { +func LoadGeneralBoardsByClass(remoteAddr string, uuserID bbs.UUserID, params interface{}, c *gin.Context) (interface{}, error) { return loadGeneralBoardsCore(remoteAddr, uuserID, params, ptttype.BSORT_BY_CLASS) } diff --git a/api/load_general_boards_by_class_test.go b/api/load_general_boards_by_class_test.go index 29d26d34..ab401c7b 100644 --- a/api/load_general_boards_by_class_test.go +++ b/api/load_general_boards_by_class_test.go @@ -45,7 +45,7 @@ func TestLoadGeneralBoardsByClass(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - got, err := LoadGeneralBoardsByClass(tt.args.remoteAddr, tt.args.uuserID, tt.args.params) + got, err := LoadGeneralBoardsByClass(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, nil) if (err != nil) != tt.wantErr { t.Errorf("LoadGeneralBoardsByClass() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/load_general_boards_test.go b/api/load_general_boards_test.go index e69512e7..84645486 100644 --- a/api/load_general_boards_test.go +++ b/api/load_general_boards_test.go @@ -61,7 +61,7 @@ func TestLoadGeneralBoards(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - got, err := LoadGeneralBoards(testIP, tt.args.uuserID, tt.args.params) + got, err := LoadGeneralBoards(testIP, tt.args.uuserID, tt.args.params, nil) if (err != nil) != tt.wantErr { t.Errorf("LoadGeneralBoards() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/load_hot_boards.go b/api/load_hot_boards.go index 6aa6be80..3bb5f851 100644 --- a/api/load_hot_boards.go +++ b/api/load_hot_boards.go @@ -16,7 +16,7 @@ func LoadHotBoardsWrapper(c *gin.Context) { } // We have only 128 hot-boards. -func LoadHotBoards(remoteAddr string, uuserID bbs.UUserID, params interface{}) (result interface{}, err error) { +func LoadHotBoards(remoteAddr string, uuserID bbs.UUserID, params interface{}, c *gin.Context) (result interface{}, err error) { summary, err := bbs.LoadHotBoards(uuserID) if err != nil { return nil, err diff --git a/api/load_hot_boards_test.go b/api/load_hot_boards_test.go index 25744189..b790e71f 100644 --- a/api/load_hot_boards_test.go +++ b/api/load_hot_boards_test.go @@ -44,7 +44,7 @@ func TestLoadHotBoards(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := LoadHotBoards(tt.args.remoteAddr, tt.args.uuserID, tt.args.params) + gotResult, err := LoadHotBoards(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, nil) if (err != nil) != tt.wantErr { t.Errorf("LoadHotBoards() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/login.go b/api/login.go index e7558a01..22acead3 100644 --- a/api/login.go +++ b/api/login.go @@ -17,6 +17,7 @@ type LoginResult struct { UserID bbs.UUserID `json:"user_id"` Jwt string `json:"access_token"` TokenType string `json:"token_type"` + Refresh string `json:"refresh_token"` } func LoginWrapper(c *gin.Context) { @@ -24,7 +25,7 @@ func LoginWrapper(c *gin.Context) { JSON(Login, params, c) } -func Login(remoteAddr string, params interface{}) (interface{}, error) { +func Login(remoteAddr string, params interface{}, c *gin.Context) (interface{}, error) { loginParams, ok := params.(*LoginParams) if !ok { return nil, ErrInvalidParams @@ -40,10 +41,16 @@ func Login(remoteAddr string, params interface{}) (interface{}, error) { return nil, err } + refreshToken, err := CreateRefreshToken(uuserID, loginParams.ClientInfo) + if err != nil { + return nil, err + } + result := &LoginResult{ UserID: uuserID, Jwt: token, TokenType: "bearer", + Refresh: refreshToken, } return result, nil diff --git a/api/login_test.go b/api/login_test.go index fa5ddfa1..20ad9812 100644 --- a/api/login_test.go +++ b/api/login_test.go @@ -33,7 +33,7 @@ func TestLogin(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - got, err := Login(testIP, tt.args.params) + got, err := Login(testIP, tt.args.params, nil) if (err != nil) != tt.wantErr { t.Errorf("Login() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/refresh.go b/api/refresh.go new file mode 100644 index 00000000..d77d5e83 --- /dev/null +++ b/api/refresh.go @@ -0,0 +1,73 @@ +package api + +import ( + "github.com/gin-gonic/gin" +) + +const REFRESH_R = "/refresh" + +type RefreshParams struct { + ClientInfo string `json:"client_info"` + Refresh string `json:"refresh_token"` +} + +type RefreshResult LoginResult + +func RefreshWrapper(c *gin.Context) { + params := &RefreshParams{} + JSON(Refresh, params, c) +} + +func Refresh(remoteAddr string, params interface{}, c *gin.Context) (result interface{}, err error) { + theParams, ok := params.(*RefreshParams) + if !ok { + return nil, ErrInvalidParams + } + + jwt := GetJwt(c) + + jwtUserID, jwtExpireTS, jwtClientInfo, err := VerifyJwt(jwt, false) + if err != nil { + return nil, ErrInvalidToken + } + + userID, refreshExpireTS, clientInfo, err := VerifyRefreshJwt(theParams.Refresh) + if err != nil { + return nil, ErrInvalidToken + } + + // verify that jwt and refresh-jwt are with same pair. + diffExpireTS := refreshExpireTS - jwtExpireTS + expectedDiffExpireTS := REFRESH_JWT_TOKEN_EXPIRE_TS - JWT_TOKEN_EXPIRE_TS + diffDiffExpireTS := diffExpireTS - expectedDiffExpireTS + if diffDiffExpireTS > EPSILON_EXPIRE_TS || diffDiffExpireTS < -EPSILON_EXPIRE_TS { + return nil, ErrInvalidToken + } + + if clientInfo != theParams.ClientInfo && clientInfo != jwtClientInfo { + return nil, ErrInvalidToken + } + + if userID != jwtUserID { + return nil, ErrInvalidToken + } + + token, err := CreateToken(userID, clientInfo) + if err != nil { + return nil, err + } + + refreshToken, err := CreateRefreshToken(userID, clientInfo) + if err != nil { + return nil, err + } + + result = &RefreshResult{ + UserID: userID, + Jwt: token, + TokenType: "bearer", + Refresh: refreshToken, + } + + return result, nil +} diff --git a/api/refresh_test.go b/api/refresh_test.go new file mode 100644 index 00000000..cd6ab014 --- /dev/null +++ b/api/refresh_test.go @@ -0,0 +1,70 @@ +package api + +import ( + "net/http" + "reflect" + "sync" + "testing" + + "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" +) + +func TestRefresh(t *testing.T) { + setupTest(t.Name()) + defer teardownTest(t.Name()) + + jwt, _ := CreateToken("SYSOP", "") + refreshJwt, _ := CreateRefreshToken("SYSOP", "") + + logrus.Infof("TestRefresh: jwt: %v refreshJwt: %v", jwt, refreshJwt) + + params0 := &RefreshParams{ + ClientInfo: "", + Refresh: refreshJwt, + } + + req, _ := http.NewRequest("POST", "http://localhost/refresh", nil) + req.Header = map[string][]string{ + "Authorization": {"bearer " + jwt}, + } + c0 := &gin.Context{} + c0.Request = req + + type args struct { + remoteAddr string + params interface{} + c *gin.Context + } + tests := []struct { + name string + args args + expected interface{} + wantErr bool + }{ + // TODO: Add test cases. + { + args: args{params: params0, c: c0}, + expected: &RefreshResult{UserID: "SYSOP", TokenType: "bearer"}, + }, + } + var wg sync.WaitGroup + for _, tt := range tests { + wg.Add(1) + t.Run(tt.name, func(t *testing.T) { + defer wg.Done() + gotResult, err := Refresh(tt.args.remoteAddr, tt.args.params, tt.args.c) + if (err != nil) != tt.wantErr { + t.Errorf("Refresh() error = %v, wantErr %v", err, tt.wantErr) + return + } + result, _ := gotResult.(*RefreshResult) + result.Jwt = "" + result.Refresh = "" + if !reflect.DeepEqual(gotResult, tt.expected) { + t.Errorf("Refresh() = %v, want %v", gotResult, tt.expected) + } + }) + wg.Wait() + } +} diff --git a/api/register.go b/api/register.go index c1a3b1bd..6ce1a6d2 100644 --- a/api/register.go +++ b/api/register.go @@ -31,7 +31,7 @@ func RegisterWrapper(c *gin.Context) { JSON(Register, params, c) } -func Register(remoteAddr string, params interface{}) (interface{}, error) { +func Register(remoteAddr string, params interface{}, c *gin.Context) (interface{}, error) { registerParams, ok := params.(*RegisterParams) if !ok { return nil, ErrInvalidParams diff --git a/api/register_test.go b/api/register_test.go index 2535a042..2e49d377 100644 --- a/api/register_test.go +++ b/api/register_test.go @@ -28,7 +28,7 @@ func TestRegister(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - got, err := Register(testIP, tt.args.params) + got, err := Register(testIP, tt.args.params, nil) if (err != nil) != tt.wantErr { t.Errorf("Register() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/reload_uhash.go b/api/reload_uhash.go index da02f15c..2abb6aa3 100644 --- a/api/reload_uhash.go +++ b/api/reload_uhash.go @@ -15,7 +15,7 @@ func ReloadUHashWrapper(c *gin.Context) { LoginRequiredQuery(ReloadUHash, nil, c) } -func ReloadUHash(remoteAddr string, uuserID bbs.UUserID, params interface{}) (result interface{}, err error) { +func ReloadUHash(remoteAddr string, uuserID bbs.UUserID, params interface{}, c *gin.Context) (result interface{}, err error) { err = bbs.ReloadUHash(uuserID) if err != nil { return nil, err diff --git a/api/reload_uhash_test.go b/api/reload_uhash_test.go index 0d93263b..a96c6040 100644 --- a/api/reload_uhash_test.go +++ b/api/reload_uhash_test.go @@ -35,7 +35,7 @@ func TestReloadUHash(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := ReloadUHash(tt.args.remoteAddr, tt.args.uuserID, tt.args.params) + gotResult, err := ReloadUHash(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, nil) if (err != nil) != tt.wantErr { t.Errorf("ReloadUHash() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/set_id_email.go b/api/set_id_email.go index 9f06c1f9..a580ba42 100644 --- a/api/set_id_email.go +++ b/api/set_id_email.go @@ -29,7 +29,7 @@ func SetIDEmailWrapper(c *gin.Context) { LoginRequiredPathJSON(SetIDEmail, params, path, c) } -func SetIDEmail(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (result interface{}, err error) { +func SetIDEmail(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*SetIDEmailParams) if !ok { return nil, ErrInvalidParams diff --git a/api/set_id_email_test.go b/api/set_id_email_test.go index 735c8472..8630b4c5 100644 --- a/api/set_id_email_test.go +++ b/api/set_id_email_test.go @@ -53,7 +53,7 @@ func TestSetIDEmail(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := SetIDEmail(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResult, err := SetIDEmail(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("SetIDEmail() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/set_user_perm.go b/api/set_user_perm.go index b7c9c904..a2e667fe 100644 --- a/api/set_user_perm.go +++ b/api/set_user_perm.go @@ -26,7 +26,7 @@ func SetUserPermWrapper(c *gin.Context) { LoginRequiredPathJSON(SetUserPerm, params, path, c) } -func SetUserPerm(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (result interface{}, err error) { +func SetUserPerm(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*SetUserPermParams) if !ok { return nil, ErrInvalidParams diff --git a/api/set_user_perm_test.go b/api/set_user_perm_test.go index b97f158b..f91a0737 100644 --- a/api/set_user_perm_test.go +++ b/api/set_user_perm_test.go @@ -51,7 +51,7 @@ func TestSetUserPerm(t *testing.T) { wg.Add(1) t.Run(tt.name, func(t *testing.T) { defer wg.Done() - gotResult, err := SetUserPerm(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResult, err := SetUserPerm(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("SetUserPerm() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/api/types.go b/api/types.go index 25fc79b6..35ca5e94 100644 --- a/api/types.go +++ b/api/types.go @@ -2,13 +2,14 @@ package api import ( "github.com/Ptt-official-app/go-pttbbs/bbs" + "github.com/gin-gonic/gin" ) -type APIFunc func(remoteAddr string, params interface{}) (interface{}, error) +type APIFunc func(remoteAddr string, params interface{}, c *gin.Context) (interface{}, error) -type LoginRequiredAPIFunc func(remoteAddr string, uuserID bbs.UUserID, params interface{}) (interface{}, error) +type LoginRequiredAPIFunc func(remoteAddr string, uuserID bbs.UUserID, params interface{}, c *gin.Context) (interface{}, error) -type LoginRequiredPathAPIFunc func(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (interface{}, error) +type LoginRequiredPathAPIFunc func(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (interface{}, error) type JwtClaim struct { ClientInfo string `json:"cli"` @@ -16,6 +17,13 @@ type JwtClaim struct { Expire int `json:"exp"` } +type RefreshJwtClaim struct { + ClientInfo string `json:"cli"` + UUserID string `json:"sub"` + Expire int `json:"exp"` + TheType string `json:"typ"` +} + type EmailJwtClaim struct { ClientInfo string `json:"cli"` UUserID string `json:"sub"` diff --git a/api/user_utils.go b/api/user_utils.go index 6b458956..18ffd096 100644 --- a/api/user_utils.go +++ b/api/user_utils.go @@ -35,7 +35,7 @@ func userInfoIsValidEmailUser(uuserID bbs.UUserID, queryUUserID bbs.UUserID, jwt return false, "" } - emailUserID, _, email, err := VerifyEmailJwt(jwt, context) + emailUserID, _, _, email, err := VerifyEmailJwt(jwt, context) if err != nil { return false, "" } diff --git a/api/write_favorites.go b/api/write_favorites.go index f4f8574f..60266413 100644 --- a/api/write_favorites.go +++ b/api/write_favorites.go @@ -26,7 +26,7 @@ func WriteFavoritesWrapper(c *gin.Context) { LoginRequiredPathJSON(WriteFavorites, params, path, c) } -func WriteFavorites(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}) (result interface{}, err error) { +func WriteFavorites(remoteAddr string, uuserID bbs.UUserID, params interface{}, path interface{}, c *gin.Context) (result interface{}, err error) { theParams, ok := params.(*WriteFavoritesParams) if !ok { return nil, ErrInvalidParams diff --git a/api/write_favorites_test.go b/api/write_favorites_test.go index d5867c45..2310bdfc 100644 --- a/api/write_favorites_test.go +++ b/api/write_favorites_test.go @@ -61,7 +61,7 @@ func TestWriteFavorites(t *testing.T) { t.Run(tt.name, func(t *testing.T) { defer wg.Done() now := types.NowTS() - gotResult, err := WriteFavorites(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path) + gotResult, err := WriteFavorites(tt.args.remoteAddr, tt.args.uuserID, tt.args.params, tt.args.path, nil) if (err != nil) != tt.wantErr { t.Errorf("WriteFavorites() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/apidoc/apidoc.py b/apidoc/apidoc.py index 43bd6b6c..b5698f8a 100644 --- a/apidoc/apidoc.py +++ b/apidoc/apidoc.py @@ -289,6 +289,21 @@ def _get_email_token_info(): """ return '' +@app.route(_with_app_prefix('/refresh'), methods=['POST']) +def _refresh(): + """ + swagger_from_file: apidoc/refresh.yaml + """ + return '' + + +@app.route(_with_app_prefix('/refreshtoken/info'), methods=['POST']) +def _get_refresh_token_info(): + """ + swagger_from_file: apidoc/get_refresh_token_info.yaml + """ + return '' + @app.route(_with_app_prefix('/user//favorites'), methods=['GET']) def _get_fav(uid): diff --git a/apidoc/defs/token.yaml b/apidoc/defs/token.yaml index e44c2290..8110603c 100644 --- a/apidoc/defs/token.yaml +++ b/apidoc/defs/token.yaml @@ -6,4 +6,6 @@ Token: access_token: type: string token_type: - type: string \ No newline at end of file + type: string + refresh_token: + type: string diff --git a/apidoc/get_email_token_info.yaml b/apidoc/get_email_token_info.yaml index 27439b1b..ccb4b888 100644 --- a/apidoc/get_email_token_info.yaml +++ b/apidoc/get_email_token_info.yaml @@ -33,3 +33,5 @@ responses: type: string email: type: string + expire: + type: number diff --git a/apidoc/get_refresh_token_info.yaml b/apidoc/get_refresh_token_info.yaml new file mode 100644 index 00000000..07dd8060 --- /dev/null +++ b/apidoc/get_refresh_token_info.yaml @@ -0,0 +1,32 @@ +getRefreshTokenInfo +--- +tags: + - user +description: get refresh token info +parameters: + - '$ref': '#/definitions/Host' + - '$ref': '#/definitions/XForwardedFor' + - '$ref': '#/definitions/Authorization' + - name: params + in: body + schema: + '$id': https://json-schema.org/draft/2019-09/output/schema + type: object + properties: + token: + type: string + required: true + description: token +responses: + 200: + schema: + '$id': https://json-schema.org/draft/2019-09/output/schema + type: object + properties: + client_info: + type: string + user_id: + type: string + expire: + type: number + description: expire in ts diff --git a/apidoc/get_token_info.yaml b/apidoc/get_token_info.yaml index fb4abcbf..01621391 100644 --- a/apidoc/get_token_info.yaml +++ b/apidoc/get_token_info.yaml @@ -27,3 +27,6 @@ responses: type: string user_id: type: string + expire: + type: number + description: expire in ts diff --git a/apidoc/refresh.yaml b/apidoc/refresh.yaml new file mode 100644 index 00000000..be0efd28 --- /dev/null +++ b/apidoc/refresh.yaml @@ -0,0 +1,21 @@ +Refresh +--- +tags: + - user +description: refresh token +parameters: + - '$ref': '#/definitions/Host' + - '$ref': '#/definitions/XForwardedFor' + - '$ref': '#/definitions/Authorization' + - name: params + in: body + schema: + '$id': https://json-schema.org/draft/2019-09/output/schema + type: object + properties: +responses: + 200: + description: index-response + schema: + '$id': https://json-schema.org/draft/2019-09/output/schema + '$ref': '#/definitions/Token' diff --git a/initgin/init_gin.go b/initgin/init_gin.go index d777463f..a3a7fcfb 100644 --- a/initgin/init_gin.go +++ b/initgin/init_gin.go @@ -70,6 +70,9 @@ func InitGin() (*gin.Engine, error) { router.GET(withPrefix(api.GET_USER_VISIT_COUNT_R), api.GetUserVisitCountWrapper) router.POST(withPrefix(api.WRITE_FAV_R), api.WriteFavoritesWrapper) + router.POST(withPrefix(api.REFRESH_R), api.RefreshWrapper) + router.POST(withPrefix(api.GET_REFRESH_TOKEN_INFO_R), api.GetRefreshTokenInfoWrapper) + // admin router.GET(withPrefix(api.RELOAD_UHASH_R), api.ReloadUHashWrapper) router.POST(withPrefix(api.SET_USER_PERM_R), api.SetUserPermWrapper)