diff --git a/relay/midjourney/relay-mj.go b/relay/midjourney/relay-mj.go index e8993185..4ff6f40f 100644 --- a/relay/midjourney/relay-mj.go +++ b/relay/midjourney/relay-mj.go @@ -12,7 +12,6 @@ import ( "one-api/common" "one-api/controller" "one-api/model" - providersBase "one-api/providers/base" provider "one-api/providers/midjourney" "one-api/relay" "one-api/relay/util" @@ -141,7 +140,7 @@ func coverMidjourneyTaskDto(originTask *model.Midjourney) (midjourneyTask provid } func RelaySwapFace(c *gin.Context) *provider.MidjourneyResponse { - mjProvider, errWithMJ := getMJProvider(c, provider.RelayModeMidjourneySwapFace, 0, nil) + mjProvider, errWithMJ := getMJProviderWithRequest(c, provider.RelayModeMidjourneySwapFace, nil) if errWithMJ != nil { return errWithMJ } @@ -157,7 +156,7 @@ func RelaySwapFace(c *gin.Context) *provider.MidjourneyResponse { return provider.MidjourneyErrorWrapper(provider.MjRequestError, "sour_base64_and_target_base64_is_required") } - quotaInstance, errWithOA := getQuota(c, mjProvider.GetOriginalModel()) + quotaInstance, errWithOA := getQuota(c, provider.MjActionSwapFace) if errWithOA != nil { return &provider.MidjourneyResponse{ Code: 4, @@ -225,7 +224,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *provider.MidjourneyResponse { return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_no_found") } - mjProvider, errWithMJ := getMJProvider(c, provider.RelayModeMidjourneyTaskImageSeed, originTask.ChannelId, nil) + mjProvider, errWithMJ := getMJProviderWithChannelId(c, originTask.ChannelId) if errWithMJ != nil { return errWithMJ } @@ -314,7 +313,6 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *provider.MidjourneyResp } func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyResponse { - channelId := 0 userId := c.GetInt("id") consumeQuota := true var midjRequest provider.MidjourneyRequest @@ -323,6 +321,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyRe return provider.MidjourneyErrorWrapper(provider.MjRequestError, "bind_request_body_failed") } + mjProvider, errWithMJ := getMJProviderWithRequest(c, relayMode, &midjRequest) + if errWithMJ != nil { + return errWithMJ + } + if relayMode == provider.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息 mjErr := CoverPlusActionToNormalAction(&midjRequest) if mjErr != nil { @@ -378,7 +381,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyRe } else if originTask.Status != "SUCCESS" && relayMode != provider.RelayModeMidjourneyModal { return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_status_not_success") } else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理 - channelId = originTask.ChannelId + mjProvider, errWithMJ = getMJProviderWithChannelId(c, originTask.ChannelId) + if errWithMJ != nil { + return errWithMJ + } log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %d", originTask.ChannelId) } midjRequest.Prompt = originTask.Prompt @@ -395,17 +401,12 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyRe consumeQuota = false } - mjProvider, errWithMJ := getMJProvider(c, relayMode, channelId, &midjRequest) - if errWithMJ != nil { - return errWithMJ - } - //baseURL := common.ChannelBaseURLs[channelType] requestURL := getMjRequestPath(c.Request.URL.String()) //midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify" - quotaInstance, errWithOA := getQuota(c, mjProvider.GetOriginalModel()) + quotaInstance, errWithOA := getQuota(c, midjRequest.Action) if errWithOA != nil { return &provider.MidjourneyResponse{ Code: 4, @@ -538,33 +539,32 @@ func getMjRequestPath(path string) string { return requestURL } -func getQuota(c *gin.Context, modelName string) (*util.Quota, *types.OpenAIErrorWithStatusCode) { - // modelName = CoverActionToModelName(modelName) +func getQuota(c *gin.Context, action string) (*util.Quota, *types.OpenAIErrorWithStatusCode) { + modelName := CoverActionToModelName(action) return util.NewQuota(c, modelName, 1000) } -func getMJProvider(c *gin.Context, relayMode, channel_id int, request *provider.MidjourneyRequest) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) { - var baseProvider providersBase.ProviderInterface - modelName := "" - if channel_id > 0 { - c.Set("specific_channel_id", channel_id) +func getMJProviderWithRequest(c *gin.Context, relayMode int, request *provider.MidjourneyRequest) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) { + midjourneyModel, mjErr, _ := GetMjRequestModel(relayMode, request) + if mjErr != nil { + return nil, MidjourneyErrorFromInternal(mjErr.Code, mjErr.Description) + } + if midjourneyModel == "" { + return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无效的请求, 无法解析模型") } - if request != nil { - midjourneyModel, mjErr, _ := GetMjRequestModel(relayMode, request) - if mjErr != nil { - return nil, MidjourneyErrorFromInternal(mjErr.Code, mjErr.Description) - } - if midjourneyModel == "" { - return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无效的请求, 无法解析模型") - } + return getMJProvider(c, midjourneyModel) +} - modelName = midjourneyModel - } +func getMJProviderWithChannelId(c *gin.Context, channel_id int) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) { + c.Set("specific_channel_id", channel_id) - var err error - baseProvider, _, err = relay.GetProvider(c, modelName) + return getMJProvider(c, "") +} + +func getMJProvider(c *gin.Context, modelName string) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) { + baseProvider, _, err := relay.GetProvider(c, modelName) if err != nil { return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无法获取provider:"+err.Error()) }