Relay done but not working
This commit is contained in:
parent
9fc375c604
commit
852af57c03
@ -12,6 +12,8 @@ var SystemName = "One API"
|
||||
var ServerAddress = "http://localhost:3000"
|
||||
var Footer = ""
|
||||
|
||||
var UsingSQLite = false
|
||||
|
||||
// Any options with "Secret", "Token" in its key won't be return by GetOptions
|
||||
|
||||
var SessionSecret = uuid.New().String()
|
||||
@ -84,6 +86,11 @@ const (
|
||||
UserStatusDisabled = 2 // also don't use 0
|
||||
)
|
||||
|
||||
const (
|
||||
TokenStatusEnabled = 1 // don't use 0, 0 is the default value!
|
||||
TokenStatusDisabled = 2 // also don't use 0
|
||||
)
|
||||
|
||||
const (
|
||||
ChannelStatusUnknown = 0
|
||||
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
|
||||
@ -100,3 +107,14 @@ const (
|
||||
ChannelTypeOpenAIMax = 6
|
||||
ChannelTypeOhMyGPT = 7
|
||||
)
|
||||
|
||||
var ChannelHosts = []string{
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://openai.api2d.net", // 2
|
||||
"", // 3
|
||||
"https://api.openai-asia.com", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
}
|
||||
|
@ -56,7 +56,7 @@ func GetChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
channel, err := model.GetChannelById(id)
|
||||
channel, err := model.GetChannelById(id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
52
controller/relay.go
Normal file
52
controller/relay.go
Normal file
@ -0,0 +1,52 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
channelType := c.GetInt("channel")
|
||||
host := common.ChannelHosts[channelType]
|
||||
req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s/%s", host, c.Request.URL.String()), c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
req.Header = c.Request.Header.Clone()
|
||||
client := &http.Client{}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
//body, err := io.ReadAll(resp.Body)
|
||||
//_, err = c.Writer.Write(body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func authHelper(c *gin.Context, minRole int) {
|
||||
@ -14,34 +15,13 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
role := session.Get("role")
|
||||
id := session.Get("id")
|
||||
status := session.Get("status")
|
||||
authByToken := false
|
||||
if username == nil {
|
||||
// Check token
|
||||
token := c.Request.Header.Get("Authorization")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权进行此操作,未登录或 token 无效",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
user := model.ValidateUserToken(token)
|
||||
if user != nil && user.Username != "" {
|
||||
// Token is valid
|
||||
username = user.Username
|
||||
role = user.Role
|
||||
id = user.Id
|
||||
status = user.Status
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权进行此操作,token 无效",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
authByToken = true
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权进行此操作,未登录",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if status.(int) == common.UserStatusDisabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@ -62,7 +42,6 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
c.Set("username", username)
|
||||
c.Set("role", role)
|
||||
c.Set("id", id)
|
||||
c.Set("authByToken", authByToken)
|
||||
c.Next()
|
||||
}
|
||||
|
||||
@ -84,33 +63,25 @@ func RootAuth() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// NoTokenAuth You should always use this after normal auth middlewares.
|
||||
func NoTokenAuth() func(c *gin.Context) {
|
||||
func TokenAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
authByToken := c.GetBool("authByToken")
|
||||
if authByToken {
|
||||
key := c.Request.Header.Get("Authorization")
|
||||
parts := strings.Split(key, "-")
|
||||
key = parts[0]
|
||||
token, err := model.ValidateUserToken(key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "本接口不支持使用 token 进行验证",
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// TokenOnlyAuth You should always use this after normal auth middlewares.
|
||||
func TokenOnlyAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
authByToken := c.GetBool("authByToken")
|
||||
if !authByToken {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "本接口仅支持使用 token 进行验证",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
c.Set("id", token.UserId)
|
||||
if len(parts) > 1 {
|
||||
c.Set("channelId", parts[1])
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
|
68
middleware/distributor.go
Normal file
68
middleware/distributor.go
Normal file
@ -0,0 +1,68 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func Distribute() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
var channel *model.Channel
|
||||
channelId, ok := c.Get("channelId")
|
||||
if ok {
|
||||
id, err := strconv.Atoi(channelId.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "无效的渠道 ID",
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
channel, err = model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "无效的渠道 ID",
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
c.JSON(200, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "该渠道已被禁用",
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Select a channel for the user
|
||||
var err error
|
||||
channel, err = model.GetRandomChannel()
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "无可用渠道",
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Set("channel", channel.Type)
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
c.Next()
|
||||
}
|
||||
}
|
@ -2,12 +2,13 @@ package model
|
||||
|
||||
import (
|
||||
_ "gorm.io/driver/sqlite"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
type Channel struct {
|
||||
Id int `json:"id"`
|
||||
Type int `json:"type" gorm:"default:0"`
|
||||
Key string `json:"key"`
|
||||
Key string `json:"key" gorm:"not null"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
Name string `json:"name" gorm:"index"`
|
||||
Weight int `json:"weight"`
|
||||
@ -27,10 +28,26 @@ func SearchChannels(keyword string) (channels []*Channel, err error) {
|
||||
return channels, err
|
||||
}
|
||||
|
||||
func GetChannelById(id int) (*Channel, error) {
|
||||
func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
||||
channel := Channel{Id: id}
|
||||
var err error = nil
|
||||
err = DB.Omit("key").First(&channel, "id = ?", id).Error
|
||||
if selectAll {
|
||||
err = DB.First(&channel, "id = ?", id).Error
|
||||
} else {
|
||||
err = DB.Omit("key").First(&channel, "id = ?", id).Error
|
||||
}
|
||||
return &channel, err
|
||||
}
|
||||
|
||||
func GetRandomChannel() (*Channel, error) {
|
||||
// TODO: consider weight
|
||||
channel := Channel{}
|
||||
var err error = nil
|
||||
if common.UsingSQLite {
|
||||
err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RANDOM()").Limit(1).First(&channel).Error
|
||||
} else {
|
||||
err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RAND()").Limit(1).First(&channel).Error
|
||||
}
|
||||
return &channel, err
|
||||
}
|
||||
|
||||
|
@ -45,6 +45,7 @@ func InitDB() (err error) {
|
||||
})
|
||||
} else {
|
||||
// Use SQLite
|
||||
common.UsingSQLite = true
|
||||
db, err = gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
|
||||
PrepareStmt: true, // precompile SQL
|
||||
})
|
||||
|
@ -3,12 +3,14 @@ package model
|
||||
import (
|
||||
"errors"
|
||||
_ "gorm.io/driver/sqlite"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id"`
|
||||
Key string `json:"key"`
|
||||
Key string `json:"key" gorm:"uniqueIndex"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
Name string `json:"name" gorm:"index" `
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
@ -27,6 +29,29 @@ func SearchUserTokens(userId int, keyword string) (tokens []*Token, err error) {
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
func ValidateUserToken(key string) (token *Token, err error) {
|
||||
if key == "" {
|
||||
return nil, errors.New("未提供 token")
|
||||
}
|
||||
key = strings.Replace(key, "Bearer ", "", 1)
|
||||
token = &Token{}
|
||||
err = DB.Where("key = ?", key).First(token).Error
|
||||
if err == nil {
|
||||
if token.Status != common.TokenStatusEnabled {
|
||||
return nil, errors.New("该 token 已被禁用")
|
||||
}
|
||||
go func() {
|
||||
token.AccessedTime = common.GetTimestamp()
|
||||
err := token.Update()
|
||||
if err != nil {
|
||||
common.SysError("更新 token 访问时间失败:" + err.Error())
|
||||
}
|
||||
}()
|
||||
return token, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func GetTokenByIds(id int, userId int) (*Token, error) {
|
||||
if id == 0 || userId == 0 {
|
||||
return nil, errors.New("id 或 userId 为空!")
|
||||
|
@ -3,7 +3,6 @@ package model
|
||||
import (
|
||||
"errors"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
|
||||
@ -149,18 +148,6 @@ func (user *User) FillUserByUsername() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateUserToken(token string) (user *User) {
|
||||
if token == "" {
|
||||
return nil
|
||||
}
|
||||
token = strings.Replace(token, "Bearer ", "", 1)
|
||||
user = &User{}
|
||||
if DB.Where("token = ?", token).First(user).RowsAffected == 1 {
|
||||
return user
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsEmailAlreadyTaken(email string) bool {
|
||||
return DB.Where("email = ?", email).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
userRoute.GET("/logout", controller.Logout)
|
||||
|
||||
selfRoute := userRoute.Group("/")
|
||||
selfRoute.Use(middleware.UserAuth(), middleware.NoTokenAuth())
|
||||
selfRoute.Use(middleware.UserAuth())
|
||||
{
|
||||
selfRoute.GET("/self", controller.GetSelf)
|
||||
selfRoute.PUT("/self", controller.UpdateSelf)
|
||||
@ -36,7 +36,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
}
|
||||
|
||||
adminRoute := userRoute.Group("/")
|
||||
adminRoute.Use(middleware.AdminAuth(), middleware.NoTokenAuth())
|
||||
adminRoute.Use(middleware.AdminAuth())
|
||||
{
|
||||
adminRoute.GET("/", controller.GetAllUsers)
|
||||
adminRoute.GET("/search", controller.SearchUsers)
|
||||
@ -48,7 +48,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
}
|
||||
}
|
||||
optionRoute := apiRouter.Group("/option")
|
||||
optionRoute.Use(middleware.RootAuth(), middleware.NoTokenAuth())
|
||||
optionRoute.Use(middleware.RootAuth())
|
||||
{
|
||||
optionRoute.GET("/", controller.GetOptions)
|
||||
optionRoute.PUT("/", controller.UpdateOption)
|
||||
|
@ -7,5 +7,6 @@ import (
|
||||
|
||||
func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
|
||||
SetApiRouter(router)
|
||||
SetRelayRouter(router)
|
||||
setWebRouter(router, buildFS, indexPage)
|
||||
}
|
||||
|
15
router/relay-router.go
Normal file
15
router/relay-router.go
Normal file
@ -0,0 +1,15 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/controller"
|
||||
"one-api/middleware"
|
||||
)
|
||||
|
||||
func SetRelayRouter(router *gin.Engine) {
|
||||
relayRouter := router.Group("/v1")
|
||||
relayRouter.Use(middleware.GlobalAPIRateLimit(), middleware.TokenAuth(), middleware.Distribute())
|
||||
{
|
||||
relayRouter.POST("/chat/completions", controller.Relay)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user