改进 认证和 JWKS 刷新机制

This commit is contained in:
ivamp 2024-07-16 01:48:05 +08:00
parent f54b8656d5
commit 9a8a33beac
5 changed files with 38 additions and 26 deletions

View File

@ -1,6 +1,7 @@
package logic package logic
import ( import (
"context"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"go.uber.org/zap" "go.uber.org/zap"
@ -88,3 +89,9 @@ func GetUserId(ctx *gin.Context) string {
logic := AuthLogic{} logic := AuthLogic{}
return logic.GinUser(ctx).Token.Sub return logic.GinUser(ctx).Token.Sub
} }
func GetUser(ctx context.Context) *models.User {
user := ctx.Value(consts.AuthMiddlewareKey)
return user.(*models.User)
}

View File

@ -19,7 +19,7 @@ func JwtAuth(ctx context.Context) (context.Context, error) {
} }
sub := consts.AnonymousUser sub := consts.AnonymousUser
var jwtIdToken = &models.User{} var jwtIdToken = models.User{}
if config.DebugMode.Enable { if config.DebugMode.Enable {
jwtIdToken.Token.Sub = sub jwtIdToken.Token.Sub = sub
@ -45,14 +45,14 @@ func JwtAuth(ctx context.Context) (context.Context, error) {
jwtIdToken.Valid = true jwtIdToken.Valid = true
err = mapstructure.Decode(token.Claims, &jwtIdToken) err = mapstructure.Decode(token.Claims, &jwtIdToken.Token)
if err != nil { if err != nil {
logger.Error("Failed to map token claims to JwtIDToken struct.\nError: " + err.Error()) logger.Error("Failed to map token claims to JwtIDToken struct.\nError: " + err.Error())
return nil, err return nil, err
} }
} }
ctx = logging.InjectFields(ctx, logging.Fields{"auth.sub", sub}) ctx = logging.InjectFields(ctx, logging.Fields{consts.AuthMiddlewareKey, sub})
return context.WithValue(ctx, "auth", jwtIdToken), nil return context.WithValue(ctx, consts.AuthMiddlewareKey, &jwtIdToken), nil
} }

View File

@ -3,10 +3,18 @@ package jwks
import "time" import "time"
func InitJwksRefresh() { func InitJwksRefresh() {
// 先刷新一次
RefreshJWKS()
var firstRefreshed = true
// 启动一个定时器 // 启动一个定时器
go func() { go func() {
for { for {
RefreshJWKS() if firstRefreshed {
firstRefreshed = false
} else {
RefreshJWKS()
}
time.Sleep(refreshRate) time.Sleep(refreshRate)
} }
}() }()

View File

@ -21,7 +21,6 @@ var logger = providers.MustGet[zap.Logger]()
var config = providers.MustGet[providers.GlobalConfig]() var config = providers.MustGet[providers.GlobalConfig]()
func RefreshJWKS() { func RefreshJWKS() {
logger.Info("Refreshing JWKS...") logger.Info("Refreshing JWKS...")
var err error var err error

View File

@ -1,26 +1,24 @@
package models package models
import "time"
type UserTokenInfo struct { type UserTokenInfo struct {
Exp int `json:"exp"` Aud string `json:"aud"`
Iat int `json:"iat"` Iss string `json:"iss"`
AuthTime int `json:"auth_time"` Iat float64 `json:"iat"`
Jti string `json:"jti"` Exp float64 `json:"exp"`
Iss string `json:"iss"` Sub string `json:"sub"`
Aud string `json:"aud"` Scopes []string `json:"scopes"`
Sub string `json:"sub"` Id int `json:"id"`
Typ string `json:"typ"` Uuid string `json:"uuid"`
Azp string `json:"azp"` Avatar string `json:"avatar"`
SessionState string `json:"session_state"` Name string `json:"name"`
AtHash string `json:"at_hash"` EmailVerified bool `json:"email_verified"`
Acr string `json:"acr"` RealNameVerified bool `json:"real_name_verified"`
Sid string `json:"sid"` PhoneVerified bool `json:"phone_verified"`
EmailVerified bool `json:"email_verified"` Email string `json:"email"`
Name string `json:"name"` Phone string `json:"phone"`
PreferredUsername string `json:"preferred_username"` CreatedAt time.Time `json:"created_at"`
GivenName string `json:"given_name"`
FamilyName string `json:"family_name"`
Email string `json:"email"`
Groups []string `json:"groups"`
} }
type User struct { type User struct {