🐛 fix: mj error
This commit is contained in:
parent
8f74fadf8a
commit
43818e7c8b
@ -12,7 +12,6 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/controller"
|
||||
"one-api/model"
|
||||
providersBase "one-api/providers/base"
|
||||
provider "one-api/providers/midjourney"
|
||||
"one-api/relay"
|
||||
"one-api/relay/util"
|
||||
@ -141,7 +140,7 @@ func coverMidjourneyTaskDto(originTask *model.Midjourney) (midjourneyTask provid
|
||||
}
|
||||
|
||||
func RelaySwapFace(c *gin.Context) *provider.MidjourneyResponse {
|
||||
mjProvider, errWithMJ := getMJProvider(c, provider.RelayModeMidjourneySwapFace, 0, nil)
|
||||
mjProvider, errWithMJ := getMJProviderWithRequest(c, provider.RelayModeMidjourneySwapFace, nil)
|
||||
if errWithMJ != nil {
|
||||
return errWithMJ
|
||||
}
|
||||
@ -157,7 +156,7 @@ func RelaySwapFace(c *gin.Context) *provider.MidjourneyResponse {
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "sour_base64_and_target_base64_is_required")
|
||||
}
|
||||
|
||||
quotaInstance, errWithOA := getQuota(c, mjProvider.GetOriginalModel())
|
||||
quotaInstance, errWithOA := getQuota(c, provider.MjActionSwapFace)
|
||||
if errWithOA != nil {
|
||||
return &provider.MidjourneyResponse{
|
||||
Code: 4,
|
||||
@ -225,7 +224,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *provider.MidjourneyResponse {
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_no_found")
|
||||
}
|
||||
|
||||
mjProvider, errWithMJ := getMJProvider(c, provider.RelayModeMidjourneyTaskImageSeed, originTask.ChannelId, nil)
|
||||
mjProvider, errWithMJ := getMJProviderWithChannelId(c, originTask.ChannelId)
|
||||
if errWithMJ != nil {
|
||||
return errWithMJ
|
||||
}
|
||||
@ -314,7 +313,6 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *provider.MidjourneyResp
|
||||
}
|
||||
|
||||
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyResponse {
|
||||
channelId := 0
|
||||
userId := c.GetInt("id")
|
||||
consumeQuota := true
|
||||
var midjRequest provider.MidjourneyRequest
|
||||
@ -323,6 +321,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyRe
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "bind_request_body_failed")
|
||||
}
|
||||
|
||||
mjProvider, errWithMJ := getMJProviderWithRequest(c, relayMode, &midjRequest)
|
||||
if errWithMJ != nil {
|
||||
return errWithMJ
|
||||
}
|
||||
|
||||
if relayMode == provider.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
|
||||
mjErr := CoverPlusActionToNormalAction(&midjRequest)
|
||||
if mjErr != nil {
|
||||
@ -378,7 +381,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyRe
|
||||
} else if originTask.Status != "SUCCESS" && relayMode != provider.RelayModeMidjourneyModal {
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_status_not_success")
|
||||
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
||||
channelId = originTask.ChannelId
|
||||
mjProvider, errWithMJ = getMJProviderWithChannelId(c, originTask.ChannelId)
|
||||
if errWithMJ != nil {
|
||||
return errWithMJ
|
||||
}
|
||||
log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %d", originTask.ChannelId)
|
||||
}
|
||||
midjRequest.Prompt = originTask.Prompt
|
||||
@ -395,17 +401,12 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyRe
|
||||
consumeQuota = false
|
||||
}
|
||||
|
||||
mjProvider, errWithMJ := getMJProvider(c, relayMode, channelId, &midjRequest)
|
||||
if errWithMJ != nil {
|
||||
return errWithMJ
|
||||
}
|
||||
|
||||
//baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := getMjRequestPath(c.Request.URL.String())
|
||||
|
||||
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
|
||||
|
||||
quotaInstance, errWithOA := getQuota(c, mjProvider.GetOriginalModel())
|
||||
quotaInstance, errWithOA := getQuota(c, midjRequest.Action)
|
||||
if errWithOA != nil {
|
||||
return &provider.MidjourneyResponse{
|
||||
Code: 4,
|
||||
@ -538,33 +539,32 @@ func getMjRequestPath(path string) string {
|
||||
return requestURL
|
||||
}
|
||||
|
||||
func getQuota(c *gin.Context, modelName string) (*util.Quota, *types.OpenAIErrorWithStatusCode) {
|
||||
// modelName = CoverActionToModelName(modelName)
|
||||
func getQuota(c *gin.Context, action string) (*util.Quota, *types.OpenAIErrorWithStatusCode) {
|
||||
modelName := CoverActionToModelName(action)
|
||||
|
||||
return util.NewQuota(c, modelName, 1000)
|
||||
}
|
||||
|
||||
func getMJProvider(c *gin.Context, relayMode, channel_id int, request *provider.MidjourneyRequest) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
|
||||
var baseProvider providersBase.ProviderInterface
|
||||
modelName := ""
|
||||
if channel_id > 0 {
|
||||
c.Set("specific_channel_id", channel_id)
|
||||
func getMJProviderWithRequest(c *gin.Context, relayMode int, request *provider.MidjourneyRequest) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
|
||||
midjourneyModel, mjErr, _ := GetMjRequestModel(relayMode, request)
|
||||
if mjErr != nil {
|
||||
return nil, MidjourneyErrorFromInternal(mjErr.Code, mjErr.Description)
|
||||
}
|
||||
if midjourneyModel == "" {
|
||||
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无效的请求, 无法解析模型")
|
||||
}
|
||||
|
||||
if request != nil {
|
||||
midjourneyModel, mjErr, _ := GetMjRequestModel(relayMode, request)
|
||||
if mjErr != nil {
|
||||
return nil, MidjourneyErrorFromInternal(mjErr.Code, mjErr.Description)
|
||||
}
|
||||
if midjourneyModel == "" {
|
||||
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无效的请求, 无法解析模型")
|
||||
}
|
||||
return getMJProvider(c, midjourneyModel)
|
||||
}
|
||||
|
||||
modelName = midjourneyModel
|
||||
}
|
||||
func getMJProviderWithChannelId(c *gin.Context, channel_id int) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
|
||||
c.Set("specific_channel_id", channel_id)
|
||||
|
||||
var err error
|
||||
baseProvider, _, err = relay.GetProvider(c, modelName)
|
||||
return getMJProvider(c, "")
|
||||
}
|
||||
|
||||
func getMJProvider(c *gin.Context, modelName string) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
|
||||
baseProvider, _, err := relay.GetProvider(c, modelName)
|
||||
if err != nil {
|
||||
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无法获取provider:"+err.Error())
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user