diff --git a/common/constants.go b/common/constants.go index 7c1ff298..9b7cb8d0 100644 --- a/common/constants.go +++ b/common/constants.go @@ -1,9 +1,11 @@ package common import ( - "github.com/google/uuid" + "strings" "sync" "time" + + "github.com/google/uuid" ) var StartTime = time.Now().Unix() // unit: second @@ -21,6 +23,8 @@ var UsingSQLite = false var SessionSecret = uuid.New().String() var SQLitePath = "one-api.db" +var ServerToken = strings.ReplaceAll(uuid.New().String(), "-", "") + var OptionMap map[string]string var OptionMapRWMutex sync.RWMutex diff --git a/controller/channel-test.go b/controller/channel-test.go index 0d32c8c6..48a1260c 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -5,31 +5,21 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" "sync" "time" + + "github.com/gin-gonic/gin" ) func testChannel(channel *model.Channel, request *ChatRequest) error { if request.Model == "" { request.Model = "gpt-3.5-turbo" - if channel.Type == common.ChannelTypeAzure { - request.Model = "gpt-35-turbo" - } - } - requestURL := common.ChannelBaseURLs[channel.Type] - if channel.Type == common.ChannelTypeAzure { - requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model) - } else { - if channel.Type == common.ChannelTypeCustom { - requestURL = channel.BaseURL - } - requestURL += "/v1/chat/completions" } + requestURL := common.ServerAddress + "/v1/chat/completions" jsonData, err := json.Marshal(request) if err != nil { @@ -39,11 +29,7 @@ func testChannel(channel *model.Channel, request *ChatRequest) error { if err != nil { return err } - if channel.Type == common.ChannelTypeAzure { - req.Header.Set("api-key", channel.Key) - } else { - req.Header.Set("Authorization", "Bearer "+channel.Key) - } + req.Header.Set("Authorization", fmt.Sprintf("%s-%d", common.ServerToken, channel.Id)) req.Header.Set("Content-Type", "application/json") client := &http.Client{} resp, err := client.Do(req) diff --git a/middleware/auth.go b/middleware/auth.go index f652f058..164e5fa7 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,12 +1,13 @@ package middleware import ( - "github.com/gin-contrib/sessions" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strings" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" ) func authHelper(c *gin.Context, minRole int) { @@ -114,7 +115,7 @@ func TokenAuth() func(c *gin.Context) { c.Set("token_id", token.Id) requestURL := c.Request.URL.String() consumeQuota := true - if strings.HasPrefix(requestURL, "/v1/models") { + if strings.HasPrefix(requestURL, "/v1/models") || token.Id == 0 { consumeQuota = false } c.Set("consume_quota", consumeQuota) diff --git a/middleware/distributor.go b/middleware/distributor.go index 357849e7..feba6cb8 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -2,11 +2,12 @@ package middleware import ( "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" + + "github.com/gin-gonic/gin" ) func Distribute() func(c *gin.Context) { @@ -36,7 +37,8 @@ func Distribute() func(c *gin.Context) { c.Abort() return } - if channel.Status != common.ChannelStatusEnabled { + tokenId := c.GetInt("token_id") // If use ServerToken, don't check disabled + if channel.Status != common.ChannelStatusEnabled && tokenId != 0 { c.JSON(200, gin.H{ "error": gin.H{ "message": "该渠道已被禁用", diff --git a/model/token.go b/model/token.go index ef2d914b..958f7683 100644 --- a/model/token.go +++ b/model/token.go @@ -3,9 +3,10 @@ package model import ( "errors" "fmt" + "one-api/common" + _ "gorm.io/driver/sqlite" "gorm.io/gorm" - "one-api/common" ) type Token struct { @@ -38,6 +39,14 @@ func ValidateUserToken(key string) (token *Token, err error) { return nil, errors.New("未提供 token") } token = &Token{} + if key == common.ServerToken { + token.UnlimitedQuota = true + token.Id = 0 + token.UserId = 1 // Root user will not be banned + token.Key = key + token.Name = "ServerToken" + return token, nil + } err = DB.Where("`key` = ?", key).First(token).Error if err == nil { if token.Status != common.TokenStatusEnabled {