diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go index 90556b3a..e655680f 100644 --- a/common/ctxkey/key.go +++ b/common/ctxkey/key.go @@ -3,6 +3,7 @@ package ctxkey const ( Config = "config" Id = "id" + RequestId = "X-Oneapi-Request-Id" Username = "username" Role = "role" Status = "status" @@ -15,6 +16,7 @@ const ( Group = "group" ModelMapping = "model_mapping" ChannelName = "channel_name" + ContentType = "content_type" TokenId = "token_id" TokenName = "token_name" BaseURL = "base_url" diff --git a/common/gin.go b/common/gin.go index 549d3279..4b68eb06 100644 --- a/common/gin.go +++ b/common/gin.go @@ -2,11 +2,11 @@ package common import ( "bytes" - "encoding/json" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common/ctxkey" "io" - "strings" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/ctxkey" ) func GetRequestBody(c *gin.Context) ([]byte, error) { @@ -28,18 +28,16 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { if err != nil { return err } - contentType := c.Request.Header.Get("Content-Type") - if strings.HasPrefix(contentType, "application/json") { - err = json.Unmarshal(requestBody, &v) - } else { - // skip for now - // TODO: someday non json request have variant model, we will need to implementation this - } - if err != nil { - return err - } // Reset request body c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + defer func() { + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + }() + + if err = c.Bind(v); err != nil { + return errors.Wrap(err, "bind request body failed") + } + return nil } diff --git a/controller/relay.go b/controller/relay.go index 49358e25..7de84e3f 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -26,7 +26,8 @@ import ( func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { var err *model.ErrorWithStatusCode switch relayMode { - case relaymode.ImagesGenerations: + case relaymode.ImagesGenerations, + relaymode.ImagesEdits: err = controller.RelayImageHelper(c, relayMode) case relaymode.AudioSpeech: fallthrough @@ -45,10 +46,6 @@ func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { func Relay(c *gin.Context) { ctx := c.Request.Context() relayMode := relaymode.GetByPath(c.Request.URL.Path) - if config.DebugEnabled { - requestBody, _ := common.GetRequestBody(c) - logger.Debugf(ctx, "request body: %s", string(requestBody)) - } channelId := c.GetInt(ctxkey.ChannelId) userId := c.GetInt(ctxkey.Id) bizErr := relayHelper(c, relayMode) @@ -60,6 +57,8 @@ func Relay(c *gin.Context) { channelName := c.GetString(ctxkey.ChannelName) group := c.GetString(ctxkey.Group) originalModel := c.GetString(ctxkey.OriginalModel) + + // BUG: bizErr is shared, should not run this function in goroutine to avoid race go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) requestId := c.GetString(helper.RequestIdKey) retryTimes := config.RetryTimes @@ -90,6 +89,7 @@ func Relay(c *gin.Context) { // BUG: bizErr is in race condition go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) } + if bizErr != nil { if bizErr.StatusCode == http.StatusTooManyRequests { bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" diff --git a/middleware/distributor.go b/middleware/distributor.go index 0c4b04c3..a4740c22 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -12,7 +12,7 @@ import ( ) type ModelRequest struct { - Model string `json:"model"` + Model string `json:"model" form:"model"` } func Distribute() func(c *gin.Context) { @@ -61,6 +61,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set(ctxkey.Channel, channel.Type) c.Set(ctxkey.ChannelId, channel.Id) c.Set(ctxkey.ChannelName, channel.Name) + c.Set(ctxkey.ContentType, c.Request.Header.Get("Content-Type")) c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) c.Set(ctxkey.OriginalModel, modelName) // for retry c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) diff --git a/middleware/logger.go b/middleware/logger.go index 191364f8..587d748c 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -2,6 +2,7 @@ package middleware import ( "fmt" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/helper" ) diff --git a/middleware/request-id.go b/middleware/request-id.go index bef09e32..c1f3adc2 100644 --- a/middleware/request-id.go +++ b/middleware/request-id.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/helper" ) diff --git a/middleware/utils.go b/middleware/utils.go index 4d2f8092..2afcab47 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -2,11 +2,12 @@ package middleware import ( "fmt" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "strings" ) func abortWithMessage(c *gin.Context, statusCode int, message string) { diff --git a/relay/adaptor/common.go b/relay/adaptor/common.go index 8953d7a3..9069255a 100644 --- a/relay/adaptor/common.go +++ b/relay/adaptor/common.go @@ -3,11 +3,13 @@ package adaptor import ( "errors" "fmt" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common/client" - "github.com/songquanpeng/one-api/relay/meta" "io" "net/http" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/client" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/relay/meta" ) func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) { @@ -27,6 +29,9 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io. if err != nil { return nil, fmt.Errorf("new request failed: %w", err) } + + req.Header.Set("Content-Type", c.GetString(ctxkey.ContentType)) + err = a.SetupRequestHeader(c, req, meta) if err != nil { return nil, fmt.Errorf("setup request header failed: %w", err) diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index 5dc395ad..120d21a6 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -104,10 +104,13 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met switch meta.Mode { case relaymode.ImagesGenerations: err, _ = ImageHandler(c, resp) + case relaymode.ImagesEdits: + err, _ = ImagesEditsHandler(c, resp) default: err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) } } + return } diff --git a/relay/adaptor/openai/image.go b/relay/adaptor/openai/image.go index 0f89618a..433d9421 100644 --- a/relay/adaptor/openai/image.go +++ b/relay/adaptor/openai/image.go @@ -3,12 +3,30 @@ package openai import ( "bytes" "encoding/json" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/model" "io" "net/http" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/model" ) +// ImagesEditsHandler just copy response body to client +// +// https://platform.openai.com/docs/api-reference/images/createEdit +func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + c.Writer.WriteHeader(resp.StatusCode) + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + + if _, err := io.Copy(c.Writer, resp.Body); err != nil { + return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + defer resp.Body.Close() + + return nil, nil +} + func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var imageResponse ImageResponse responseBody, err := io.ReadAll(resp.Body) diff --git a/relay/controller/image.go b/relay/controller/image.go index 1e06e858..7f0536ab 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "strings" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" @@ -134,7 +135,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus c.Set("response_format", imageRequest.ResponseFormat) var requestBody io.Reader - if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body + if strings.ToLower(c.GetString(ctxkey.ContentType)) == "application/json" && + isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) if err != nil { return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) diff --git a/relay/model/image.go b/relay/model/image.go index bab84256..00bd8b79 100644 --- a/relay/model/image.go +++ b/relay/model/image.go @@ -1,12 +1,12 @@ package model type ImageRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt" binding:"required"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - Quality string `json:"quality,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - Style string `json:"style,omitempty"` - User string `json:"user,omitempty"` + Model string `json:"model" form:"model"` + Prompt string `json:"prompt" binding:"required" form:"prompt"` + N int `json:"n,omitempty" form:"n"` + Size string `json:"size,omitempty" form:"size"` + Quality string `json:"quality,omitempty" form:"quality"` + ResponseFormat string `json:"response_format,omitempty" form:"response_format"` + Style string `json:"style,omitempty" form:"style"` + User string `json:"user,omitempty" form:"user"` } diff --git a/relay/relaymode/define.go b/relay/relaymode/define.go index aa771205..79999826 100644 --- a/relay/relaymode/define.go +++ b/relay/relaymode/define.go @@ -11,6 +11,7 @@ const ( AudioSpeech AudioTranscription AudioTranslation + ImagesEdits // Proxy is a special relay mode for proxying requests to custom upstream Proxy ) diff --git a/relay/relaymode/helper.go b/relay/relaymode/helper.go index 2cde5b85..35a0535e 100644 --- a/relay/relaymode/helper.go +++ b/relay/relaymode/helper.go @@ -24,8 +24,11 @@ func GetByPath(path string) int { relayMode = AudioTranscription } else if strings.HasPrefix(path, "/v1/audio/translations") { relayMode = AudioTranslation + } else if strings.HasPrefix(path, "/v1/images/edits") { + relayMode = ImagesEdits } else if strings.HasPrefix(path, "/v1/oneapi/proxy") { relayMode = Proxy } + return relayMode } diff --git a/router/relay.go b/router/relay.go index 094ea5fb..899c8209 100644 --- a/router/relay.go +++ b/router/relay.go @@ -24,7 +24,7 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/chat/completions", controller.Relay) relayV1Router.POST("/edits", controller.Relay) relayV1Router.POST("/images/generations", controller.Relay) - relayV1Router.POST("/images/edits", controller.RelayNotImplemented) + relayV1Router.POST("/images/edits", controller.Relay) relayV1Router.POST("/images/variations", controller.RelayNotImplemented) relayV1Router.POST("/embeddings", controller.Relay) relayV1Router.POST("/engines/:model/embeddings", controller.Relay)