fix: support embedding models for doubao (#1662)

Fixes #1594
This commit is contained in:
igophper 2024-07-22 22:38:50 +08:00 committed by GitHub
parent 2a892c1937
commit 39383e5532
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 27 deletions

View File

@ -7,8 +7,12 @@ import (
) )
func GetRequestURL(meta *meta.Meta) (string, error) { func GetRequestURL(meta *meta.Meta) (string, error) {
if meta.Mode == relaymode.ChatCompletions { switch meta.Mode {
case relaymode.ChatCompletions:
return fmt.Sprintf("%s/api/v3/chat/completions", meta.BaseURL), nil return fmt.Sprintf("%s/api/v3/chat/completions", meta.BaseURL), nil
case relaymode.Embeddings:
return fmt.Sprintf("%s/api/v3/embeddings", meta.BaseURL), nil
default:
} }
return "", fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode) return "", fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode)
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/apitype" "github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/billing" "github.com/songquanpeng/one-api/relay/billing"
@ -31,9 +32,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
meta.IsStream = textRequest.Stream meta.IsStream = textRequest.Stream
// map model name // map model name
var isModelMapped bool
meta.OriginModelName = textRequest.Model meta.OriginModelName = textRequest.Model
textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping) textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model meta.ActualModelName = textRequest.Model
// get model ratio & group ratio // get model ratio & group ratio
modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType) modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
@ -55,30 +55,9 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
adaptor.Init(meta) adaptor.Init(meta)
// get request body // get request body
var requestBody io.Reader requestBody, err := getRequestBody(c, meta, textRequest, adaptor)
if meta.APIType == apitype.OpenAI { if err != nil {
// no need to convert request for openai return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
shouldResetRequestBody := isModelMapped || meta.ChannelType == channeltype.Baichuan // frequency_penalty 0 is not acceptable for baichuan
if shouldResetRequestBody {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
} else {
convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
if err != nil {
return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
logger.Debugf(ctx, "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData)
} }
// do request // do request
@ -103,3 +82,26 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio) go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
return nil return nil
} }
func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
if meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan {
// no need to convert request for openai
return c.Request.Body, nil
}
// get request body
var requestBody io.Reader
convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
if err != nil {
logger.Debugf(c.Request.Context(), "converted request failed: %s\n", err.Error())
return nil, err
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
logger.Debugf(c.Request.Context(), "converted request json_marshal_failed: %s\n", err.Error())
return nil, err
}
logger.Debugf(c.Request.Context(), "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData)
return requestBody, nil
}