diff --git a/common/random.go b/common/random.go new file mode 100644 index 00000000..44bd2856 --- /dev/null +++ b/common/random.go @@ -0,0 +1,8 @@ +package common + +import "math/rand" + +// RandRange returns a random number between min and max (max is not included) +func RandRange(min, max int) int { + return min + rand.Intn(max-min) +} diff --git a/controller/channel-test.go b/controller/channel-test.go index 485d7702..7007e205 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -19,6 +19,7 @@ import ( "net/http/httptest" "net/url" "strconv" + "strings" "sync" "time" @@ -61,6 +62,12 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error } adaptor.Init(meta) modelName := adaptor.GetModelList()[0] + if !strings.Contains(channel.Models, modelName) { + modelNames := strings.Split(channel.Models, ",") + if len(modelNames) > 0 { + modelName = modelNames[0] + } + } request := buildTestRequest() request.Model = modelName meta.OriginModelName, meta.ActualModelName = modelName, modelName diff --git a/controller/relay.go b/controller/relay.go index 278c0b32..9b2d462c 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -62,7 +62,7 @@ func Relay(c *gin.Context) { retryTimes = 0 } for i := retryTimes; i > 0; i-- { - channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel) + channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes) if err != nil { logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err) break diff --git a/middleware/distributor.go b/middleware/distributor.go index aeb2796a..e845c2f8 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -68,7 +68,7 @@ func Distribute() func(c *gin.Context) { } } requestModel = modelRequest.Model - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false) if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) if channel != nil { diff --git a/model/cache.go b/model/cache.go index 04a60348..3c3575b8 100644 --- a/model/cache.go +++ b/model/cache.go @@ -191,7 +191,7 @@ func SyncChannelCache(frequency int) { } } -func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { +func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { if !config.MemoryCacheEnabled { return GetRandomSatisfiedChannel(group, model) } @@ -213,5 +213,10 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error } } idx := rand.Intn(endIdx) + if ignoreFirstPriority { + if endIdx < len(channels) { // which means there are more than one priority + idx = common.RandRange(endIdx, len(channels)) + } + } return channels[idx], nil } diff --git a/relay/channel/ali/main.go b/relay/channel/ali/main.go index b9625584..62115d58 100644 --- a/relay/channel/ali/main.go +++ b/relay/channel/ali/main.go @@ -33,6 +33,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { enableSearch = true aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) } + if request.TopP >= 1 { + request.TopP = 0.9999 + } return &ChatRequest{ Model: aliModel, Input: Input{ @@ -42,6 +45,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { EnableSearch: enableSearch, IncrementalOutput: request.Stream, Seed: uint64(request.Seed), + MaxTokens: request.MaxTokens, + Temperature: request.Temperature, + TopP: request.TopP, }, } } diff --git a/relay/channel/ali/model.go b/relay/channel/ali/model.go index 54f13041..76e814d1 100644 --- a/relay/channel/ali/model.go +++ b/relay/channel/ali/model.go @@ -16,6 +16,8 @@ type Parameters struct { Seed uint64 `json:"seed,omitempty"` EnableSearch bool `json:"enable_search,omitempty"` IncrementalOutput bool `json:"incremental_output,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` } type ChatRequest struct { diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index d2d06ce0..066a8107 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -36,6 +36,8 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" case "Embedding-V1": fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" + default: + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + meta.ActualModelName } var accessToken string var err error