🐛 fix: mj error
This commit is contained in:
parent
8f74fadf8a
commit
43818e7c8b
@ -12,7 +12,6 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/controller"
|
"one-api/controller"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
providersBase "one-api/providers/base"
|
|
||||||
provider "one-api/providers/midjourney"
|
provider "one-api/providers/midjourney"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
"one-api/relay/util"
|
"one-api/relay/util"
|
||||||
@ -141,7 +140,7 @@ func coverMidjourneyTaskDto(originTask *model.Midjourney) (midjourneyTask provid
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RelaySwapFace(c *gin.Context) *provider.MidjourneyResponse {
|
func RelaySwapFace(c *gin.Context) *provider.MidjourneyResponse {
|
||||||
mjProvider, errWithMJ := getMJProvider(c, provider.RelayModeMidjourneySwapFace, 0, nil)
|
mjProvider, errWithMJ := getMJProviderWithRequest(c, provider.RelayModeMidjourneySwapFace, nil)
|
||||||
if errWithMJ != nil {
|
if errWithMJ != nil {
|
||||||
return errWithMJ
|
return errWithMJ
|
||||||
}
|
}
|
||||||
@ -157,7 +156,7 @@ func RelaySwapFace(c *gin.Context) *provider.MidjourneyResponse {
|
|||||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "sour_base64_and_target_base64_is_required")
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "sour_base64_and_target_base64_is_required")
|
||||||
}
|
}
|
||||||
|
|
||||||
quotaInstance, errWithOA := getQuota(c, mjProvider.GetOriginalModel())
|
quotaInstance, errWithOA := getQuota(c, provider.MjActionSwapFace)
|
||||||
if errWithOA != nil {
|
if errWithOA != nil {
|
||||||
return &provider.MidjourneyResponse{
|
return &provider.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
@ -225,7 +224,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *provider.MidjourneyResponse {
|
|||||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_no_found")
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_no_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
mjProvider, errWithMJ := getMJProvider(c, provider.RelayModeMidjourneyTaskImageSeed, originTask.ChannelId, nil)
|
mjProvider, errWithMJ := getMJProviderWithChannelId(c, originTask.ChannelId)
|
||||||
if errWithMJ != nil {
|
if errWithMJ != nil {
|
||||||
return errWithMJ
|
return errWithMJ
|
||||||
}
|
}
|
||||||
@ -314,7 +313,6 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *provider.MidjourneyResp
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyResponse {
|
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyResponse {
|
||||||
channelId := 0
|
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
consumeQuota := true
|
consumeQuota := true
|
||||||
var midjRequest provider.MidjourneyRequest
|
var midjRequest provider.MidjourneyRequest
|
||||||
@ -323,6 +321,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyRe
|
|||||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "bind_request_body_failed")
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "bind_request_body_failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mjProvider, errWithMJ := getMJProviderWithRequest(c, relayMode, &midjRequest)
|
||||||
|
if errWithMJ != nil {
|
||||||
|
return errWithMJ
|
||||||
|
}
|
||||||
|
|
||||||
if relayMode == provider.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
|
if relayMode == provider.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
|
||||||
mjErr := CoverPlusActionToNormalAction(&midjRequest)
|
mjErr := CoverPlusActionToNormalAction(&midjRequest)
|
||||||
if mjErr != nil {
|
if mjErr != nil {
|
||||||
@ -378,7 +381,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyRe
|
|||||||
} else if originTask.Status != "SUCCESS" && relayMode != provider.RelayModeMidjourneyModal {
|
} else if originTask.Status != "SUCCESS" && relayMode != provider.RelayModeMidjourneyModal {
|
||||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_status_not_success")
|
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_status_not_success")
|
||||||
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
||||||
channelId = originTask.ChannelId
|
mjProvider, errWithMJ = getMJProviderWithChannelId(c, originTask.ChannelId)
|
||||||
|
if errWithMJ != nil {
|
||||||
|
return errWithMJ
|
||||||
|
}
|
||||||
log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %d", originTask.ChannelId)
|
log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %d", originTask.ChannelId)
|
||||||
}
|
}
|
||||||
midjRequest.Prompt = originTask.Prompt
|
midjRequest.Prompt = originTask.Prompt
|
||||||
@ -395,17 +401,12 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyRe
|
|||||||
consumeQuota = false
|
consumeQuota = false
|
||||||
}
|
}
|
||||||
|
|
||||||
mjProvider, errWithMJ := getMJProvider(c, relayMode, channelId, &midjRequest)
|
|
||||||
if errWithMJ != nil {
|
|
||||||
return errWithMJ
|
|
||||||
}
|
|
||||||
|
|
||||||
//baseURL := common.ChannelBaseURLs[channelType]
|
//baseURL := common.ChannelBaseURLs[channelType]
|
||||||
requestURL := getMjRequestPath(c.Request.URL.String())
|
requestURL := getMjRequestPath(c.Request.URL.String())
|
||||||
|
|
||||||
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
|
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
|
||||||
|
|
||||||
quotaInstance, errWithOA := getQuota(c, mjProvider.GetOriginalModel())
|
quotaInstance, errWithOA := getQuota(c, midjRequest.Action)
|
||||||
if errWithOA != nil {
|
if errWithOA != nil {
|
||||||
return &provider.MidjourneyResponse{
|
return &provider.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
@ -538,33 +539,32 @@ func getMjRequestPath(path string) string {
|
|||||||
return requestURL
|
return requestURL
|
||||||
}
|
}
|
||||||
|
|
||||||
func getQuota(c *gin.Context, modelName string) (*util.Quota, *types.OpenAIErrorWithStatusCode) {
|
func getQuota(c *gin.Context, action string) (*util.Quota, *types.OpenAIErrorWithStatusCode) {
|
||||||
// modelName = CoverActionToModelName(modelName)
|
modelName := CoverActionToModelName(action)
|
||||||
|
|
||||||
return util.NewQuota(c, modelName, 1000)
|
return util.NewQuota(c, modelName, 1000)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMJProvider(c *gin.Context, relayMode, channel_id int, request *provider.MidjourneyRequest) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
|
func getMJProviderWithRequest(c *gin.Context, relayMode int, request *provider.MidjourneyRequest) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
|
||||||
var baseProvider providersBase.ProviderInterface
|
midjourneyModel, mjErr, _ := GetMjRequestModel(relayMode, request)
|
||||||
modelName := ""
|
if mjErr != nil {
|
||||||
if channel_id > 0 {
|
return nil, MidjourneyErrorFromInternal(mjErr.Code, mjErr.Description)
|
||||||
c.Set("specific_channel_id", channel_id)
|
}
|
||||||
|
if midjourneyModel == "" {
|
||||||
|
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无效的请求, 无法解析模型")
|
||||||
}
|
}
|
||||||
|
|
||||||
if request != nil {
|
return getMJProvider(c, midjourneyModel)
|
||||||
midjourneyModel, mjErr, _ := GetMjRequestModel(relayMode, request)
|
}
|
||||||
if mjErr != nil {
|
|
||||||
return nil, MidjourneyErrorFromInternal(mjErr.Code, mjErr.Description)
|
|
||||||
}
|
|
||||||
if midjourneyModel == "" {
|
|
||||||
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无效的请求, 无法解析模型")
|
|
||||||
}
|
|
||||||
|
|
||||||
modelName = midjourneyModel
|
func getMJProviderWithChannelId(c *gin.Context, channel_id int) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
|
||||||
}
|
c.Set("specific_channel_id", channel_id)
|
||||||
|
|
||||||
var err error
|
return getMJProvider(c, "")
|
||||||
baseProvider, _, err = relay.GetProvider(c, modelName)
|
}
|
||||||
|
|
||||||
|
func getMJProvider(c *gin.Context, modelName string) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
|
||||||
|
baseProvider, _, err := relay.GetProvider(c, modelName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无法获取provider:"+err.Error())
|
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无法获取provider:"+err.Error())
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user