diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go new file mode 100644 index 00000000..6f78365d --- /dev/null +++ b/relay/channel/ali/image.go @@ -0,0 +1,192 @@ +package ali + +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" + "strings" + "time" +) + +func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + responseFormat := c.GetString("response_format") + + var aliTaskResponse TaskResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &aliTaskResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + if aliTaskResponse.Message != "" { + logger.SysError("aliAsyncTask err: " + string(responseBody)) + return openai.ErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil + } + + aliResponse, _, err := asyncTaskWait(aliTaskResponse.Output.TaskId, apiKey) + if err != nil { + return openai.ErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil + } + + if aliResponse.Output.TaskStatus != "SUCCEEDED" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: aliResponse.Output.Message, + Type: "ali_error", + Param: "", + Code: aliResponse.Output.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := responseAli2OpenAIImage(aliResponse, responseFormat) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, nil +} + +func asyncTask(taskID string, key string) (*TaskResponse, error, []byte) { + url := fmt.Sprintf("https://dashscope.aliyuncs.com/api/v1/tasks/%s", taskID) + + var aliResponse TaskResponse + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return &aliResponse, err, nil + } + + req.Header.Set("Authorization", "Bearer "+key) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + logger.SysError("aliAsyncTask client.Do err: " + err.Error()) + return &aliResponse, err, nil + } + defer resp.Body.Close() + + responseBody, err := io.ReadAll(resp.Body) + + var response TaskResponse + err = json.Unmarshal(responseBody, &response) + if err != nil { + logger.SysError("aliAsyncTask NewDecoder err: " + err.Error()) + return &aliResponse, err, nil + } + + return &response, nil, responseBody +} + +func asyncTaskWait(taskID string, key string) (*TaskResponse, []byte, error) { + waitSeconds := 2 + step := 0 + maxStep := 20 + + var taskResponse TaskResponse + var responseBody []byte + + for { + step++ + rsp, err, body := asyncTask(taskID, key) + responseBody = body + if err != nil { + return &taskResponse, responseBody, err + } + + if rsp.Output.TaskStatus == "" { + return &taskResponse, responseBody, nil + } + + switch rsp.Output.TaskStatus { + case "FAILED": + fallthrough + case "CANCELED": + fallthrough + case "SUCCEEDED": + fallthrough + case "UNKNOWN": + return rsp, responseBody, nil + } + if step >= maxStep { + break + } + time.Sleep(time.Duration(waitSeconds) * time.Second) + } + + return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout") +} + +func responseAli2OpenAIImage(response *TaskResponse, responseFormat string) *openai.ImageResponse { + imageResponse := openai.ImageResponse{ + Created: helper.GetTimestamp(), + } + + for _, data := range response.Output.Results { + var b64Json string + if responseFormat == "b64_json" { + // 读取 data.Url 的图片数据并转存到 b64Json + imageData, err := getImageData(data.Url) + if err != nil { + // 处理获取图片数据失败的情况 + logger.SysError("getImageData Error getting image data: " + err.Error()) + continue + } + + // 将图片数据转为 Base64 编码的字符串 + b64Json = Base64Encode(imageData) + } else { + // 如果 responseFormat 不是 "b64_json",则直接使用 data.B64Image + b64Json = data.B64Image + } + + imageResponse.Data = append(imageResponse.Data, openai.ImageData{ + Url: data.Url, + B64Json: b64Json, + RevisedPrompt: "", + }) + } + return &imageResponse +} + +func getImageData(url string) ([]byte, error) { + response, err := http.Get(url) + if err != nil { + return nil, err + } + defer response.Body.Close() + + imageData, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + + return imageData, nil +} + +func Base64Encode(data []byte) string { + b64Json := base64.StdEncoding.EncodeToString(data) + return b64Json +} diff --git a/relay/channel/ali/main.go b/relay/channel/ali/main.go index c1832b11..49019fea 100644 --- a/relay/channel/ali/main.go +++ b/relay/channel/ali/main.go @@ -2,10 +2,7 @@ package ali import ( "bufio" - "encoding/base64" "encoding/json" - "errors" - "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" @@ -15,7 +12,6 @@ import ( "io" "net/http" "strings" - "time" ) // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r @@ -280,173 +276,3 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * _, err = c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage } - -func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - responseFormat := c.GetString("response_format") - - var aliTaskResponse TaskResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = json.Unmarshal(responseBody, &aliTaskResponse) - if err != nil { - return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - - if aliTaskResponse.Message != "" { - logger.SysError("aliAsyncTask err: " + string(responseBody)) - return openai.ErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil - } - - aliResponse, err, _ := asyncTaskWait(aliTaskResponse.Output.TaskId, apiKey) - if err != nil { - return openai.ErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil - } - - if aliResponse.Output.TaskStatus != "SUCCEEDED" { - return &model.ErrorWithStatusCode{ - Error: model.Error{ - Message: aliResponse.Output.Message, - Type: "ali_error", - Param: "", - Code: aliResponse.Output.Code, - }, - StatusCode: resp.StatusCode, - }, nil - } - - fullTextResponse := responseAli2OpenAIImage(aliResponse, responseFormat) - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &fullTextResponse.Usage -} - -func asyncTask(taskID string, key string) (*TaskResponse, error, []byte) { - url := fmt.Sprintf("https://dashscope.aliyuncs.com/api/v1/tasks/%s", taskID) - - var aliResponse TaskResponse - - req, err := http.NewRequest("GET", url, nil) - if err != nil { - return &aliResponse, err, nil - } - - req.Header.Set("Authorization", "Bearer "+key) - - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - logger.SysError("aliAsyncTask client.Do err: " + err.Error()) - return &aliResponse, err, nil - } - defer resp.Body.Close() - - responseBody, err := io.ReadAll(resp.Body) - - var response TaskResponse - err = json.Unmarshal(responseBody, &response) - if err != nil { - logger.SysError("aliAsyncTask NewDecoder err: " + err.Error()) - return &aliResponse, err, nil - } - - return &response, nil, responseBody -} - -func asyncTaskWait(taskID string, key string) (*TaskResponse, error, []byte) { - waitSeconds := 2 - step := 0 - - var taskResponse TaskResponse - var responseBody []byte - - for { - step++ - rsp, err, body := asyncTask(taskID, key) - responseBody = body - if err != nil { - return &taskResponse, err, responseBody - } - - if rsp.Output.TaskStatus == "" { - return &taskResponse, nil, responseBody - } - - switch rsp.Output.TaskStatus { - case "FAILED": - fallthrough - case "CANCELED": - fallthrough - case "SUCCEEDED": - fallthrough - case "UNKNOWN": - return rsp, nil, responseBody - } - - time.Sleep(time.Duration(waitSeconds) * time.Second) - } - - return &taskResponse, nil, responseBody -} - -func responseAli2OpenAIImage(response *TaskResponse, responseFormat string) *openai.ImageResponse { - imageRespones := openai.ImageResponse{} - - for _, data := range response.Output.Results { - var b64Json string - if responseFormat == "b64_json" { - // 读取 data.Url 的图片数据并转存到 b64Json - imageData, err := getImageData(data.Url) - if err != nil { - // 处理获取图片数据失败的情况 - logger.SysError("getImageData Error getting image data: " + err.Error()) - continue - } - - // 将图片数据转为 Base64 编码的字符串 - b64Json = Base64Encode(imageData) - } else { - // 如果 responseFormat 不是 "b64_json",则直接使用 data.B64Image - b64Json = data.B64Image - } - - imageRespones.Data = append(imageRespones.Data, openai.ImageData{ - Url: data.Url, - B64Json: b64Json, - RevisedPrompt: "", - }) - } - return &imageRespones -} - -func getImageData(url string) ([]byte, error) { - response, err := http.Get(url) - if err != nil { - return nil, err - } - defer response.Body.Close() - - imageData, err := io.ReadAll(response.Body) - if err != nil { - return nil, err - } - - return imageData, nil -} - -func Base64Encode(data []byte) string { - b64Json := base64.StdEncoding.EncodeToString(data) - return b64Json -} diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 91b5fc2c..3212d8f8 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -92,7 +92,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel } else { switch meta.Mode { case constant.RelayModeImagesGenerations: - err, usage = ImageHandler(c, resp) + err, _ = ImageHandler(c, resp) default: err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) } diff --git a/relay/channel/openai/main.go b/relay/channel/openai/main.go index df3f0691..7ace3f63 100644 --- a/relay/channel/openai/main.go +++ b/relay/channel/openai/main.go @@ -181,5 +181,5 @@ func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCo if err != nil { return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - return nil, &imageResponse.Usage + return nil, nil } diff --git a/relay/channel/openai/model.go b/relay/channel/openai/model.go index 6e9c38f1..ce252ff6 100644 --- a/relay/channel/openai/model.go +++ b/relay/channel/openai/model.go @@ -117,9 +117,9 @@ type ImageData struct { } type ImageResponse struct { - Created int `json:"created"` - Data []ImageData `json:"data"` - model.Usage `json:"usage"` + Created int64 `json:"created"` + Data []ImageData `json:"data"` + //model.Usage `json:"usage"` } type ChatCompletionsStreamResponseChoice struct {