Merge branch 'songquanpeng:main' into main
This commit is contained in:
commit
b356750bd5
8
common/random.go
Normal file
8
common/random.go
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import "math/rand"
|
||||||
|
|
||||||
|
// RandRange returns a random number between min and max (max is not included)
|
||||||
|
func RandRange(min, max int) int {
|
||||||
|
return min + rand.Intn(max-min)
|
||||||
|
}
|
@ -19,6 +19,7 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -61,6 +62,12 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
|
|||||||
}
|
}
|
||||||
adaptor.Init(meta)
|
adaptor.Init(meta)
|
||||||
modelName := adaptor.GetModelList()[0]
|
modelName := adaptor.GetModelList()[0]
|
||||||
|
if !strings.Contains(channel.Models, modelName) {
|
||||||
|
modelNames := strings.Split(channel.Models, ",")
|
||||||
|
if len(modelNames) > 0 {
|
||||||
|
modelName = modelNames[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
request := buildTestRequest()
|
request := buildTestRequest()
|
||||||
request.Model = modelName
|
request.Model = modelName
|
||||||
meta.OriginModelName, meta.ActualModelName = modelName, modelName
|
meta.OriginModelName, meta.ActualModelName = modelName, modelName
|
||||||
|
@ -62,7 +62,7 @@ func Relay(c *gin.Context) {
|
|||||||
retryTimes = 0
|
retryTimes = 0
|
||||||
}
|
}
|
||||||
for i := retryTimes; i > 0; i-- {
|
for i := retryTimes; i > 0; i-- {
|
||||||
channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel)
|
channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err)
|
logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err)
|
||||||
break
|
break
|
||||||
|
@ -68,7 +68,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
requestModel = modelRequest.Model
|
requestModel = modelRequest.Model
|
||||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||||
if channel != nil {
|
if channel != nil {
|
||||||
|
@ -191,7 +191,7 @@ func SyncChannelCache(frequency int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) {
|
||||||
if !config.MemoryCacheEnabled {
|
if !config.MemoryCacheEnabled {
|
||||||
return GetRandomSatisfiedChannel(group, model)
|
return GetRandomSatisfiedChannel(group, model)
|
||||||
}
|
}
|
||||||
@ -213,5 +213,10 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
idx := rand.Intn(endIdx)
|
idx := rand.Intn(endIdx)
|
||||||
|
if ignoreFirstPriority {
|
||||||
|
if endIdx < len(channels) { // which means there are more than one priority
|
||||||
|
idx = common.RandRange(endIdx, len(channels))
|
||||||
|
}
|
||||||
|
}
|
||||||
return channels[idx], nil
|
return channels[idx], nil
|
||||||
}
|
}
|
||||||
|
@ -33,6 +33,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
enableSearch = true
|
enableSearch = true
|
||||||
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
|
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
|
||||||
}
|
}
|
||||||
|
if request.TopP >= 1 {
|
||||||
|
request.TopP = 0.9999
|
||||||
|
}
|
||||||
return &ChatRequest{
|
return &ChatRequest{
|
||||||
Model: aliModel,
|
Model: aliModel,
|
||||||
Input: Input{
|
Input: Input{
|
||||||
@ -42,6 +45,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
EnableSearch: enableSearch,
|
EnableSearch: enableSearch,
|
||||||
IncrementalOutput: request.Stream,
|
IncrementalOutput: request.Stream,
|
||||||
Seed: uint64(request.Seed),
|
Seed: uint64(request.Seed),
|
||||||
|
MaxTokens: request.MaxTokens,
|
||||||
|
Temperature: request.Temperature,
|
||||||
|
TopP: request.TopP,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,8 @@ type Parameters struct {
|
|||||||
Seed uint64 `json:"seed,omitempty"`
|
Seed uint64 `json:"seed,omitempty"`
|
||||||
EnableSearch bool `json:"enable_search,omitempty"`
|
EnableSearch bool `json:"enable_search,omitempty"`
|
||||||
IncrementalOutput bool `json:"incremental_output,omitempty"`
|
IncrementalOutput bool `json:"incremental_output,omitempty"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatRequest struct {
|
type ChatRequest struct {
|
||||||
|
@ -36,6 +36,8 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
|||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
||||||
case "Embedding-V1":
|
case "Embedding-V1":
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
|
||||||
|
default:
|
||||||
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + meta.ActualModelName
|
||||||
}
|
}
|
||||||
var accessToken string
|
var accessToken string
|
||||||
var err error
|
var err error
|
||||||
|
Loading…
Reference in New Issue
Block a user