From 1883c7c795c745523d31da47e0dc54dfe8e7c29a Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Mon, 6 Nov 2023 02:40:23 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9C=A8=E7=BA=BF=E5=85=85=E5=80=BC=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E8=AE=BE=E7=BD=AE=E6=9C=80=E4=BD=8E=E5=85=85=E5=80=BC?= =?UTF-8?q?=E6=95=B0=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/constants.go | 1 + controller/midjourney.go | 49 +++++++++++-- controller/misc.go | 1 + controller/relay-mj.go | 102 +++++++++++++++++++++++++--- controller/relay.go | 2 + controller/topup.go | 23 +++---- model/option.go | 3 + web/src/components/SystemSetting.js | 10 +++ web/src/pages/TopUp/index.js | 4 +- 9 files changed, 167 insertions(+), 28 deletions(-) diff --git a/common/constants.go b/common/constants.go index 19a5c1c9..3628bd13 100644 --- a/common/constants.go +++ b/common/constants.go @@ -17,6 +17,7 @@ var PayAddress = "" var EpayId = "" var EpayKey = "" var Price = 7.3 +var MinCharge = 1 var Footer = "" var Logo = "" var TopUpLink = "" diff --git a/controller/midjourney.go b/controller/midjourney.go index 498d19b0..d9f4f896 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -2,14 +2,17 @@ package controller import ( "bytes" + "context" "encoding/json" "fmt" "github.com/gin-gonic/gin" + "io" "log" "net/http" "one-api/common" "one-api/model" "strconv" + "strings" "time" ) @@ -25,7 +28,9 @@ func UpdateMidjourneyTask() { time.Sleep(time.Duration(15) * time.Second) tasks := model.GetAllUnFinishTasks() if len(tasks) != 0 { + log.Printf("检测到未完成的任务数有: %v", len(tasks)) for _, task := range tasks { + log.Printf("未完成的任务信息: %v", task) midjourneyChannel, err := model.GetChannelById(task.ChannelId, true) if err != nil { log.Printf("UpdateMidjourneyTask: %v", err) @@ -39,6 +44,7 @@ func UpdateMidjourneyTask() { continue } requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId) + log.Printf("requestUrl: %s", requestUrl) req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte(""))) if err != nil { @@ -46,7 +52,16 @@ func UpdateMidjourneyTask() { continue } + // 设置超时时间 + timeout := time.Second * 5 + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + // 使用带有超时的 context 创建新的请求 + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer midjourney-proxy") req.Header.Set("mj-api-secret", midjourneyChannel.Key) resp, err := httpClient.Do(req) if err != nil { @@ -54,11 +69,37 @@ func UpdateMidjourneyTask() { continue } defer resp.Body.Close() + responseBody, err := io.ReadAll(resp.Body) + log.Printf("responseBody: %s", string(responseBody)) var responseItem Midjourney - err = json.NewDecoder(resp.Body).Decode(&responseItem) + // err = json.NewDecoder(resp.Body).Decode(&responseItem) + err = json.Unmarshal(responseBody, &responseItem) if err != nil { - log.Printf("UpdateMidjourneyTask error: %v", err) - continue + if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field Midjourney.status of type string") { + var responseWithoutStatus MidjourneyWithoutStatus + var responseStatus MidjourneyStatus + err1 := json.Unmarshal(responseBody, &responseWithoutStatus) + err2 := json.Unmarshal(responseBody, &responseStatus) + if err1 == nil && err2 == nil { + jsonData, err3 := json.Marshal(responseWithoutStatus) + if err3 != nil { + log.Fatalf("UpdateMidjourneyTask error1: %v", err3) + continue + } + err4 := json.Unmarshal(jsonData, &responseStatus) + if err4 != nil { + log.Fatalf("UpdateMidjourneyTask error2: %v", err4) + continue + } + responseItem.Status = strconv.Itoa(responseStatus.Status) + } else { + log.Printf("UpdateMidjourneyTask error3: %v", err) + continue + } + } else { + log.Printf("UpdateMidjourneyTask error4: %v", err) + continue + } } task.Code = 1 task.Progress = responseItem.Progress @@ -94,7 +135,7 @@ func UpdateMidjourneyTask() { err = task.Update() if err != nil { - log.Printf("UpdateMidjourneyTask error: %v", err) + log.Printf("UpdateMidjourneyTask error5: %v", err) } log.Printf("UpdateMidjourneyTask success: %v", task) } diff --git a/controller/misc.go b/controller/misc.go index cf324b78..316e01c6 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -31,6 +31,7 @@ func GetStatus(c *gin.Context) { "epay_id": common.EpayId, "epay_key": common.EpayKey, "price": common.Price, + "min_charge": common.MinCharge, "turnstile_check": common.TurnstileCheckEnabled, "turnstile_site_key": common.TurnstileSiteKey, "top_up_link": common.TopUpLink, diff --git a/controller/relay-mj.go b/controller/relay-mj.go index 20e86890..4f852341 100644 --- a/controller/relay-mj.go +++ b/controller/relay-mj.go @@ -12,6 +12,7 @@ import ( "one-api/model" "strconv" "strings" + "time" "github.com/gin-gonic/gin" ) @@ -32,6 +33,28 @@ type Midjourney struct { FailReason string `json:"failReason"` } +type MidjourneyStatus struct { + Status int `json:"status"` +} +type MidjourneyWithoutStatus struct { + Id int `json:"id"` + Code int `json:"code"` + UserId int `json:"user_id" gorm:"index"` + Action string `json:"action"` + MjId string `json:"mj_id" gorm:"index"` + Prompt string `json:"prompt"` + PromptEn string `json:"prompt_en"` + Description string `json:"description"` + State string `json:"state"` + SubmitTime int64 `json:"submit_time"` + StartTime int64 `json:"start_time"` + FinishTime int64 `json:"finish_time"` + ImageUrl string `json:"image_url"` + Progress string `json:"progress"` + FailReason string `json:"fail_reason"` + ChannelId int `json:"channel_id"` +} + func RelayMidjourneyImage(c *gin.Context) { taskId := c.Param("id") midjourneyTask := model.GetByMJId(taskId) @@ -54,7 +77,7 @@ func RelayMidjourneyImage(c *gin.Context) { return } c.Header("Content-Type", "image/jpeg") - //c.Header("Content-Length", string(rune(len(data)))) + //c.HeaderBar("Content-Length", string(rune(len(data)))) c.Data(http.StatusOK, "image/jpeg", data) } @@ -115,7 +138,13 @@ func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse { midjourneyTask.SubmitTime = originTask.SubmitTime midjourneyTask.StartTime = originTask.StartTime midjourneyTask.FinishTime = originTask.FinishTime - midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId + midjourneyTask.ImageUrl = "" + if originTask.ImageUrl != "" { + midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId + if originTask.Status != "SUCCESS" { + midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10) + } + } midjourneyTask.Status = originTask.Status midjourneyTask.FailReason = originTask.FailReason midjourneyTask.Action = originTask.Action @@ -157,7 +186,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { } } } - if relayMode == RelayModeMidjourneyImagine { + if relayMode == RelayModeMidjourneyImagine { //绘画任务,此类任务可重复 if midjRequest.Prompt == "" { return &MidjourneyResponse{ Code: 4, @@ -165,7 +194,11 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { } } midjRequest.Action = "IMAGINE" - } else if midjRequest.TaskId != "" { + } else if relayMode == RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复 + midjRequest.Action = "DESCRIBE" + } else if relayMode == RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 + midjRequest.Action = "BLEND" + } else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果 originTask := model.GetByMJId(midjRequest.TaskId) if originTask == nil { return &MidjourneyResponse{ @@ -183,7 +216,17 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { Code: 4, Description: "task_status_is_not_success", } - + } else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理 + channel, err := model.GetChannelById(originTask.ChannelId, false) + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: "channel_not_found", + } + } + c.Set("base_url", channel.GetBaseURL()) + c.Set("channel_id", originTask.ChannelId) + log.Printf("检测到此操作为放大、变换,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL()) } midjRequest.Prompt = originTask.Prompt } else if relayMode == RelayModeMidjourneyChange { @@ -234,6 +277,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { //midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify" fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + log.Printf("fullRequestURL: %s", fullRequestURL) var requestBody io.Reader if isModelMapped { @@ -283,6 +327,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { //if c.Request.Header.Get("Authorization") != "" { // mjToken = strings.Split(c.Request.Header.Get("Authorization"), " ")[1] //} + req.Header.Set("Authorization", "Bearer midjourney-proxy") req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1]) // print request header log.Printf("request header: %s", req.Header) @@ -367,10 +412,14 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { Description: "unmarshal_response_body_failed", } } - if midjResponse.Code == 24 || midjResponse.Code == 21 || midjResponse.Code == 4 { - consumeQuota = false - } + // 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md + //1-提交成功 + // 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}} + // 22-排队中 {"code":22,"description":"排队中,前面还有1个任务","result":"0741798445574458","properties":{"numberOfQueues":1,"discordInstanceId":"1118138338562560102"}} + // 23-队列已满,请稍后再试 {"code":23,"description":"队列已满,请稍后尝试","result":"14001929738841620","properties":{"discordInstanceId":"1118138338562560102"}} + // 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}} + // other: 提交错误,description为错误描述 midjourneyTask := &model.Midjourney{ UserId: userId, Code: midjResponse.Code, @@ -380,7 +429,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { PromptEn: "", Description: midjResponse.Description, State: "", - SubmitTime: 0, + SubmitTime: time.Now().UnixNano() / int64(time.Millisecond), StartTime: 0, FinishTime: 0, ImageUrl: "", @@ -389,9 +438,35 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { FailReason: "", ChannelId: c.GetInt("channel_id"), } - if midjResponse.Code == 4 || midjResponse.Code == 24 { + + if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 { + //非1-提交成功,21-任务已存在和22-排队中,则记录错误原因 midjourneyTask.FailReason = midjResponse.Description + consumeQuota = false } + + if midjResponse.Code == 21 { //21-任务已存在(处理中或者有结果了) + // 将 properties 转换为一个 map + properties, ok := midjResponse.Properties.(map[string]interface{}) + if ok { + imageUrl, ok1 := properties["imageUrl"].(string) + status, ok2 := properties["status"].(string) + if ok1 && ok2 { + midjourneyTask.ImageUrl = imageUrl + midjourneyTask.Status = status + if status == "SUCCESS" { + midjourneyTask.Progress = "100%" + midjourneyTask.StartTime = time.Now().UnixNano() / int64(time.Millisecond) + midjourneyTask.FinishTime = time.Now().UnixNano() / int64(time.Millisecond) + midjResponse.Code = 1 + } + } + } + //修改返回值 + newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1) + responseBody = []byte(newBody) + } + err = midjourneyTask.Insert() if err != nil { return &MidjourneyResponse{ @@ -399,6 +474,13 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { Description: "insert_midjourney_task_failed", } } + + if midjResponse.Code == 22 { //22-排队中,说明任务已存在 + //修改返回值 + newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1) + responseBody = []byte(newBody) + } + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) for k, v := range resp.Header { diff --git a/controller/relay.go b/controller/relay.go index cd8d80b6..cd9d6bf9 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -26,6 +26,8 @@ const ( RelayModeImagesGenerations RelayModeEdits RelayModeMidjourneyImagine + RelayModeMidjourneyDescribe + RelayModeMidjourneyBlend RelayModeMidjourneyChange RelayModeMidjourneyNotify RelayModeMidjourneyTaskFetch diff --git a/controller/topup.go b/controller/topup.go index 5d945878..9f3a5032 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -38,14 +38,14 @@ func GetEpayClient() *epay.Client { return withUrl } -func GetAmount(count float64, user model.User) float64 { +func GetPayMoney(amount float64, user model.User) float64 { // 别问为什么用float64,问就是这么点钱没必要 topupGroupRatio := common.GetTopupGroupRatio(user.Group) if topupGroupRatio == 0 { topupGroupRatio = 1 } - amount := count * common.Price * topupGroupRatio - return amount + money := amount * common.Price * topupGroupRatio + return money } func RequestEpay(c *gin.Context) { @@ -55,14 +55,14 @@ func RequestEpay(c *gin.Context) { c.JSON(200, gin.H{"message": err.Error(), "data": 10}) return } - if req.Amount < 1 { - c.JSON(200, gin.H{"message": "充值金额不能小于1", "data": 10}) + if req.Amount < common.MinCharge { + c.JSON(200, gin.H{"message": fmt.Sprintf("最小充值数量为%d", common.MinCharge), "data": 10}) return } id := c.GetInt("id") user, _ := model.GetUserById(id, false) - amount := GetAmount(float64(req.Amount), *user) + needToPay := GetPayMoney(float64(req.Amount), *user) if req.PaymentMethod == "zfb" { req.PaymentMethod = "alipay" @@ -74,7 +74,6 @@ func RequestEpay(c *gin.Context) { returnUrl, _ := url.Parse(common.ServerAddress + "/log") notifyUrl, _ := url.Parse(common.ServerAddress + "/api/user/epay/notify") tradeNo := strconv.FormatInt(time.Now().Unix(), 10) - payMoney := amount client := GetEpayClient() if client == nil { c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"}) @@ -84,7 +83,7 @@ func RequestEpay(c *gin.Context) { Type: epay.PurchaseType(req.PaymentMethod), ServiceTradeNo: "A" + tradeNo, Name: "B" + tradeNo, - Money: strconv.FormatFloat(payMoney, 'f', 2, 64), + Money: strconv.FormatFloat(needToPay, 'f', 2, 64), Device: epay.PC, NotifyUrl: notifyUrl, ReturnUrl: returnUrl, @@ -96,7 +95,7 @@ func RequestEpay(c *gin.Context) { topUp := &model.TopUp{ UserId: id, Amount: req.Amount, - Money: int(amount), + Money: int(needToPay), TradeNo: "A" + tradeNo, CreateTime: time.Now().Unix(), Status: "pending", @@ -167,11 +166,11 @@ func RequestAmount(c *gin.Context) { c.JSON(200, gin.H{"message": "error", "data": "参数错误"}) return } - if req.Amount < 1 { - c.JSON(200, gin.H{"message": "error", "data": "充值金额不能小于1"}) + if req.Amount < common.MinCharge { + c.JSON(200, gin.H{"message": fmt.Sprintf("最小充值数量为%d", common.MinCharge), "data": 10}) return } id := c.GetInt("id") user, _ := model.GetUserById(id, false) - c.JSON(200, gin.H{"message": "success", "data": GetAmount(float64(req.Amount), *user)}) + c.JSON(200, gin.H{"message": "success", "data": GetPayMoney(float64(req.Amount), *user)}) } diff --git a/model/option.go b/model/option.go index a56ab4a5..3068686f 100644 --- a/model/option.go +++ b/model/option.go @@ -57,6 +57,7 @@ func InitOptionMap() { common.OptionMap["EpayId"] = "" common.OptionMap["EpayKey"] = "" common.OptionMap["Price"] = strconv.FormatFloat(common.Price, 'f', -1, 64) + common.OptionMap["MinCharge"] = strconv.Itoa(common.MinCharge) common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString() common.OptionMap["GitHubClientId"] = "" common.OptionMap["GitHubClientSecret"] = "" @@ -185,6 +186,8 @@ func updateOptionMap(key string, value string) (err error) { common.EpayKey = value case "Price": common.Price, _ = strconv.ParseFloat(value, 64) + case "MinCharge": + common.MinCharge, _ = strconv.Atoi(value) case "TopupGroupRatio": err = common.UpdateTopupGroupRatioByJSONString(value) case "GitHubClientId": diff --git a/web/src/components/SystemSetting.js b/web/src/components/SystemSetting.js index 2197f3fb..ef067845 100644 --- a/web/src/components/SystemSetting.js +++ b/web/src/components/SystemSetting.js @@ -20,6 +20,7 @@ const SystemSetting = () => { EpayId: '', EpayKey: '', Price: 7.3, + MinCharge: 1, TopupGroupRatio: '', PayAddress: '', Footer: '', @@ -308,6 +309,15 @@ const SystemSetting = () => { min={0} onChange={handleInputChange} /> +