Relay done but not working

This commit is contained in:
JustSong 2023-04-23 18:24:11 +08:00
parent 9fc375c604
commit 852af57c03
12 changed files with 225 additions and 70 deletions

View File

@ -12,6 +12,8 @@ var SystemName = "One API"
var ServerAddress = "http://localhost:3000"
var Footer = ""
var UsingSQLite = false
// Any options with "Secret", "Token" in its key won't be return by GetOptions
var SessionSecret = uuid.New().String()
@ -84,6 +86,11 @@ const (
UserStatusDisabled = 2 // also don't use 0
)
const (
TokenStatusEnabled = 1 // don't use 0, 0 is the default value!
TokenStatusDisabled = 2 // also don't use 0
)
const (
ChannelStatusUnknown = 0
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
@ -100,3 +107,14 @@ const (
ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7
)
var ChannelHosts = []string{
"", // 0
"https://api.openai.com", // 1
"https://openai.api2d.net", // 2
"", // 3
"https://api.openai-asia.com", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
}

View File

@ -56,7 +56,7 @@ func GetChannel(c *gin.Context) {
})
return
}
channel, err := model.GetChannelById(id)
channel, err := model.GetChannelById(id, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,

52
controller/relay.go Normal file
View File

@ -0,0 +1,52 @@
package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
)
func Relay(c *gin.Context) {
channelType := c.GetInt("channel")
host := common.ChannelHosts[channelType]
req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s/%s", host, c.Request.URL.String()), c.Request.Body)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
return
}
req.Header = c.Request.Header.Clone()
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
return
}
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
_, err = io.Copy(c.Writer, resp.Body)
//body, err := io.ReadAll(resp.Body)
//_, err = c.Writer.Write(body)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
return
}
}

View File

@ -6,6 +6,7 @@ import (
"net/http"
"one-api/common"
"one-api/model"
"strings"
)
func authHelper(c *gin.Context, minRole int) {
@ -14,34 +15,13 @@ func authHelper(c *gin.Context, minRole int) {
role := session.Get("role")
id := session.Get("id")
status := session.Get("status")
authByToken := false
if username == nil {
// Check token
token := c.Request.Header.Get("Authorization")
if token == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权进行此操作,未登录或 token 无效",
})
c.Abort()
return
}
user := model.ValidateUserToken(token)
if user != nil && user.Username != "" {
// Token is valid
username = user.Username
role = user.Role
id = user.Id
status = user.Status
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权进行此操作token 无效",
})
c.Abort()
return
}
authByToken = true
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权进行此操作,未登录",
})
c.Abort()
return
}
if status.(int) == common.UserStatusDisabled {
c.JSON(http.StatusOK, gin.H{
@ -62,7 +42,6 @@ func authHelper(c *gin.Context, minRole int) {
c.Set("username", username)
c.Set("role", role)
c.Set("id", id)
c.Set("authByToken", authByToken)
c.Next()
}
@ -84,33 +63,25 @@ func RootAuth() func(c *gin.Context) {
}
}
// NoTokenAuth You should always use this after normal auth middlewares.
func NoTokenAuth() func(c *gin.Context) {
func TokenAuth() func(c *gin.Context) {
return func(c *gin.Context) {
authByToken := c.GetBool("authByToken")
if authByToken {
key := c.Request.Header.Get("Authorization")
parts := strings.Split(key, "-")
key = parts[0]
token, err := model.ValidateUserToken(key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "本接口不支持使用 token 进行验证",
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
c.Abort()
return
}
c.Next()
}
}
// TokenOnlyAuth You should always use this after normal auth middlewares.
func TokenOnlyAuth() func(c *gin.Context) {
return func(c *gin.Context) {
authByToken := c.GetBool("authByToken")
if !authByToken {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "本接口仅支持使用 token 进行验证",
})
c.Abort()
return
c.Set("id", token.UserId)
if len(parts) > 1 {
c.Set("channelId", parts[1])
}
c.Next()
}

68
middleware/distributor.go Normal file
View File

@ -0,0 +1,68 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
)
func Distribute() func(c *gin.Context) {
return func(c *gin.Context) {
var channel *model.Channel
channelId, ok := c.Get("channelId")
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": "无效的渠道 ID",
"type": "one_api_error",
},
})
c.Abort()
return
}
channel, err = model.GetChannelById(id, true)
if err != nil {
c.JSON(200, gin.H{
"error": gin.H{
"message": "无效的渠道 ID",
"type": "one_api_error",
},
})
c.Abort()
return
}
if channel.Status != common.ChannelStatusEnabled {
c.JSON(200, gin.H{
"error": gin.H{
"message": "该渠道已被禁用",
"type": "one_api_error",
},
})
c.Abort()
return
}
} else {
// Select a channel for the user
var err error
channel, err = model.GetRandomChannel()
if err != nil {
c.JSON(200, gin.H{
"error": gin.H{
"message": "无可用渠道",
"type": "one_api_error",
},
})
c.Abort()
return
}
}
c.Set("channel", channel.Type)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Next()
}
}

View File

@ -2,12 +2,13 @@ package model
import (
_ "gorm.io/driver/sqlite"
"one-api/common"
)
type Channel struct {
Id int `json:"id"`
Type int `json:"type" gorm:"default:0"`
Key string `json:"key"`
Key string `json:"key" gorm:"not null"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"`
Weight int `json:"weight"`
@ -27,10 +28,26 @@ func SearchChannels(keyword string) (channels []*Channel, err error) {
return channels, err
}
func GetChannelById(id int) (*Channel, error) {
func GetChannelById(id int, selectAll bool) (*Channel, error) {
channel := Channel{Id: id}
var err error = nil
err = DB.Omit("key").First(&channel, "id = ?", id).Error
if selectAll {
err = DB.First(&channel, "id = ?", id).Error
} else {
err = DB.Omit("key").First(&channel, "id = ?", id).Error
}
return &channel, err
}
func GetRandomChannel() (*Channel, error) {
// TODO: consider weight
channel := Channel{}
var err error = nil
if common.UsingSQLite {
err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RANDOM()").Limit(1).First(&channel).Error
} else {
err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RAND()").Limit(1).First(&channel).Error
}
return &channel, err
}

View File

@ -45,6 +45,7 @@ func InitDB() (err error) {
})
} else {
// Use SQLite
common.UsingSQLite = true
db, err = gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
PrepareStmt: true, // precompile SQL
})

View File

@ -3,12 +3,14 @@ package model
import (
"errors"
_ "gorm.io/driver/sqlite"
"one-api/common"
"strings"
)
type Token struct {
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key"`
Key string `json:"key" gorm:"uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" `
CreatedTime int64 `json:"created_time" gorm:"bigint"`
@ -27,6 +29,29 @@ func SearchUserTokens(userId int, keyword string) (tokens []*Token, err error) {
return tokens, err
}
func ValidateUserToken(key string) (token *Token, err error) {
if key == "" {
return nil, errors.New("未提供 token")
}
key = strings.Replace(key, "Bearer ", "", 1)
token = &Token{}
err = DB.Where("key = ?", key).First(token).Error
if err == nil {
if token.Status != common.TokenStatusEnabled {
return nil, errors.New("该 token 已被禁用")
}
go func() {
token.AccessedTime = common.GetTimestamp()
err := token.Update()
if err != nil {
common.SysError("更新 token 访问时间失败:" + err.Error())
}
}()
return token, nil
}
return nil, err
}
func GetTokenByIds(id int, userId int) (*Token, error) {
if id == 0 || userId == 0 {
return nil, errors.New("id 或 userId 为空!")

View File

@ -3,7 +3,6 @@ package model
import (
"errors"
"one-api/common"
"strings"
)
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
@ -149,18 +148,6 @@ func (user *User) FillUserByUsername() error {
return nil
}
func ValidateUserToken(token string) (user *User) {
if token == "" {
return nil
}
token = strings.Replace(token, "Bearer ", "", 1)
user = &User{}
if DB.Where("token = ?", token).First(user).RowsAffected == 1 {
return user
}
return nil
}
func IsEmailAlreadyTaken(email string) bool {
return DB.Where("email = ?", email).Find(&User{}).RowsAffected == 1
}

View File

@ -28,7 +28,7 @@ func SetApiRouter(router *gin.Engine) {
userRoute.GET("/logout", controller.Logout)
selfRoute := userRoute.Group("/")
selfRoute.Use(middleware.UserAuth(), middleware.NoTokenAuth())
selfRoute.Use(middleware.UserAuth())
{
selfRoute.GET("/self", controller.GetSelf)
selfRoute.PUT("/self", controller.UpdateSelf)
@ -36,7 +36,7 @@ func SetApiRouter(router *gin.Engine) {
}
adminRoute := userRoute.Group("/")
adminRoute.Use(middleware.AdminAuth(), middleware.NoTokenAuth())
adminRoute.Use(middleware.AdminAuth())
{
adminRoute.GET("/", controller.GetAllUsers)
adminRoute.GET("/search", controller.SearchUsers)
@ -48,7 +48,7 @@ func SetApiRouter(router *gin.Engine) {
}
}
optionRoute := apiRouter.Group("/option")
optionRoute.Use(middleware.RootAuth(), middleware.NoTokenAuth())
optionRoute.Use(middleware.RootAuth())
{
optionRoute.GET("/", controller.GetOptions)
optionRoute.PUT("/", controller.UpdateOption)

View File

@ -7,5 +7,6 @@ import (
func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
SetApiRouter(router)
SetRelayRouter(router)
setWebRouter(router, buildFS, indexPage)
}

15
router/relay-router.go Normal file
View File

@ -0,0 +1,15 @@
package router
import (
"github.com/gin-gonic/gin"
"one-api/controller"
"one-api/middleware"
)
func SetRelayRouter(router *gin.Engine) {
relayRouter := router.Group("/v1")
relayRouter.Use(middleware.GlobalAPIRateLimit(), middleware.TokenAuth(), middleware.Distribute())
{
relayRouter.POST("/chat/completions", controller.Relay)
}
}