feat: support openai images edits api
This commit is contained in:
parent
da0842272c
commit
f71e4ef151
@ -2,6 +2,7 @@ package ctxkey
|
||||
|
||||
const (
|
||||
Id = "id"
|
||||
RequestId = "X-Oneapi-Request-Id"
|
||||
Username = "username"
|
||||
Role = "role"
|
||||
Status = "status"
|
||||
@ -14,6 +15,7 @@ const (
|
||||
Group = "group"
|
||||
ModelMapping = "model_mapping"
|
||||
ChannelName = "channel_name"
|
||||
ContentType = "content_type"
|
||||
TokenId = "token_id"
|
||||
TokenName = "token_name"
|
||||
BaseURL = "base_url"
|
||||
|
@ -2,10 +2,10 @@ package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const KeyRequestBody = "key_request_body"
|
||||
@ -29,18 +29,16 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
err = json.Unmarshal(requestBody, &v)
|
||||
} else {
|
||||
// skip for now
|
||||
// TODO: someday non json request have variant model, we will need to implementation this
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Reset request body
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
defer func() {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
}()
|
||||
|
||||
if err = c.Bind(v); err != nil {
|
||||
return errors.Wrap(err, "bind request body failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,3 @@
|
||||
package logger
|
||||
|
||||
const (
|
||||
RequestIdKey = "X-Oneapi-Request-Id"
|
||||
)
|
||||
|
||||
var LogDir string
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
)
|
||||
|
||||
@ -87,7 +88,7 @@ func logHelper(ctx context.Context, level string, msg string) {
|
||||
if level == loggerINFO {
|
||||
writer = gin.DefaultWriter
|
||||
}
|
||||
id := ctx.Value(RequestIdKey)
|
||||
id := ctx.Value(ctxkey.RequestId)
|
||||
if id == nil {
|
||||
id = helper.GenRequestID()
|
||||
}
|
||||
|
@ -25,7 +25,8 @@ import (
|
||||
func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
|
||||
var err *model.ErrorWithStatusCode
|
||||
switch relayMode {
|
||||
case relaymode.ImagesGenerations:
|
||||
case relaymode.ImagesGenerations,
|
||||
relaymode.ImagesEdits:
|
||||
err = controller.RelayImageHelper(c, relayMode)
|
||||
case relaymode.AudioSpeech:
|
||||
fallthrough
|
||||
@ -42,10 +43,6 @@ func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
|
||||
func Relay(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
relayMode := relaymode.GetByPath(c.Request.URL.Path)
|
||||
if config.DebugEnabled {
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
logger.Debugf(ctx, "request body: %s", string(requestBody))
|
||||
}
|
||||
channelId := c.GetInt(ctxkey.ChannelId)
|
||||
bizErr := relayHelper(c, relayMode)
|
||||
if bizErr == nil {
|
||||
@ -56,8 +53,9 @@ func Relay(c *gin.Context) {
|
||||
channelName := c.GetString(ctxkey.ChannelName)
|
||||
group := c.GetString(ctxkey.Group)
|
||||
originalModel := c.GetString(ctxkey.OriginalModel)
|
||||
go processChannelRelayError(ctx, channelId, channelName, bizErr)
|
||||
requestId := c.GetString(logger.RequestIdKey)
|
||||
// bizErr is shared, should not run this function in goroutine to avoid race
|
||||
processChannelRelayError(ctx, channelId, channelName, bizErr)
|
||||
requestId := c.GetString(ctxkey.RequestId)
|
||||
retryTimes := config.RetryTimes
|
||||
if !shouldRetry(c, bizErr.StatusCode) {
|
||||
logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode)
|
||||
@ -83,8 +81,10 @@ func Relay(c *gin.Context) {
|
||||
channelId := c.GetInt(ctxkey.ChannelId)
|
||||
lastFailedChannelId = channelId
|
||||
channelName := c.GetString(ctxkey.ChannelName)
|
||||
go processChannelRelayError(ctx, channelId, channelName, bizErr)
|
||||
// bizErr is shared, should not run this function in goroutine to avoid race
|
||||
processChannelRelayError(ctx, channelId, channelName, bizErr)
|
||||
}
|
||||
|
||||
if bizErr != nil {
|
||||
if bizErr.StatusCode == http.StatusTooManyRequests {
|
||||
bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
|
@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
type ModelRequest struct {
|
||||
Model string `json:"model"`
|
||||
Model string `json:"model" form:"model"`
|
||||
}
|
||||
|
||||
func Distribute() func(c *gin.Context) {
|
||||
@ -61,6 +61,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
c.Set(ctxkey.Channel, channel.Type)
|
||||
c.Set(ctxkey.ChannelId, channel.Id)
|
||||
c.Set(ctxkey.ChannelName, channel.Name)
|
||||
c.Set(ctxkey.ContentType, c.Request.Header.Get("Content-Type"))
|
||||
c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
|
||||
c.Set(ctxkey.OriginalModel, modelName) // for retry
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
|
@ -2,15 +2,16 @@ package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
)
|
||||
|
||||
func SetUpLogger(server *gin.Engine) {
|
||||
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
var requestID string
|
||||
if param.Keys != nil {
|
||||
requestID = param.Keys[logger.RequestIdKey].(string)
|
||||
requestID = param.Keys[ctxkey.RequestId].(string)
|
||||
}
|
||||
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
|
||||
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
|
||||
|
@ -2,18 +2,19 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
)
|
||||
|
||||
func RequestId() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
id := helper.GenRequestID()
|
||||
c.Set(logger.RequestIdKey, id)
|
||||
ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id)
|
||||
c.Set(ctxkey.RequestId, id)
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.RequestId, id)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
c.Header(logger.RequestIdKey, id)
|
||||
c.Header(ctxkey.RequestId, id)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
@ -2,17 +2,19 @@ package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||
c.JSON(statusCode, gin.H{
|
||||
"error": gin.H{
|
||||
"message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)),
|
||||
"message": helper.MessageWithRequestId(message, c.GetString(ctxkey.RequestId)),
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
|
@ -3,11 +3,13 @@ package adaptor
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/client"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/relay/client"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
)
|
||||
|
||||
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) {
|
||||
@ -27,6 +29,9 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", c.GetString(ctxkey.ContentType))
|
||||
|
||||
err = a.SetupRequestHeader(c, req, meta)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
|
@ -93,10 +93,13 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
|
||||
switch meta.Mode {
|
||||
case relaymode.ImagesGenerations:
|
||||
err, _ = ImageHandler(c, resp)
|
||||
case relaymode.ImagesEdits:
|
||||
err, _ = ImagesEditsHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -3,12 +3,30 @@ package openai
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
// ImagesEditsHandler just copy response body to client
|
||||
//
|
||||
// https://platform.openai.com/docs/api-reference/images/createEdit
|
||||
func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
|
||||
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
||||
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var imageResponse ImageResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
@ -6,6 +6,10 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
@ -16,8 +20,6 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func isWithinRange(element string, value int) bool {
|
||||
@ -56,7 +58,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
}
|
||||
|
||||
var requestBody io.Reader
|
||||
if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body
|
||||
if strings.ToLower(c.GetString(ctxkey.ContentType)) == "application/json" &&
|
||||
isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body
|
||||
jsonStr, err := json.Marshal(imageRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
|
||||
|
@ -1,12 +1,12 @@
|
||||
package model
|
||||
|
||||
type ImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Model string `json:"model" form:"model"`
|
||||
Prompt string `json:"prompt" binding:"required" form:"prompt"`
|
||||
N int `json:"n,omitempty" form:"n"`
|
||||
Size string `json:"size,omitempty" form:"size"`
|
||||
Quality string `json:"quality,omitempty" form:"quality"`
|
||||
ResponseFormat string `json:"response_format,omitempty" form:"response_format"`
|
||||
Style string `json:"style,omitempty" form:"style"`
|
||||
User string `json:"user,omitempty" form:"user"`
|
||||
}
|
||||
|
@ -11,4 +11,5 @@ const (
|
||||
AudioSpeech
|
||||
AudioTranscription
|
||||
AudioTranslation
|
||||
ImagesEdits
|
||||
)
|
||||
|
@ -24,6 +24,9 @@ func GetByPath(path string) int {
|
||||
relayMode = AudioTranscription
|
||||
} else if strings.HasPrefix(path, "/v1/audio/translations") {
|
||||
relayMode = AudioTranslation
|
||||
} else if strings.HasPrefix(path, "/v1/images/edits") {
|
||||
relayMode = ImagesEdits
|
||||
}
|
||||
|
||||
return relayMode
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
relayV1Router.POST("/chat/completions", controller.Relay)
|
||||
relayV1Router.POST("/edits", controller.Relay)
|
||||
relayV1Router.POST("/images/generations", controller.Relay)
|
||||
relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/images/edits", controller.Relay)
|
||||
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/embeddings", controller.Relay)
|
||||
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
|
||||
|
Loading…
Reference in New Issue
Block a user