From 880e12c85573e582cbe554723cfbe9c7a22bc1bc Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 6 Apr 2024 00:30:08 +0800 Subject: [PATCH] feat: support cogview-3 --- common/model-ratio.go | 21 +++++++-------- relay/channel/openai/image.go | 44 ++++++++++++++++++++++++++++++++ relay/channel/openai/main.go | 34 ------------------------ relay/channel/zhipu/adaptor.go | 25 ++++++++++++++---- relay/channel/zhipu/constants.go | 1 + relay/channel/zhipu/model.go | 6 +++++ relay/constant/image.go | 36 ++++++++++++++++---------- relay/controller/helper.go | 25 +++++++++++++----- relay/controller/image.go | 8 +++--- 9 files changed, 127 insertions(+), 73 deletions(-) create mode 100644 relay/channel/openai/image.go diff --git a/common/model-ratio.go b/common/model-ratio.go index c8a2b5b8..d226e954 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -62,8 +62,8 @@ var ModelRatio = map[string]float64{ "text-search-ada-doc-001": 10, "text-moderation-stable": 0.1, "text-moderation-latest": 0.1, - "dall-e-2": 8, // $0.016 - $0.020 / image - "dall-e-3": 20, // $0.040 - $0.120 / image + "dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image + "dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image // https://www.anthropic.com/api#pricing "claude-instant-1.2": 0.8 / 1000 * USD, "claude-2.0": 8.0 / 1000 * USD, @@ -96,14 +96,15 @@ var ModelRatio = map[string]float64{ "gemini-1.0-pro-001": 1, "gemini-1.5-pro": 1, // https://open.bigmodel.cn/pricing - "glm-4": 0.1 * RMB, - "glm-4v": 0.1 * RMB, - "glm-3-turbo": 0.005 * RMB, - "embedding-2": 0.0005 * RMB, - "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens - "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens - "chatglm_std": 0.3572, // ¥0.005 / 1k tokens - "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens + "glm-4": 0.1 * RMB, + "glm-4v": 0.1 * RMB, + "glm-3-turbo": 0.005 * RMB, + "embedding-2": 0.0005 * RMB, + "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens + "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens + "chatglm_std": 0.3572, // ¥0.005 / 1k tokens + "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens + "cogview-3": 0.25 * RMB, // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens "qwen-plus": 1.4286, // ¥0.02 / 1k tokens diff --git a/relay/channel/openai/image.go b/relay/channel/openai/image.go new file mode 100644 index 00000000..0f89618a --- /dev/null +++ b/relay/channel/openai/image.go @@ -0,0 +1,44 @@ +package openai + +import ( + "bytes" + "encoding/json" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var imageResponse ImageResponse + responseBody, err := io.ReadAll(resp.Body) + + if err != nil { + return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &imageResponse) + if err != nil { + return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, nil +} diff --git a/relay/channel/openai/main.go b/relay/channel/openai/main.go index 7ace3f63..63cb9ae8 100644 --- a/relay/channel/openai/main.go +++ b/relay/channel/openai/main.go @@ -149,37 +149,3 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st } return nil, &textResponse.Usage } - -func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { - var imageResponse ImageResponse - responseBody, err := io.ReadAll(resp.Body) - - if err != nil { - return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = json.Unmarshal(responseBody, &imageResponse) - if err != nil { - return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - return nil, nil -} diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index dbcf240d..14c581dd 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -32,13 +32,16 @@ func (a *Adaptor) SetVersionByModeName(modelName string) { } func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { + switch meta.Mode { + case constant.RelayModeImagesGenerations: + return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil + case constant.RelayModeEmbeddings: + return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil + } a.SetVersionByModeName(meta.ActualModelName) if a.APIVersion == "v4" { return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil } - if meta.Mode == constant.RelayModeEmbeddings { - return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil - } method := "invoke" if meta.IsStream { method = "sse-invoke" @@ -81,7 +84,12 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) if request == nil { return nil, errors.New("request is nil") } - return request, nil + newRequest := ImageRequest{ + Model: request.Model, + Prompt: request.Prompt, + UserId: request.User, + } + return newRequest, nil } func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { @@ -98,10 +106,17 @@ func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *util.R } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + switch meta.Mode { + case constant.RelayModeEmbeddings: + err, usage = EmbeddingsHandler(c, resp) + return + case constant.RelayModeImagesGenerations: + err, usage = openai.ImageHandler(c, resp) + return + } if a.APIVersion == "v4" { return a.DoResponseV4(c, resp, meta) } - if meta.IsStream { err, usage = StreamHandler(c, resp) } else { diff --git a/relay/channel/zhipu/constants.go b/relay/channel/zhipu/constants.go index 2daeb19c..e1192123 100644 --- a/relay/channel/zhipu/constants.go +++ b/relay/channel/zhipu/constants.go @@ -3,4 +3,5 @@ package zhipu var ModelList = []string{ "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite", "glm-4", "glm-4v", "glm-3-turbo", "embedding-2", + "cogview-3", } diff --git a/relay/channel/zhipu/model.go b/relay/channel/zhipu/model.go index 4f308ab3..f91de1dc 100644 --- a/relay/channel/zhipu/model.go +++ b/relay/channel/zhipu/model.go @@ -62,3 +62,9 @@ type EmbeddingData struct { Object string `json:"object"` Embedding []float64 `json:"embedding"` } + +type ImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + UserId string `json:"user_id,omitempty"` +} diff --git a/relay/constant/image.go b/relay/constant/image.go index 4f61bcc3..e3e3260a 100644 --- a/relay/constant/image.go +++ b/relay/constant/image.go @@ -1,6 +1,6 @@ package constant -var DalleSizeRatios = map[string]map[string]float64{ +var ImageSizeRatios = map[string]map[string]float64{ "dall-e-2": { "256x256": 1, "512x512": 1.125, @@ -11,7 +11,14 @@ var DalleSizeRatios = map[string]map[string]float64{ "1024x1792": 2, "1792x1024": 2, }, - "stable-diffusion-xl": { + "ali-stable-diffusion-xl": { + "512x1024": 1, + "1024x768": 1, + "1024x1024": 1, + "576x1024": 1, + "1024x576": 1, + }, + "ali-stable-diffusion-v1.5": { "512x1024": 1, "1024x768": 1, "1024x1024": 1, @@ -25,17 +32,20 @@ var DalleSizeRatios = map[string]map[string]float64{ }, } -var DalleGenerationImageAmounts = map[string][2]int{ - "dall-e-2": {1, 10}, - "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. - "stable-diffusion-xl": {1, 4}, // Ali - "wanx-v1": {1, 4}, // Ali +var ImageGenerationAmounts = map[string][2]int{ + "dall-e-2": {1, 10}, + "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. + "ali-stable-diffusion-xl": {1, 4}, // Ali + "ali-stable-diffusion-v1.5": {1, 4}, // Ali + "wanx-v1": {1, 4}, // Ali + "cogview-3": {1, 1}, } -var DalleImagePromptLengthLimitations = map[string]int{ - "dall-e-2": 1000, - "dall-e-3": 4000, - "stable-diffusion-xl": 4000, - "wanx-v1": 4000, - "cogview-3": 833, +var ImagePromptLengthLimitations = map[string]int{ + "dall-e-2": 1000, + "dall-e-3": 4000, + "ali-stable-diffusion-xl": 4000, + "ali-stable-diffusion-v1.5": 4000, + "wanx-v1": 4000, + "cogview-3": 833, } diff --git a/relay/controller/helper.go b/relay/controller/helper.go index c78d22c7..d591984e 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -54,9 +54,25 @@ func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, e return imageRequest, nil } +func isValidImageSize(model string, size string) bool { + if model == "cogview-3" { + return true + } + _, ok := constant.ImageSizeRatios[model][size] + return ok +} + +func getImageSizeRatio(model string, size string) float64 { + ratio, ok := constant.ImageSizeRatios[model][size] + if !ok { + return 1 + } + return ratio +} + func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode { // model validation - _, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] + hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size) if !hasValidSize { return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) } @@ -64,7 +80,7 @@ func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *util.Rela if imageRequest.Prompt == "" { return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) } - if len(imageRequest.Prompt) > constant.DalleImagePromptLengthLimitations[imageRequest.Model] { + if len(imageRequest.Prompt) > constant.ImagePromptLengthLimitations[imageRequest.Model] { return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) } // Number of generated images validation @@ -81,10 +97,7 @@ func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { if imageRequest == nil { return 0, errors.New("imageRequest is nil") } - imageCostRatio, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] - if !hasValidSize { - return 0, fmt.Errorf("size not supported for this image model: %s", imageRequest.Size) - } + imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size) if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { if imageRequest.Size == "1024x1024" { imageCostRatio *= 2 diff --git a/relay/controller/image.go b/relay/controller/image.go index 9d614300..ee0c4495 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -20,12 +20,11 @@ import ( ) func isWithinRange(element string, value int) bool { - if _, ok := constant.DalleGenerationImageAmounts[element]; !ok { + if _, ok := constant.ImageGenerationAmounts[element]; !ok { return false } - min := constant.DalleGenerationImageAmounts[element][0] - max := constant.DalleGenerationImageAmounts[element][1] - + min := constant.ImageGenerationAmounts[element][0] + max := constant.ImageGenerationAmounts[element][1] return value >= min && value <= max } @@ -81,7 +80,6 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus if err != nil { return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) } - jsonStr, err := json.Marshal(finalRequest) if err != nil { return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)