Merge branch 'songquanpeng:main' into main

This commit is contained in:
qingfengfenga 2024-03-05 17:55:12 +08:00 committed by GitHub
commit b356750bd5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 33 additions and 3 deletions

8
common/random.go Normal file
View File

@ -0,0 +1,8 @@
package common
import "math/rand"
// RandRange returns a random number between min and max (max is not included)
func RandRange(min, max int) int {
return min + rand.Intn(max-min)
}

View File

@ -19,6 +19,7 @@ import (
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
@ -61,6 +62,12 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
} }
adaptor.Init(meta) adaptor.Init(meta)
modelName := adaptor.GetModelList()[0] modelName := adaptor.GetModelList()[0]
if !strings.Contains(channel.Models, modelName) {
modelNames := strings.Split(channel.Models, ",")
if len(modelNames) > 0 {
modelName = modelNames[0]
}
}
request := buildTestRequest() request := buildTestRequest()
request.Model = modelName request.Model = modelName
meta.OriginModelName, meta.ActualModelName = modelName, modelName meta.OriginModelName, meta.ActualModelName = modelName, modelName

View File

@ -62,7 +62,7 @@ func Relay(c *gin.Context) {
retryTimes = 0 retryTimes = 0
} }
for i := retryTimes; i > 0; i-- { for i := retryTimes; i > 0; i-- {
channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel) channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes)
if err != nil { if err != nil {
logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err) logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err)
break break

View File

@ -68,7 +68,7 @@ func Distribute() func(c *gin.Context) {
} }
} }
requestModel = modelRequest.Model requestModel = modelRequest.Model
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false)
if err != nil { if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
if channel != nil { if channel != nil {

View File

@ -191,7 +191,7 @@ func SyncChannelCache(frequency int) {
} }
} }
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) {
if !config.MemoryCacheEnabled { if !config.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model) return GetRandomSatisfiedChannel(group, model)
} }
@ -213,5 +213,10 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
} }
} }
idx := rand.Intn(endIdx) idx := rand.Intn(endIdx)
if ignoreFirstPriority {
if endIdx < len(channels) { // which means there are more than one priority
idx = common.RandRange(endIdx, len(channels))
}
}
return channels[idx], nil return channels[idx], nil
} }

View File

@ -33,6 +33,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
enableSearch = true enableSearch = true
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
} }
if request.TopP >= 1 {
request.TopP = 0.9999
}
return &ChatRequest{ return &ChatRequest{
Model: aliModel, Model: aliModel,
Input: Input{ Input: Input{
@ -42,6 +45,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
EnableSearch: enableSearch, EnableSearch: enableSearch,
IncrementalOutput: request.Stream, IncrementalOutput: request.Stream,
Seed: uint64(request.Seed), Seed: uint64(request.Seed),
MaxTokens: request.MaxTokens,
Temperature: request.Temperature,
TopP: request.TopP,
}, },
} }
} }

View File

@ -16,6 +16,8 @@ type Parameters struct {
Seed uint64 `json:"seed,omitempty"` Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"` EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"` IncrementalOutput bool `json:"incremental_output,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
} }
type ChatRequest struct { type ChatRequest struct {

View File

@ -36,6 +36,8 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1": case "Embedding-V1":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
default:
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + meta.ActualModelName
} }
var accessToken string var accessToken string
var err error var err error