🐛 fix: mj error

This commit is contained in:
MartialBE 2024-04-06 15:07:33 +08:00
parent 8f74fadf8a
commit 43818e7c8b
No known key found for this signature in database
GPG Key ID: F5A7AC860020C896

View File

@ -12,7 +12,6 @@ import (
"one-api/common"
"one-api/controller"
"one-api/model"
providersBase "one-api/providers/base"
provider "one-api/providers/midjourney"
"one-api/relay"
"one-api/relay/util"
@ -141,7 +140,7 @@ func coverMidjourneyTaskDto(originTask *model.Midjourney) (midjourneyTask provid
}
func RelaySwapFace(c *gin.Context) *provider.MidjourneyResponse {
mjProvider, errWithMJ := getMJProvider(c, provider.RelayModeMidjourneySwapFace, 0, nil)
mjProvider, errWithMJ := getMJProviderWithRequest(c, provider.RelayModeMidjourneySwapFace, nil)
if errWithMJ != nil {
return errWithMJ
}
@ -157,7 +156,7 @@ func RelaySwapFace(c *gin.Context) *provider.MidjourneyResponse {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "sour_base64_and_target_base64_is_required")
}
quotaInstance, errWithOA := getQuota(c, mjProvider.GetOriginalModel())
quotaInstance, errWithOA := getQuota(c, provider.MjActionSwapFace)
if errWithOA != nil {
return &provider.MidjourneyResponse{
Code: 4,
@ -225,7 +224,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *provider.MidjourneyResponse {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_no_found")
}
mjProvider, errWithMJ := getMJProvider(c, provider.RelayModeMidjourneyTaskImageSeed, originTask.ChannelId, nil)
mjProvider, errWithMJ := getMJProviderWithChannelId(c, originTask.ChannelId)
if errWithMJ != nil {
return errWithMJ
}
@ -314,7 +313,6 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *provider.MidjourneyResp
}
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyResponse {
channelId := 0
userId := c.GetInt("id")
consumeQuota := true
var midjRequest provider.MidjourneyRequest
@ -323,6 +321,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyRe
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "bind_request_body_failed")
}
mjProvider, errWithMJ := getMJProviderWithRequest(c, relayMode, &midjRequest)
if errWithMJ != nil {
return errWithMJ
}
if relayMode == provider.RelayModeMidjourneyAction { // midjourney plus需要从customId中获取任务信息
mjErr := CoverPlusActionToNormalAction(&midjRequest)
if mjErr != nil {
@ -378,7 +381,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyRe
} else if originTask.Status != "SUCCESS" && relayMode != provider.RelayModeMidjourneyModal {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_status_not_success")
} else { //原任务的Status=SUCCESS则可以做放大UPSCALE、变换VARIATION等动作此时必须使用原来的请求地址才能正确处理
channelId = originTask.ChannelId
mjProvider, errWithMJ = getMJProviderWithChannelId(c, originTask.ChannelId)
if errWithMJ != nil {
return errWithMJ
}
log.Printf("检测到此操作为放大、变换、重绘获取原channel信息: %d", originTask.ChannelId)
}
midjRequest.Prompt = originTask.Prompt
@ -395,17 +401,12 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyRe
consumeQuota = false
}
mjProvider, errWithMJ := getMJProvider(c, relayMode, channelId, &midjRequest)
if errWithMJ != nil {
return errWithMJ
}
//baseURL := common.ChannelBaseURLs[channelType]
requestURL := getMjRequestPath(c.Request.URL.String())
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
quotaInstance, errWithOA := getQuota(c, mjProvider.GetOriginalModel())
quotaInstance, errWithOA := getQuota(c, midjRequest.Action)
if errWithOA != nil {
return &provider.MidjourneyResponse{
Code: 4,
@ -538,33 +539,32 @@ func getMjRequestPath(path string) string {
return requestURL
}
func getQuota(c *gin.Context, modelName string) (*util.Quota, *types.OpenAIErrorWithStatusCode) {
// modelName = CoverActionToModelName(modelName)
func getQuota(c *gin.Context, action string) (*util.Quota, *types.OpenAIErrorWithStatusCode) {
modelName := CoverActionToModelName(action)
return util.NewQuota(c, modelName, 1000)
}
func getMJProvider(c *gin.Context, relayMode, channel_id int, request *provider.MidjourneyRequest) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
var baseProvider providersBase.ProviderInterface
modelName := ""
if channel_id > 0 {
c.Set("specific_channel_id", channel_id)
func getMJProviderWithRequest(c *gin.Context, relayMode int, request *provider.MidjourneyRequest) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
midjourneyModel, mjErr, _ := GetMjRequestModel(relayMode, request)
if mjErr != nil {
return nil, MidjourneyErrorFromInternal(mjErr.Code, mjErr.Description)
}
if midjourneyModel == "" {
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无效的请求, 无法解析模型")
}
if request != nil {
midjourneyModel, mjErr, _ := GetMjRequestModel(relayMode, request)
if mjErr != nil {
return nil, MidjourneyErrorFromInternal(mjErr.Code, mjErr.Description)
}
if midjourneyModel == "" {
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无效的请求, 无法解析模型")
}
return getMJProvider(c, midjourneyModel)
}
modelName = midjourneyModel
}
func getMJProviderWithChannelId(c *gin.Context, channel_id int) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
c.Set("specific_channel_id", channel_id)
var err error
baseProvider, _, err = relay.GetProvider(c, modelName)
return getMJProvider(c, "")
}
func getMJProvider(c *gin.Context, modelName string) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
baseProvider, _, err := relay.GetProvider(c, modelName)
if err != nil {
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无法获取provider:"+err.Error())
}