diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go index 568cb095..56dbd8b3 100644 --- a/common/ctxkey/key.go +++ b/common/ctxkey/key.go @@ -2,6 +2,7 @@ package ctxkey const ( Id = "id" + RequestId = "X-Oneapi-Request-Id" Username = "username" Role = "role" Status = "status" @@ -14,6 +15,7 @@ const ( Group = "group" ModelMapping = "model_mapping" ChannelName = "channel_name" + ContentType = "content_type" TokenId = "token_id" TokenName = "token_name" BaseURL = "base_url" diff --git a/common/gin.go b/common/gin.go index b6ef96a6..92e42474 100644 --- a/common/gin.go +++ b/common/gin.go @@ -2,10 +2,10 @@ package common import ( "bytes" - "encoding/json" - "github.com/gin-gonic/gin" "io" - "strings" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" ) const KeyRequestBody = "key_request_body" @@ -29,18 +29,16 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { if err != nil { return err } - contentType := c.Request.Header.Get("Content-Type") - if strings.HasPrefix(contentType, "application/json") { - err = json.Unmarshal(requestBody, &v) - } else { - // skip for now - // TODO: someday non json request have variant model, we will need to implementation this - } - if err != nil { - return err - } // Reset request body c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + defer func() { + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + }() + + if err = c.Bind(v); err != nil { + return errors.Wrap(err, "bind request body failed") + } + return nil } diff --git a/common/logger/constants.go b/common/logger/constants.go index 78d32062..49df31ec 100644 --- a/common/logger/constants.go +++ b/common/logger/constants.go @@ -1,7 +1,3 @@ package logger -const ( - RequestIdKey = "X-Oneapi-Request-Id" -) - var LogDir string diff --git a/common/logger/logger.go b/common/logger/logger.go index 858e33e2..1e13a894 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -12,6 +12,7 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/helper" ) @@ -87,7 +88,7 @@ func logHelper(ctx context.Context, level string, msg string) { if level == loggerINFO { writer = gin.DefaultWriter } - id := ctx.Value(RequestIdKey) + id := ctx.Value(ctxkey.RequestId) if id == nil { id = helper.GenRequestID() } diff --git a/controller/relay.go b/controller/relay.go index 5fd22f85..ffd626bf 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -25,7 +25,8 @@ import ( func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { var err *model.ErrorWithStatusCode switch relayMode { - case relaymode.ImagesGenerations: + case relaymode.ImagesGenerations, + relaymode.ImagesEdits: err = controller.RelayImageHelper(c, relayMode) case relaymode.AudioSpeech: fallthrough @@ -42,10 +43,6 @@ func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { func Relay(c *gin.Context) { ctx := c.Request.Context() relayMode := relaymode.GetByPath(c.Request.URL.Path) - if config.DebugEnabled { - requestBody, _ := common.GetRequestBody(c) - logger.Debugf(ctx, "request body: %s", string(requestBody)) - } channelId := c.GetInt(ctxkey.ChannelId) bizErr := relayHelper(c, relayMode) if bizErr == nil { @@ -56,8 +53,9 @@ func Relay(c *gin.Context) { channelName := c.GetString(ctxkey.ChannelName) group := c.GetString(ctxkey.Group) originalModel := c.GetString(ctxkey.OriginalModel) - go processChannelRelayError(ctx, channelId, channelName, bizErr) - requestId := c.GetString(logger.RequestIdKey) + // bizErr is shared, should not run this function in goroutine to avoid race + processChannelRelayError(ctx, channelId, channelName, bizErr) + requestId := c.GetString(ctxkey.RequestId) retryTimes := config.RetryTimes if !shouldRetry(c, bizErr.StatusCode) { logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode) @@ -83,8 +81,10 @@ func Relay(c *gin.Context) { channelId := c.GetInt(ctxkey.ChannelId) lastFailedChannelId = channelId channelName := c.GetString(ctxkey.ChannelName) - go processChannelRelayError(ctx, channelId, channelName, bizErr) + // bizErr is shared, should not run this function in goroutine to avoid race + processChannelRelayError(ctx, channelId, channelName, bizErr) } + if bizErr != nil { if bizErr.StatusCode == http.StatusTooManyRequests { bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" diff --git a/middleware/distributor.go b/middleware/distributor.go index a4c34085..d4a6a120 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -12,7 +12,7 @@ import ( ) type ModelRequest struct { - Model string `json:"model"` + Model string `json:"model" form:"model"` } func Distribute() func(c *gin.Context) { @@ -61,6 +61,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set(ctxkey.Channel, channel.Type) c.Set(ctxkey.ChannelId, channel.Id) c.Set(ctxkey.ChannelName, channel.Name) + c.Set(ctxkey.ContentType, c.Request.Header.Get("Content-Type")) c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) c.Set(ctxkey.OriginalModel, modelName) // for retry c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) diff --git a/middleware/logger.go b/middleware/logger.go index 6aae4f23..81d427a3 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -2,15 +2,16 @@ package middleware import ( "fmt" + "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/ctxkey" ) func SetUpLogger(server *gin.Engine) { server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { var requestID string if param.Keys != nil { - requestID = param.Keys[logger.RequestIdKey].(string) + requestID = param.Keys[ctxkey.RequestId].(string) } return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", param.TimeStamp.Format("2006/01/02 - 15:04:05"), diff --git a/middleware/request-id.go b/middleware/request-id.go index a4c49ddb..09d3809a 100644 --- a/middleware/request-id.go +++ b/middleware/request-id.go @@ -2,18 +2,19 @@ package middleware import ( "context" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/helper" - "github.com/songquanpeng/one-api/common/logger" ) func RequestId() func(c *gin.Context) { return func(c *gin.Context) { id := helper.GenRequestID() - c.Set(logger.RequestIdKey, id) - ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) + c.Set(ctxkey.RequestId, id) + ctx := context.WithValue(c.Request.Context(), ctxkey.RequestId, id) c.Request = c.Request.WithContext(ctx) - c.Header(logger.RequestIdKey, id) + c.Header(ctxkey.RequestId, id) c.Next() } } diff --git a/middleware/utils.go b/middleware/utils.go index b65b018b..ab6a6fb5 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -2,17 +2,19 @@ package middleware import ( "fmt" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "strings" ) func abortWithMessage(c *gin.Context, statusCode int, message string) { c.JSON(statusCode, gin.H{ "error": gin.H{ - "message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)), + "message": helper.MessageWithRequestId(message, c.GetString(ctxkey.RequestId)), "type": "one_api_error", }, }) diff --git a/relay/adaptor/common.go b/relay/adaptor/common.go index 82a5160e..0014323e 100644 --- a/relay/adaptor/common.go +++ b/relay/adaptor/common.go @@ -3,11 +3,13 @@ package adaptor import ( "errors" "fmt" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/client" - "github.com/songquanpeng/one-api/relay/meta" "io" "net/http" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/relay/client" + "github.com/songquanpeng/one-api/relay/meta" ) func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) { @@ -27,6 +29,9 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io. if err != nil { return nil, fmt.Errorf("new request failed: %w", err) } + + req.Header.Set("Content-Type", c.GetString(ctxkey.ContentType)) + err = a.SetupRequestHeader(c, req, meta) if err != nil { return nil, fmt.Errorf("setup request header failed: %w", err) diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index 4bb2384e..05ecc2a9 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -93,10 +93,13 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met switch meta.Mode { case relaymode.ImagesGenerations: err, _ = ImageHandler(c, resp) + case relaymode.ImagesEdits: + err, _ = ImagesEditsHandler(c, resp) default: err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) } } + return } diff --git a/relay/adaptor/openai/image.go b/relay/adaptor/openai/image.go index 0f89618a..433d9421 100644 --- a/relay/adaptor/openai/image.go +++ b/relay/adaptor/openai/image.go @@ -3,12 +3,30 @@ package openai import ( "bytes" "encoding/json" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/model" "io" "net/http" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/model" ) +// ImagesEditsHandler just copy response body to client +// +// https://platform.openai.com/docs/api-reference/images/createEdit +func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + c.Writer.WriteHeader(resp.StatusCode) + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + + if _, err := io.Copy(c.Writer, resp.Body); err != nil { + return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + defer resp.Body.Close() + + return nil, nil +} + func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var imageResponse ImageResponse responseBody, err := io.ReadAll(resp.Body) diff --git a/relay/controller/image.go b/relay/controller/image.go index 216e4700..94df1d55 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -6,6 +6,10 @@ import ( "encoding/json" "errors" "fmt" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" @@ -16,8 +20,6 @@ import ( "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" ) func isWithinRange(element string, value int) bool { @@ -56,7 +58,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } var requestBody io.Reader - if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body + if strings.ToLower(c.GetString(ctxkey.ContentType)) == "application/json" && + isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) if err != nil { return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) diff --git a/relay/model/image.go b/relay/model/image.go index bab84256..00bd8b79 100644 --- a/relay/model/image.go +++ b/relay/model/image.go @@ -1,12 +1,12 @@ package model type ImageRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt" binding:"required"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - Quality string `json:"quality,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - Style string `json:"style,omitempty"` - User string `json:"user,omitempty"` + Model string `json:"model" form:"model"` + Prompt string `json:"prompt" binding:"required" form:"prompt"` + N int `json:"n,omitempty" form:"n"` + Size string `json:"size,omitempty" form:"size"` + Quality string `json:"quality,omitempty" form:"quality"` + ResponseFormat string `json:"response_format,omitempty" form:"response_format"` + Style string `json:"style,omitempty" form:"style"` + User string `json:"user,omitempty" form:"user"` } diff --git a/relay/relaymode/define.go b/relay/relaymode/define.go index 96d09438..a4ba0cad 100644 --- a/relay/relaymode/define.go +++ b/relay/relaymode/define.go @@ -11,4 +11,5 @@ const ( AudioSpeech AudioTranscription AudioTranslation + ImagesEdits ) diff --git a/relay/relaymode/helper.go b/relay/relaymode/helper.go index 926dd42e..2110239b 100644 --- a/relay/relaymode/helper.go +++ b/relay/relaymode/helper.go @@ -24,6 +24,9 @@ func GetByPath(path string) int { relayMode = AudioTranscription } else if strings.HasPrefix(path, "/v1/audio/translations") { relayMode = AudioTranslation + } else if strings.HasPrefix(path, "/v1/images/edits") { + relayMode = ImagesEdits } + return relayMode } diff --git a/router/relay.go b/router/relay.go index 65072c86..12758a46 100644 --- a/router/relay.go +++ b/router/relay.go @@ -23,7 +23,7 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/chat/completions", controller.Relay) relayV1Router.POST("/edits", controller.Relay) relayV1Router.POST("/images/generations", controller.Relay) - relayV1Router.POST("/images/edits", controller.RelayNotImplemented) + relayV1Router.POST("/images/edits", controller.Relay) relayV1Router.POST("/images/variations", controller.RelayNotImplemented) relayV1Router.POST("/embeddings", controller.Relay) relayV1Router.POST("/engines/:model/embeddings", controller.Relay)