From 40b7d0f03e2f7f676dae638b0009ee7fc58fbc29 Mon Sep 17 00:00:00 2001 From: mo Date: Fri, 29 Mar 2024 15:59:19 +0800 Subject: [PATCH] Support Ali stable-diffusion-xl and wanx-v1 model --- common/model-ratio.go | 2 + relay/channel/aiproxy/adaptor.go | 7 ++ relay/channel/ali/adaptor.go | 27 ++++- relay/channel/ali/main.go | 185 +++++++++++++++++++++++++++++ relay/channel/ali/model.go | 73 ++++++++++++ relay/channel/anthropic/adaptor.go | 7 ++ relay/channel/baidu/adaptor.go | 7 ++ relay/channel/gemini/adaptor.go | 7 ++ relay/channel/interface.go | 1 + relay/channel/ollama/adaptor.go | 7 ++ relay/channel/openai/adaptor.go | 23 +++- relay/channel/openai/main.go | 34 ++++++ relay/channel/openai/model.go | 13 +- relay/channel/palm/adaptor.go | 7 ++ relay/channel/tencent/adaptor.go | 7 ++ relay/channel/xunfei/adaptor.go | 7 ++ relay/channel/zhipu/adaptor.go | 7 ++ relay/constant/image.go | 17 +++ relay/controller/helper.go | 8 +- relay/controller/image.go | 98 +++++---------- relay/model/image.go | 12 ++ 21 files changed, 478 insertions(+), 78 deletions(-) create mode 100644 relay/model/image.go diff --git a/common/model-ratio.go b/common/model-ratio.go index 460d4843..3c615ed8 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -100,6 +100,8 @@ var ModelRatio = map[string]float64{ "qwen-max": 1.4286, // ¥0.02 / 1k tokens "qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens + "stable-diffusion-xl": 0, // todo + "wanx-v1": 0, // todo "SparkDesk": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens diff --git a/relay/channel/aiproxy/adaptor.go b/relay/channel/aiproxy/adaptor.go index 2b4e3022..386c5ba4 100644 --- a/relay/channel/aiproxy/adaptor.go +++ b/relay/channel/aiproxy/adaptor.go @@ -38,6 +38,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return aiProxyLibraryRequest, nil } +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { return channel.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 902b0fc1..f04bf56f 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -23,10 +23,16 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { } func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { - fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL) - if meta.Mode == constant.RelayModeEmbeddings { + fullRequestURL := "" + switch meta.Mode { + case constant.RelayModeEmbeddings: fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL) + case constant.RelayModeImagesGenerations: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL) + default: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL) } + return fullRequestURL, nil } @@ -34,10 +40,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *ut channel.SetupCommonRequestHeader(c, req, meta) if meta.IsStream { req.Header.Set("Accept", "text/event-stream") + req.Header.Set("X-DashScope-SSE", "enable") } req.Header.Set("Authorization", "Bearer "+meta.APIKey) - if meta.IsStream { - req.Header.Set("X-DashScope-SSE", "enable") + + if meta.Mode == constant.RelayModeImagesGenerations { + req.Header.Set("X-DashScope-Async", "enable") } if c.GetString(common.ConfigKeyPlugin) != "" { req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin)) @@ -59,6 +67,15 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } } +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + aliRequest := ConvertImageRequest(*request) + return aliRequest, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { return channel.DoRequestHelper(a, c, meta, requestBody) } @@ -70,6 +87,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel switch meta.Mode { case constant.RelayModeEmbeddings: err, usage = EmbeddingHandler(c, resp) + case constant.RelayModeImagesGenerations: + err, usage = ImageHandler(c, resp) default: err, usage = Handler(c, resp) } diff --git a/relay/channel/ali/main.go b/relay/channel/ali/main.go index 62115d58..df28f82b 100644 --- a/relay/channel/ali/main.go +++ b/relay/channel/ali/main.go @@ -2,7 +2,10 @@ package ali import ( "bufio" + "encoding/base64" "encoding/json" + "errors" + "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" @@ -12,6 +15,7 @@ import ( "io" "net/http" "strings" + "time" ) // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r @@ -63,6 +67,17 @@ func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingReque } } +func ConvertImageRequest(request model.ImageRequest) *ImageRequest { + var imageRequest ImageRequest + imageRequest.Input.Prompt = request.Prompt + imageRequest.Model = request.Model + imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1) + imageRequest.Parameters.N = request.N + imageRequest.ResponseFormat = request.ResponseFormat + + return &imageRequest +} + func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var aliResponse EmbeddingResponse err := json.NewDecoder(resp.Body).Decode(&aliResponse) @@ -261,3 +276,173 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * _, err = c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage } + +func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + responseFormat := c.GetString("response_format") + + var aliTaskResponse TaskResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &aliTaskResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + if aliTaskResponse.Message != "" { + logger.SysError("aliAsyncTask err: " + string(responseBody)) + return openai.ErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil + } + + aliResponse, err, _ := asyncTaskWait(aliTaskResponse.Output.TaskId, apiKey) + if err != nil { + return openai.ErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil + } + + if aliResponse.Output.TaskStatus != "SUCCEEDED" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: aliResponse.Output.Message, + Type: "ali_error", + Param: "", + Code: aliResponse.Output.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := responseAli2OpenAIImage(aliResponse, responseFormat) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func asyncTask(taskID string, key string) (*TaskResponse, error, []byte) { + url := fmt.Sprintf("https://dashscope.aliyuncs.com/api/v1/tasks/%s", taskID) + + var aliResponse TaskResponse + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return &aliResponse, err, nil + } + + req.Header.Set("Authorization", "Bearer "+key) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + logger.SysError("aliAsyncTask client.Do err: " + err.Error()) + return &aliResponse, err, nil + } + defer resp.Body.Close() + + responseBody, err := io.ReadAll(resp.Body) + + var response TaskResponse + err = json.Unmarshal(responseBody, &response) + if err != nil { + logger.SysError("aliAsyncTask NewDecoder err: " + err.Error()) + return &aliResponse, err, nil + } + + return &response, nil, responseBody +} + +func asyncTaskWait(taskID string, key string) (*TaskResponse, error, []byte) { + waitSeconds := 2 + step := 0 + + var taskResponse TaskResponse + var responseBody []byte + + for { + step++ + rsp, err, body := asyncTask(taskID, key) + responseBody = body + if err != nil { + return &taskResponse, err, responseBody + } + + if rsp.Output.TaskStatus == "" { + return &taskResponse, nil, responseBody + } + + switch rsp.Output.TaskStatus { + case "FAILED": + fallthrough + case "CANCELED": + fallthrough + case "SUCCEEDED": + fallthrough + case "UNKNOWN": + return rsp, nil, responseBody + } + + time.Sleep(time.Duration(waitSeconds) * time.Second) + } + + return &taskResponse, nil, responseBody +} + +func responseAli2OpenAIImage(response *TaskResponse, responseFormat string) *openai.ImageResponse { + imageRespones := openai.ImageResponse{} + + for _, data := range response.Output.Results { + var b64Json string + if responseFormat == "b64_json" { + // 读取 data.Url 的图片数据并转存到 b64Json + imageData, err := getImageData(data.Url) + if err != nil { + // 处理获取图片数据失败的情况 + logger.SysError("getImageData Error getting image data: " + err.Error()) + continue + } + + // 将图片数据转为 Base64 编码的字符串 + b64Json = Base64Encode(imageData) + } else { + // 如果 responseFormat 不是 "b64_json",则直接使用 data.B64Image + b64Json = data.B64Image + } + + imageRespones.Data = append(imageRespones.Data, openai.ImageData{ + Url: data.Url, + B64Json: b64Json, + RevisedPrompt: "", + }) + } + return &imageRespones +} + +func getImageData(url string) ([]byte, error) { + response, err := http.Get(url) + if err != nil { + return nil, err + } + defer response.Body.Close() + + imageData, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + + return imageData, nil +} + +func Base64Encode(data []byte) string { + b64Json := base64.StdEncoding.EncodeToString(data) + return b64Json +} diff --git a/relay/channel/ali/model.go b/relay/channel/ali/model.go index 76e814d1..34515b16 100644 --- a/relay/channel/ali/model.go +++ b/relay/channel/ali/model.go @@ -26,6 +26,79 @@ type ChatRequest struct { Parameters Parameters `json:"parameters,omitempty"` } +type ImageRequest struct { + Model string `json:"model"` + Input struct { + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt,omitempty"` + } `json:"input"` + Parameters struct { + Size string `json:"size,omitempty"` + N int `json:"n,omitempty"` + Steps string `json:"steps,omitempty"` + Scale string `json:"scale,omitempty"` + } `json:"parameters,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` +} + +type TaskResponse struct { + StatusCode int `json:"status_code,omitempty"` + RequestId string `json:"request_id,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Output struct { + TaskId string `json:"task_id,omitempty"` + TaskStatus string `json:"task_status,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Results []struct { + B64Image string `json:"b64_image,omitempty"` + Url string `json:"url,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + } `json:"results,omitempty"` + TaskMetrics struct { + Total int `json:"TOTAL,omitempty"` + Succeeded int `json:"SUCCEEDED,omitempty"` + Failed int `json:"FAILED,omitempty"` + } `json:"task_metrics,omitempty"` + } `json:"output,omitempty"` + Usage Usage `json:"usage"` +} + +type Header struct { + Action string `json:"action,omitempty"` + Streaming string `json:"streaming,omitempty"` + TaskID string `json:"task_id,omitempty"` + Event string `json:"event,omitempty"` + ErrorCode string `json:"error_code,omitempty"` + ErrorMessage string `json:"error_message,omitempty"` + Attributes any `json:"attributes,omitempty"` +} + +type Payload struct { + Model string `json:"model,omitempty"` + Task string `json:"task,omitempty"` + TaskGroup string `json:"task_group,omitempty"` + Function string `json:"function,omitempty"` + Parameters struct { + SampleRate int `json:"sample_rate,omitempty"` + Rate float64 `json:"rate,omitempty"` + Format string `json:"format,omitempty"` + } `json:"parameters,omitempty"` + Input struct { + Text string `json:"text,omitempty"` + } `json:"input,omitempty"` + Usage struct { + Characters int `json:"characters,omitempty"` + } `json:"usage,omitempty"` +} + +type WSSMessage struct { + Header Header `json:"header,omitempty"` + Payload Payload `json:"payload,omitempty"` +} + type EmbeddingRequest struct { Model string `json:"model"` Input struct { diff --git a/relay/channel/anthropic/adaptor.go b/relay/channel/anthropic/adaptor.go index a165b35c..12e01c52 100644 --- a/relay/channel/anthropic/adaptor.go +++ b/relay/channel/anthropic/adaptor.go @@ -41,6 +41,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { return channel.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 2d2e24f6..471f2dc5 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -91,6 +91,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } } +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { return channel.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index f3305e5d..30002c53 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -42,6 +42,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { return channelhelper.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/channel/interface.go b/relay/channel/interface.go index e25db83f..78d6ace1 100644 --- a/relay/channel/interface.go +++ b/relay/channel/interface.go @@ -13,6 +13,7 @@ type Adaptor interface { GetRequestURL(meta *util.RelayMeta) (string, error) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) + ConvertImageRequest(request *model.ImageRequest) (any, error) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) GetModelList() []string diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index e2ae7d2b..33635c5c 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -48,6 +48,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } } +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { return channel.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 1f153c3e..9cd98382 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -7,6 +7,7 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel/minimax" + "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" @@ -25,6 +26,13 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { switch meta.ChannelType { case common.ChannelTypeAzure: + if meta.Mode == constant.RelayModeImagesGenerations { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api + // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview + fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.APIVersion) + return fullRequestURL, nil + } + // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api requestURL := strings.Split(meta.RequestURLPath, "?")[0] requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion) @@ -63,6 +71,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return request, nil } +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { return channel.DoRequestHelper(a, c, meta, requestBody) } @@ -73,7 +88,13 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel err, responseText, _ = StreamHandler(c, resp, meta.Mode) usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) } else { - err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + switch meta.Mode { + case constant.RelayModeImagesGenerations: + err, usage = ImageHandler(c, resp) + default: + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + } return } diff --git a/relay/channel/openai/main.go b/relay/channel/openai/main.go index d47cd164..e53dafa6 100644 --- a/relay/channel/openai/main.go +++ b/relay/channel/openai/main.go @@ -148,3 +148,37 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st } return nil, &textResponse.Usage } + +func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var imageResponse ImageResponse + responseBody, err := io.ReadAll(resp.Body) + + if err != nil { + return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), &imageResponse.Usage + } + err = resp.Body.Close() + if err != nil { + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), &imageResponse.Usage + } + err = json.Unmarshal(responseBody, &imageResponse) + if err != nil { + return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), &imageResponse.Usage + } + + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), &imageResponse.Usage + } + err = resp.Body.Close() + if err != nil { + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), &imageResponse.Usage + } + return nil, nil +} diff --git a/relay/channel/openai/model.go b/relay/channel/openai/model.go index 6c0b2c53..31d9fe97 100644 --- a/relay/channel/openai/model.go +++ b/relay/channel/openai/model.go @@ -110,11 +110,16 @@ type EmbeddingResponse struct { model.Usage `json:"usage"` } +type ImageData struct { + Url string `json:"url,omitempty"` + B64Json string `json:"b64_json,omitempty"` + RevisedPrompt string `json:"revised_prompt,omitempty"` +} + type ImageResponse struct { - Created int `json:"created"` - Data []struct { - Url string `json:"url"` - } + Created int `json:"created"` + Data []ImageData `json:"data"` + model.Usage `json:"usage"` } type ChatCompletionsStreamResponseChoice struct { diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index efd0620c..8d3c9b0b 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -36,6 +36,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { return channel.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index f348674e..84ed8527 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -52,6 +52,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return tencentRequest, nil } +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { return channel.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 92d9d7d6..8d83ff1b 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -38,6 +38,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, nil } +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { // xunfei's request is not http request, so we don't need to do anything here dummyResp := &http.Response{} diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 0ca23d59..75ff977b 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -67,6 +67,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { return channel.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/constant/image.go b/relay/constant/image.go index 5e04895f..3d011bb2 100644 --- a/relay/constant/image.go +++ b/relay/constant/image.go @@ -11,14 +11,31 @@ var DalleSizeRatios = map[string]map[string]float64{ "1024x1792": 2, "1792x1024": 2, }, + "stable-diffusion-xl": { + "512x1024": 1, + "1024x768": 1, + "1024x1024": 1, + "576x1024": 1, + "1024x576": 1, + }, + "wanx-v1": { + "1024x1024": 1, + "720x1280": 1, + "1280x720": 1, + }, } var DalleGenerationImageAmounts = map[string][2]int{ "dall-e-2": {1, 10}, "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. + "stable-diffusion-xl": {1, 4}, // Ali + "wanx-v1": {1, 4}, // Ali } var DalleImagePromptLengthLimitations = map[string]int{ "dall-e-2": 1000, "dall-e-3": 4000, + "dall-e-2": 1000, + "dall-e-3": 4000, + "stable-diffusion-xl": 4000, } diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 600a8d65..c78d22c7 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -36,8 +36,8 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener return textRequest, nil } -func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error) { - imageRequest := &openai.ImageRequest{} +func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { + imageRequest := &relaymodel.ImageRequest{} err := common.UnmarshalBodyReusable(c, imageRequest) if err != nil { return nil, err @@ -54,7 +54,7 @@ func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error return imageRequest, nil } -func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode { +func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode { // model validation _, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] if !hasValidSize { @@ -77,7 +77,7 @@ func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMet return nil } -func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) { +func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { if imageRequest == nil { return 0, errors.New("imageRequest is nil") } diff --git a/relay/controller/image.go b/relay/controller/image.go index d81dadf6..9d614300 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -6,18 +6,17 @@ import ( "encoding/json" "errors" "fmt" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/helper" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" "net/http" - "strings" - - "github.com/gin-gonic/gin" ) func isWithinRange(element string, value int) bool { @@ -56,15 +55,6 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError) } - requestURL := c.Request.URL.String() - fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) - if meta.ChannelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api - apiVersion := util.GetAzureAPIVersion(c) - // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, imageRequest.Model, apiVersion) - } - var requestBody io.Reader if isModelMapped || meta.ChannelType == common.ChannelTypeAzure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) @@ -76,6 +66,29 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus requestBody = c.Request.Body } + adaptor := helper.GetAdaptor(meta.APIType) + if adaptor == nil { + return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) + } + + switch meta.ChannelType { + case common.ChannelTypeAli: + fallthrough + case common.ChannelTypeBaidu: + fallthrough + case common.ChannelTypeZhipu: + finalRequest, err := adaptor.ConvertImageRequest(imageRequest) + if err != nil { + return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) + } + + jsonStr, err := json.Marshal(finalRequest) + if err != nil { + return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) + } + modelRatio := common.GetModelRatio(imageRequest.Model) groupRatio := common.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio @@ -87,36 +100,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) - if err != nil { - return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - token := c.Request.Header.Get("Authorization") - if meta.ChannelType == common.ChannelTypeAzure { // Azure authentication - token = strings.TrimPrefix(token, "Bearer ") - req.Header.Set("api-key", token) - } else { - req.Header.Set("Authorization", token) - } - - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - - resp, err := util.HTTPClient.Do(req) + // do request + resp, err := adaptor.DoRequest(c, meta, requestBody) if err != nil { + logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - err = req.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - err = c.Request.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - var imageResponse openai.ImageResponse - defer func(ctx context.Context) { if resp.StatusCode != http.StatusOK { return @@ -139,34 +129,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } }(c.Request.Context()) - responseBody, err := io.ReadAll(resp.Body) - - if err != nil { - return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - err = json.Unmarshal(responseBody, &imageResponse) - if err != nil { - return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + // do response + _, respErr := adaptor.DoResponse(c, resp, meta) + if respErr != nil { + logger.Errorf(ctx, "respErr is not nil: %+v", respErr) + return respErr } - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } return nil } diff --git a/relay/model/image.go b/relay/model/image.go new file mode 100644 index 00000000..bab84256 --- /dev/null +++ b/relay/model/image.go @@ -0,0 +1,12 @@ +package model + +type ImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Style string `json:"style,omitempty"` + User string `json:"user,omitempty"` +}