From 852af57c0393429dcbd8de977c8edc09413dda81 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 23 Apr 2023 18:24:11 +0800 Subject: [PATCH] Relay done but not working --- common/constants.go | 18 ++++++++++ controller/channel.go | 2 +- controller/relay.go | 52 +++++++++++++++++++++++++++++ middleware/auth.go | 69 ++++++++++++--------------------------- middleware/distributor.go | 68 ++++++++++++++++++++++++++++++++++++++ model/channel.go | 23 +++++++++++-- model/main.go | 1 + model/token.go | 27 ++++++++++++++- model/user.go | 13 -------- router/api-router.go | 6 ++-- router/main.go | 1 + router/relay-router.go | 15 +++++++++ 12 files changed, 225 insertions(+), 70 deletions(-) create mode 100644 controller/relay.go create mode 100644 middleware/distributor.go create mode 100644 router/relay-router.go diff --git a/common/constants.go b/common/constants.go index 722df3bc..a5f1eb67 100644 --- a/common/constants.go +++ b/common/constants.go @@ -12,6 +12,8 @@ var SystemName = "One API" var ServerAddress = "http://localhost:3000" var Footer = "" +var UsingSQLite = false + // Any options with "Secret", "Token" in its key won't be return by GetOptions var SessionSecret = uuid.New().String() @@ -84,6 +86,11 @@ const ( UserStatusDisabled = 2 // also don't use 0 ) +const ( + TokenStatusEnabled = 1 // don't use 0, 0 is the default value! + TokenStatusDisabled = 2 // also don't use 0 +) + const ( ChannelStatusUnknown = 0 ChannelStatusEnabled = 1 // don't use 0, 0 is the default value! @@ -100,3 +107,14 @@ const ( ChannelTypeOpenAIMax = 6 ChannelTypeOhMyGPT = 7 ) + +var ChannelHosts = []string{ + "", // 0 + "https://api.openai.com", // 1 + "https://openai.api2d.net", // 2 + "", // 3 + "https://api.openai-asia.com", // 4 + "https://api.openai-sb.com", // 5 + "https://api.openaimax.com", // 6 + "https://api.ohmygpt.com", // 7 +} diff --git a/controller/channel.go b/controller/channel.go index d6bc98d4..3a147019 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -56,7 +56,7 @@ func GetChannel(c *gin.Context) { }) return } - channel, err := model.GetChannelById(id) + channel, err := model.GetChannelById(id, false) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/relay.go b/controller/relay.go new file mode 100644 index 00000000..886e4f0d --- /dev/null +++ b/controller/relay.go @@ -0,0 +1,52 @@ +package controller + +import ( + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" +) + +func Relay(c *gin.Context) { + channelType := c.GetInt("channel") + host := common.ChannelHosts[channelType] + req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s/%s", host, c.Request.URL.String()), c.Request.Body) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "one_api_error", + }, + }) + return + } + req.Header = c.Request.Header.Clone() + client := &http.Client{} + + resp, err := client.Do(req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "one_api_error", + }, + }) + return + } + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + _, err = io.Copy(c.Writer, resp.Body) + //body, err := io.ReadAll(resp.Body) + //_, err = c.Writer.Write(body) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "one_api_error", + }, + }) + return + } +} diff --git a/middleware/auth.go b/middleware/auth.go index 427217e2..22bf9fed 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/model" + "strings" ) func authHelper(c *gin.Context, minRole int) { @@ -14,34 +15,13 @@ func authHelper(c *gin.Context, minRole int) { role := session.Get("role") id := session.Get("id") status := session.Get("status") - authByToken := false if username == nil { - // Check token - token := c.Request.Header.Get("Authorization") - if token == "" { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "无权进行此操作,未登录或 token 无效", - }) - c.Abort() - return - } - user := model.ValidateUserToken(token) - if user != nil && user.Username != "" { - // Token is valid - username = user.Username - role = user.Role - id = user.Id - status = user.Status - } else { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "无权进行此操作,token 无效", - }) - c.Abort() - return - } - authByToken = true + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无权进行此操作,未登录", + }) + c.Abort() + return } if status.(int) == common.UserStatusDisabled { c.JSON(http.StatusOK, gin.H{ @@ -62,7 +42,6 @@ func authHelper(c *gin.Context, minRole int) { c.Set("username", username) c.Set("role", role) c.Set("id", id) - c.Set("authByToken", authByToken) c.Next() } @@ -84,33 +63,25 @@ func RootAuth() func(c *gin.Context) { } } -// NoTokenAuth You should always use this after normal auth middlewares. -func NoTokenAuth() func(c *gin.Context) { +func TokenAuth() func(c *gin.Context) { return func(c *gin.Context) { - authByToken := c.GetBool("authByToken") - if authByToken { + key := c.Request.Header.Get("Authorization") + parts := strings.Split(key, "-") + key = parts[0] + token, err := model.ValidateUserToken(key) + if err != nil { c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "本接口不支持使用 token 进行验证", + "error": gin.H{ + "message": err.Error(), + "type": "one_api_error", + }, }) c.Abort() return } - c.Next() - } -} - -// TokenOnlyAuth You should always use this after normal auth middlewares. -func TokenOnlyAuth() func(c *gin.Context) { - return func(c *gin.Context) { - authByToken := c.GetBool("authByToken") - if !authByToken { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "本接口仅支持使用 token 进行验证", - }) - c.Abort() - return + c.Set("id", token.UserId) + if len(parts) > 1 { + c.Set("channelId", parts[1]) } c.Next() } diff --git a/middleware/distributor.go b/middleware/distributor.go new file mode 100644 index 00000000..08193f8f --- /dev/null +++ b/middleware/distributor.go @@ -0,0 +1,68 @@ +package middleware + +import ( + "fmt" + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" + "one-api/model" + "strconv" +) + +func Distribute() func(c *gin.Context) { + return func(c *gin.Context) { + var channel *model.Channel + channelId, ok := c.Get("channelId") + if ok { + id, err := strconv.Atoi(channelId.(string)) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "error": gin.H{ + "message": "无效的渠道 ID", + "type": "one_api_error", + }, + }) + c.Abort() + return + } + channel, err = model.GetChannelById(id, true) + if err != nil { + c.JSON(200, gin.H{ + "error": gin.H{ + "message": "无效的渠道 ID", + "type": "one_api_error", + }, + }) + c.Abort() + return + } + if channel.Status != common.ChannelStatusEnabled { + c.JSON(200, gin.H{ + "error": gin.H{ + "message": "该渠道已被禁用", + "type": "one_api_error", + }, + }) + c.Abort() + return + } + } else { + // Select a channel for the user + var err error + channel, err = model.GetRandomChannel() + if err != nil { + c.JSON(200, gin.H{ + "error": gin.H{ + "message": "无可用渠道", + "type": "one_api_error", + }, + }) + c.Abort() + return + } + } + c.Set("channel", channel.Type) + c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) + c.Next() + } +} diff --git a/model/channel.go b/model/channel.go index 7c298ee2..ceaf2710 100644 --- a/model/channel.go +++ b/model/channel.go @@ -2,12 +2,13 @@ package model import ( _ "gorm.io/driver/sqlite" + "one-api/common" ) type Channel struct { Id int `json:"id"` Type int `json:"type" gorm:"default:0"` - Key string `json:"key"` + Key string `json:"key" gorm:"not null"` Status int `json:"status" gorm:"default:1"` Name string `json:"name" gorm:"index"` Weight int `json:"weight"` @@ -27,10 +28,26 @@ func SearchChannels(keyword string) (channels []*Channel, err error) { return channels, err } -func GetChannelById(id int) (*Channel, error) { +func GetChannelById(id int, selectAll bool) (*Channel, error) { channel := Channel{Id: id} var err error = nil - err = DB.Omit("key").First(&channel, "id = ?", id).Error + if selectAll { + err = DB.First(&channel, "id = ?", id).Error + } else { + err = DB.Omit("key").First(&channel, "id = ?", id).Error + } + return &channel, err +} + +func GetRandomChannel() (*Channel, error) { + // TODO: consider weight + channel := Channel{} + var err error = nil + if common.UsingSQLite { + err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RANDOM()").Limit(1).First(&channel).Error + } else { + err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RAND()").Limit(1).First(&channel).Error + } return &channel, err } diff --git a/model/main.go b/model/main.go index 608cf6f3..8d739cf2 100644 --- a/model/main.go +++ b/model/main.go @@ -45,6 +45,7 @@ func InitDB() (err error) { }) } else { // Use SQLite + common.UsingSQLite = true db, err = gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ PrepareStmt: true, // precompile SQL }) diff --git a/model/token.go b/model/token.go index cc9e3976..5b3bed56 100644 --- a/model/token.go +++ b/model/token.go @@ -3,12 +3,14 @@ package model import ( "errors" _ "gorm.io/driver/sqlite" + "one-api/common" + "strings" ) type Token struct { Id int `json:"id"` UserId int `json:"user_id"` - Key string `json:"key"` + Key string `json:"key" gorm:"uniqueIndex"` Status int `json:"status" gorm:"default:1"` Name string `json:"name" gorm:"index" ` CreatedTime int64 `json:"created_time" gorm:"bigint"` @@ -27,6 +29,29 @@ func SearchUserTokens(userId int, keyword string) (tokens []*Token, err error) { return tokens, err } +func ValidateUserToken(key string) (token *Token, err error) { + if key == "" { + return nil, errors.New("未提供 token") + } + key = strings.Replace(key, "Bearer ", "", 1) + token = &Token{} + err = DB.Where("key = ?", key).First(token).Error + if err == nil { + if token.Status != common.TokenStatusEnabled { + return nil, errors.New("该 token 已被禁用") + } + go func() { + token.AccessedTime = common.GetTimestamp() + err := token.Update() + if err != nil { + common.SysError("更新 token 访问时间失败:" + err.Error()) + } + }() + return token, nil + } + return nil, err +} + func GetTokenByIds(id int, userId int) (*Token, error) { if id == 0 || userId == 0 { return nil, errors.New("id 或 userId 为空!") diff --git a/model/user.go b/model/user.go index d84119bc..74fc593e 100644 --- a/model/user.go +++ b/model/user.go @@ -3,7 +3,6 @@ package model import ( "errors" "one-api/common" - "strings" ) // User if you add sensitive fields, don't forget to clean them in setupLogin function. @@ -149,18 +148,6 @@ func (user *User) FillUserByUsername() error { return nil } -func ValidateUserToken(token string) (user *User) { - if token == "" { - return nil - } - token = strings.Replace(token, "Bearer ", "", 1) - user = &User{} - if DB.Where("token = ?", token).First(user).RowsAffected == 1 { - return user - } - return nil -} - func IsEmailAlreadyTaken(email string) bool { return DB.Where("email = ?", email).Find(&User{}).RowsAffected == 1 } diff --git a/router/api-router.go b/router/api-router.go index 5e34cea6..1d5cde17 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -28,7 +28,7 @@ func SetApiRouter(router *gin.Engine) { userRoute.GET("/logout", controller.Logout) selfRoute := userRoute.Group("/") - selfRoute.Use(middleware.UserAuth(), middleware.NoTokenAuth()) + selfRoute.Use(middleware.UserAuth()) { selfRoute.GET("/self", controller.GetSelf) selfRoute.PUT("/self", controller.UpdateSelf) @@ -36,7 +36,7 @@ func SetApiRouter(router *gin.Engine) { } adminRoute := userRoute.Group("/") - adminRoute.Use(middleware.AdminAuth(), middleware.NoTokenAuth()) + adminRoute.Use(middleware.AdminAuth()) { adminRoute.GET("/", controller.GetAllUsers) adminRoute.GET("/search", controller.SearchUsers) @@ -48,7 +48,7 @@ func SetApiRouter(router *gin.Engine) { } } optionRoute := apiRouter.Group("/option") - optionRoute.Use(middleware.RootAuth(), middleware.NoTokenAuth()) + optionRoute.Use(middleware.RootAuth()) { optionRoute.GET("/", controller.GetOptions) optionRoute.PUT("/", controller.UpdateOption) diff --git a/router/main.go b/router/main.go index 30b25c82..d0908560 100644 --- a/router/main.go +++ b/router/main.go @@ -7,5 +7,6 @@ import ( func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { SetApiRouter(router) + SetRelayRouter(router) setWebRouter(router, buildFS, indexPage) } diff --git a/router/relay-router.go b/router/relay-router.go new file mode 100644 index 00000000..99027a22 --- /dev/null +++ b/router/relay-router.go @@ -0,0 +1,15 @@ +package router + +import ( + "github.com/gin-gonic/gin" + "one-api/controller" + "one-api/middleware" +) + +func SetRelayRouter(router *gin.Engine) { + relayRouter := router.Group("/v1") + relayRouter.Use(middleware.GlobalAPIRateLimit(), middleware.TokenAuth(), middleware.Distribute()) + { + relayRouter.POST("/chat/completions", controller.Relay) + } +}