rag/internal/logic/auth.go

104 lines
2.3 KiB
Go
Raw Permalink Normal View History

2024-07-14 09:44:49 +00:00
package logic
import (
2024-07-15 17:48:05 +00:00
"context"
2024-07-14 09:44:49 +00:00
"github.com/gin-gonic/gin"
"github.com/mitchellh/mapstructure"
"go.uber.org/zap"
2024-07-14 15:58:23 +00:00
"leafdev.top/leaf/rag/consts"
"leafdev.top/leaf/rag/internal/providers"
"leafdev.top/leaf/rag/internal/providers/jwks"
"leafdev.top/leaf/rag/models"
2024-07-14 09:44:49 +00:00
"strings"
)
type AuthLogic struct {
}
var (
2024-07-14 14:16:03 +00:00
config = *providers.MustGet[providers.GlobalConfig]()
logger = *providers.MustGet[zap.Logger]()
2024-07-14 09:44:49 +00:00
)
func NewAuthLogic() *AuthLogic {
return &AuthLogic{}
}
2024-07-14 14:14:20 +00:00
func (a *AuthLogic) GinMiddlewareAuth(tokenType models.JWTTokenTypes, c *gin.Context) (*models.User, error) {
2024-07-14 10:14:01 +00:00
var sub = consts.AnonymousUser
2024-07-14 14:14:20 +00:00
var jwtIdToken = &models.User{}
2024-07-14 09:44:49 +00:00
if config.DebugMode.Enable {
jwtIdToken.Token.Sub = sub
jwtIdToken.Valid = true
return jwtIdToken, nil
} else {
authorization := c.Request.Header.Get(consts.AuthHeader)
if authorization == "" {
2024-07-14 14:16:03 +00:00
return nil, consts.ErrJWTFormatError
2024-07-14 09:44:49 +00:00
}
authSplit := strings.Split(authorization, " ")
if len(authSplit) != 2 {
2024-07-14 14:16:03 +00:00
return nil, consts.ErrJWTFormatError
2024-07-14 09:44:49 +00:00
}
if authSplit[0] != consts.AuthPrefix {
2024-07-14 14:16:03 +00:00
return nil, consts.ErrNotBearerType
2024-07-14 09:44:49 +00:00
}
token, err := jwks.ParseJWT(authSplit[1])
if err != nil {
2024-07-14 14:16:03 +00:00
return nil, consts.ErrNotValidToken
2024-07-14 09:44:49 +00:00
}
sub, err = token.Claims.GetSubject()
if err != nil {
2024-07-14 14:16:03 +00:00
return nil, consts.ErrNotValidToken
2024-07-14 09:44:49 +00:00
}
2024-07-14 10:14:01 +00:00
// 如果 token.Header 中没有 typ
if token.Header["typ"] == "" {
2024-07-14 14:16:03 +00:00
return nil, consts.ErrEmptyResponse
2024-07-14 10:14:01 +00:00
}
// 验证 token 类型
if tokenType != "" && tokenType.String() != token.Header["typ"] {
2024-07-14 14:16:03 +00:00
return nil, consts.ErrTokenError
2024-07-14 10:14:01 +00:00
}
jwtIdToken.Valid = true
2024-07-14 09:44:49 +00:00
err = mapstructure.Decode(token.Claims, &jwtIdToken.Token)
if err != nil {
logger.Error("Failed to map token claims to JwtIDToken struct.\nError: " + err.Error())
return nil, nil
}
2024-07-14 10:14:01 +00:00
2024-07-14 09:44:49 +00:00
}
return jwtIdToken, nil
}
2024-07-14 14:14:20 +00:00
func (a *AuthLogic) GinUser(c *gin.Context) *models.User {
2024-07-14 09:44:49 +00:00
user, _ := c.Get(consts.AuthMiddlewareKey)
2024-07-14 14:14:20 +00:00
return user.(*models.User)
2024-07-14 09:44:49 +00:00
}
2024-07-15 18:05:02 +00:00
func GetUserId(ctx context.Context) string {
user := GetUser(ctx)
return user.Token.Sub
}
2024-07-15 17:48:05 +00:00
func GetUser(ctx context.Context) *models.User {
user := ctx.Value(consts.AuthMiddlewareKey)
return user.(*models.User)
}
2024-07-17 12:58:03 +00:00
func SetUser(ctx context.Context, user *models.User) context.Context {
context.WithValue(ctx, consts.AuthMiddlewareKey, user)
return context.WithValue(ctx, consts.AuthMiddlewareKey, user)
}