diff --git a/README.md b/README.md index 896d2c7b..f3b86860 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,16 @@ _本项目是基于[one-api](https://github.com/songquanpeng/one-api)二次开 请查看[文档](https://github.com/MartialBE/one-api/wiki) +## 感谢 + +- 本程序使用了以下开源项目 + - [one-api](https://github.com/songquanpeng/one-api)为本项目的基础 + - [Berry Free React Admin Template](https://github.com/codedthemes/berry-free-react-admin-template)为本项目的前端界面 + - [minimal-ui-kit](https://github.com/minimal-ui-kit/material-kit-react),使用了其中的部分样式 + - [new api](https://github.com/Calcium-Ion/new-api),Midjourney 模块的代码来源于此 + +感谢以上项目的作者和贡献者 + ## 其他 diff --git a/common/constants.go b/common/constants.go index 97ed4212..b44ca170 100644 --- a/common/constants.go +++ b/common/constants.go @@ -37,6 +37,9 @@ var WeChatAuthEnabled = false var TurnstileCheckEnabled = false var RegisterEnabled = true +// mj +var MjNotifyEnabled = false + var EmailDomainRestrictionEnabled = false var EmailDomainWhitelist = []string{ "gmail.com", @@ -161,6 +164,7 @@ const ( ChannelTypeGroq = 31 ChannelTypeBedrock = 32 ChannelTypeLingyi = 33 + ChannelTypeMidjourney = 34 ) var ChannelBaseURLs = []string{ @@ -198,6 +202,7 @@ var ChannelBaseURLs = []string{ "https://api.groq.com/openai", //31 "", //32 "https://api.lingyiwanwu.com", //33 + "", //34 } const ( diff --git a/common/go-channel.go b/common/go-channel.go new file mode 100644 index 00000000..4f00dff2 --- /dev/null +++ b/common/go-channel.go @@ -0,0 +1,32 @@ +package common + +import ( + "fmt" + "runtime/debug" +) + +func SafeGoroutine(f func()) { + go func() { + defer func() { + if r := recover(); r != nil { + SysError(fmt.Sprintf("child goroutine panic occured: error: %v, stack: %s", r, string(debug.Stack()))) + } + }() + f() + }() +} + +func SafeSend(ch chan bool, value bool) (closed bool) { + defer func() { + // Recover from panic if one occured. A panic would mean the channel was closed. + if recover() != nil { + closed = true + } + }() + + // This will panic if the channel is closed. + ch <- value + + // If the code reaches here, then the channel was not closed. + return false +} diff --git a/common/logger.go b/common/logger.go index d2548679..73a3539c 100644 --- a/common/logger.go +++ b/common/logger.go @@ -106,7 +106,10 @@ func logHelper(ctx context.Context, level string, msg string) { if level == loggerINFO { writer = gin.DefaultWriter } - id := ctx.Value(RequestIdKey) + id, ok := ctx.Value(RequestIdKey).(string) + if !ok { + id = "unknown" + } now := time.Now() _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) logCount++ // we don't need accurate count, so no lock here diff --git a/common/requester/http_requester.go b/common/requester/http_requester.go index 17d870a9..ccc27eb8 100644 --- a/common/requester/http_requester.go +++ b/common/requester/http_requester.go @@ -23,6 +23,7 @@ type HTTPRequester struct { CreateFormBuilder func(io.Writer) FormBuilder ErrorHandler HttpErrorHandler proxyAddr string + Context context.Context } // NewHTTPRequester 创建一个新的 HTTPRequester 实例。 @@ -37,6 +38,7 @@ func NewHTTPRequester(proxyAddr string, errorHandler HttpErrorHandler) *HTTPRequ }, ErrorHandler: errorHandler, proxyAddr: proxyAddr, + Context: context.Background(), } } @@ -47,18 +49,18 @@ type requestOptions struct { type requestOption func(*requestOptions) -func (r *HTTPRequester) getContext() context.Context { +func (r *HTTPRequester) setProxy() context.Context { if r.proxyAddr == "" { - return context.Background() + return r.Context } // 如果是以 socks5:// 开头的地址,那么使用 socks5 代理 if strings.HasPrefix(r.proxyAddr, "socks5://") { - return context.WithValue(context.Background(), ProxySock5AddrKey, r.proxyAddr) + return context.WithValue(r.Context, ProxySock5AddrKey, r.proxyAddr) } // 否则使用 http 代理 - return context.WithValue(context.Background(), ProxyHTTPAddrKey, r.proxyAddr) + return context.WithValue(r.Context, ProxyHTTPAddrKey, r.proxyAddr) } @@ -71,7 +73,7 @@ func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption) for _, setter := range setters { setter(args) } - req, err := r.requestBuilder.Build(r.getContext(), method, url, args.body, args.header) + req, err := r.requestBuilder.Build(r.setProxy(), method, url, args.body, args.header) if err != nil { return nil, err } diff --git a/controller/midjourney.go b/controller/midjourney.go new file mode 100644 index 00000000..8286807a --- /dev/null +++ b/controller/midjourney.go @@ -0,0 +1,285 @@ +// Author: Calcium-Ion +// GitHub: https://github.com/Calcium-Ion/new-api +// Path: controller/midjourney.go +package controller + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/common/requester" + "one-api/model" + provider "one-api/providers/midjourney" + "time" + + "github.com/gin-gonic/gin" +) + +var activeMidjourneyTask = make(chan bool, 1) + +func InitMidjourneyTask() { + common.SafeGoroutine(func() { + midjourneyTask() + }) + + ActivateUpdateMidjourneyTaskBulk() +} + +func midjourneyTask() { + for { + select { + case <-activeMidjourneyTask: + UpdateMidjourneyTaskBulk() + } + } +} + +func ActivateUpdateMidjourneyTaskBulk() { + if len(activeMidjourneyTask) == 0 { + activeMidjourneyTask <- true + } +} + +func UpdateMidjourneyTaskBulk() { + ctx := context.WithValue(context.Background(), common.RequestIdKey, "MidjourneyTask") + for { + common.LogInfo(ctx, "running") + + tasks := model.GetAllUnFinishTasks() + + // 如果没有未完成的任务,则等待 + if len(tasks) == 0 { + for len(activeMidjourneyTask) > 0 { + <-activeMidjourneyTask + } + common.LogInfo(ctx, "no tasks, waiting...") + return + } + + common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) + taskChannelM := make(map[int][]string) + taskM := make(map[string]*model.Midjourney) + nullTaskIds := make([]int, 0) + for _, task := range tasks { + if task.MjId == "" { + // 统计失败的未完成任务 + nullTaskIds = append(nullTaskIds, task.Id) + continue + } + taskM[task.MjId] = task + taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId) + } + if len(nullTaskIds) > 0 { + err := model.MjBulkUpdateByTaskIds(nullTaskIds, map[string]any{ + "status": "FAILURE", + "progress": "100%", + }) + if err != nil { + common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err)) + } else { + common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds)) + } + } + if len(taskChannelM) == 0 { + continue + } + + for channelId, taskIds := range taskChannelM { + common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) + if len(taskIds) == 0 { + continue + } + midjourneyChannel := model.ChannelGroup.GetChannel(channelId) + if midjourneyChannel == nil { + err := model.MjBulkUpdate(taskIds, map[string]any{ + "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), + "status": "FAILURE", + "progress": "100%", + }) + common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err)) + continue + } + requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL) + + body, _ := json.Marshal(map[string]any{ + "ids": taskIds, + }) + req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body)) + if err != nil { + common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err)) + continue + } + // 设置超时时间 + timeout := time.Second * 5 + ctx, cancel := context.WithTimeout(context.Background(), timeout) + // 使用带有超时的 context 创建新的请求 + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("mj-api-secret", midjourneyChannel.Key) + resp, err := requester.HTTPClient.Do(req) + if err != nil { + common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) + continue + } + if resp.StatusCode != http.StatusOK { + common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + continue + } + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) + continue + } + var responseItems []provider.MidjourneyDto + err = json.Unmarshal(responseBody, &responseItems) + if err != nil { + common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) + continue + } + resp.Body.Close() + req.Body.Close() + cancel() + + for _, responseItem := range responseItems { + task := taskM[responseItem.MjId] + + useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime + // 如果时间超过一小时,且进度不是100%,则认为任务失败 + if useTime > 3600000 && task.Progress != "100%" { + responseItem.FailReason = "上游任务超时(超过1小时)" + responseItem.Status = "FAILURE" + } + if !checkMjTaskNeedUpdate(task, responseItem) { + continue + } + task.Code = 1 + task.Progress = responseItem.Progress + task.PromptEn = responseItem.PromptEn + task.State = responseItem.State + task.SubmitTime = responseItem.SubmitTime + task.StartTime = responseItem.StartTime + task.FinishTime = responseItem.FinishTime + task.ImageUrl = responseItem.ImageUrl + task.Status = responseItem.Status + task.FailReason = responseItem.FailReason + if responseItem.Properties != nil { + propertiesStr, _ := json.Marshal(responseItem.Properties) + task.Properties = string(propertiesStr) + } + if responseItem.Buttons != nil { + buttonStr, _ := json.Marshal(responseItem.Buttons) + task.Buttons = string(buttonStr) + } + + if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") { + common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) + task.Progress = "100%" + err = model.CacheUpdateUserQuota(task.UserId) + if err != nil { + common.LogError(ctx, "error update user quota cache: "+err.Error()) + } else { + quota := task.Quota + if quota != 0 { + err = model.IncreaseUserQuota(task.UserId, quota) + if err != nil { + common.LogError(ctx, "fail to increase user quota: "+err.Error()) + } + logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota)) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + } + } + } + err = task.Update() + if err != nil { + common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) + } + } + } + time.Sleep(time.Duration(15) * time.Second) + } +} + +func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask provider.MidjourneyDto) bool { + if oldTask.Code != 1 { + return true + } + if oldTask.Progress != newTask.Progress { + return true + } + if oldTask.PromptEn != newTask.PromptEn { + return true + } + if oldTask.State != newTask.State { + return true + } + if oldTask.SubmitTime != newTask.SubmitTime { + return true + } + if oldTask.StartTime != newTask.StartTime { + return true + } + if oldTask.FinishTime != newTask.FinishTime { + return true + } + if oldTask.ImageUrl != newTask.ImageUrl { + return true + } + if oldTask.Status != newTask.Status { + return true + } + if oldTask.FailReason != newTask.FailReason { + return true + } + if oldTask.FinishTime != newTask.FinishTime { + return true + } + if oldTask.Progress != "100%" && newTask.FailReason != "" { + return true + } + + return false +} + +func GetAllMidjourney(c *gin.Context) { + var params model.TaskQueryParams + if err := c.ShouldBindQuery(¶ms); err != nil { + common.APIRespondWithError(c, http.StatusOK, err) + return + } + + midjourneys, err := model.GetAllTasks(¶ms) + if err != nil { + common.APIRespondWithError(c, http.StatusOK, err) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": midjourneys, + }) +} + +func GetUserMidjourney(c *gin.Context) { + userId := c.GetInt("id") + + var params model.TaskQueryParams + if err := c.ShouldBindQuery(¶ms); err != nil { + common.APIRespondWithError(c, http.StatusOK, err) + return + } + + midjourneys, err := model.GetAllUserTask(userId, ¶ms) + if err != nil { + common.APIRespondWithError(c, http.StatusOK, err) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": midjourneys, + }) +} diff --git a/controller/misc.go b/controller/misc.go index fd723a49..810e02d6 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -40,6 +40,7 @@ func GetStatus(c *gin.Context) { "quota_per_unit": common.QuotaPerUnit, "display_in_currency": common.DisplayInCurrencyEnabled, "telegram_bot": telegram_bot, + "mj_notify_enabled": common.MjNotifyEnabled, }, }) } diff --git a/main.go b/main.go index 18b31de8..02417e03 100644 --- a/main.go +++ b/main.go @@ -45,6 +45,8 @@ func main() { // Initialize Telegram bot telegram.InitTelegramBot() + controller.InitMidjourneyTask() + initHttpServer() } diff --git a/middleware/auth.go b/middleware/auth.go index 2537ceae..91e3dd56 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -83,43 +83,54 @@ func RootAuth() func(c *gin.Context) { } } -func TokenAuth() func(c *gin.Context) { - return func(c *gin.Context) { - key := c.Request.Header.Get("Authorization") - key = strings.TrimPrefix(key, "Bearer ") - key = strings.TrimPrefix(key, "sk-") - parts := strings.Split(key, "-") - key = parts[0] - token, err := model.ValidateUserToken(key) - if err != nil { - abortWithMessage(c, http.StatusUnauthorized, err.Error()) - return - } - userEnabled, err := model.CacheIsUserEnabled(token.UserId) - if err != nil { - abortWithMessage(c, http.StatusInternalServerError, err.Error()) - return - } - if !userEnabled { - abortWithMessage(c, http.StatusForbidden, "用户已被封禁") - return - } - c.Set("id", token.UserId) - c.Set("token_id", token.Id) - c.Set("token_name", token.Name) - if len(parts) > 1 { - if model.IsAdmin(token.UserId) { - channelId := common.String2Int(parts[1]) - if channelId == 0 { - abortWithMessage(c, http.StatusForbidden, "无效的渠道 Id") - return - } - c.Set("specific_channel_id", channelId) - } else { - abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") +func tokenAuth(c *gin.Context, key string) { + key = strings.TrimPrefix(key, "Bearer ") + key = strings.TrimPrefix(key, "sk-") + parts := strings.Split(key, "-") + key = parts[0] + token, err := model.ValidateUserToken(key) + if err != nil { + abortWithMessage(c, http.StatusUnauthorized, err.Error()) + return + } + userEnabled, err := model.CacheIsUserEnabled(token.UserId) + if err != nil { + abortWithMessage(c, http.StatusInternalServerError, err.Error()) + return + } + if !userEnabled { + abortWithMessage(c, http.StatusForbidden, "用户已被封禁") + return + } + c.Set("id", token.UserId) + c.Set("token_id", token.Id) + c.Set("token_name", token.Name) + if len(parts) > 1 { + if model.IsAdmin(token.UserId) { + channelId := common.String2Int(parts[1]) + if channelId == 0 { + abortWithMessage(c, http.StatusForbidden, "无效的渠道 Id") return } + c.Set("specific_channel_id", channelId) + } else { + abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") + return } - c.Next() + } + c.Next() +} + +func OpenaiAuth() func(c *gin.Context) { + return func(c *gin.Context) { + key := c.Request.Header.Get("Authorization") + tokenAuth(c, key) + } +} + +func MjAuth() func(c *gin.Context) { + return func(c *gin.Context) { + key := c.Request.Header.Get("mj-api-secret") + tokenAuth(c, key) } } diff --git a/model/balancer.go b/model/balancer.go index 44bdcb9e..ecf2721d 100644 --- a/model/balancer.go +++ b/model/balancer.go @@ -114,6 +114,17 @@ func (cc *ChannelsChooser) GetGroupModels(group string) ([]string, error) { return models, nil } +func (cc *ChannelsChooser) GetChannel(channelId int) *Channel { + cc.RLock() + defer cc.RUnlock() + + if choice, ok := cc.Channels[channelId]; ok { + return choice.Channel + } + + return nil +} + var ChannelGroup = ChannelsChooser{} func (cc *ChannelsChooser) Load() { diff --git a/model/main.go b/model/main.go index fcdb952e..61e8c9b6 100644 --- a/model/main.go +++ b/model/main.go @@ -139,6 +139,10 @@ func InitDB() (err error) { if err != nil { return err } + err = db.AutoMigrate(&Midjourney{}) + if err != nil { + return err + } common.SysLog("database migrated") err = createRootAccountIfNeed() return err diff --git a/model/midjourney.go b/model/midjourney.go new file mode 100644 index 00000000..a300db3c --- /dev/null +++ b/model/midjourney.go @@ -0,0 +1,182 @@ +// Copyright (c) 2024 Calcium-Ion +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// +// Author: Calcium-Ion +// GitHub: https://github.com/Calcium-Ion/new-api + +package model + +type Midjourney struct { + Id int `json:"id"` + Code int `json:"code"` + UserId int `json:"user_id" gorm:"index"` + Action string `json:"action" gorm:"type:varchar(40);index"` + MjId string `json:"mj_id" gorm:"index"` + Prompt string `json:"prompt"` + PromptEn string `json:"prompt_en"` + Description string `json:"description"` + State string `json:"state"` + SubmitTime int64 `json:"submit_time" gorm:"index"` + StartTime int64 `json:"start_time" gorm:"index"` + FinishTime int64 `json:"finish_time" gorm:"index"` + ImageUrl string `json:"image_url"` + Status string `json:"status" gorm:"type:varchar(20);index"` + Progress string `json:"progress" gorm:"type:varchar(30);index"` + FailReason string `json:"fail_reason"` + ChannelId int `json:"channel_id"` + Quota int `json:"quota"` + Buttons string `json:"buttons"` + Properties string `json:"properties"` +} + +// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段 +type TaskQueryParams struct { + ChannelID int `form:"channel_id"` + MjID string `form:"mj_id"` + StartTimestamp int `form:"start_timestamp"` + EndTimestamp int `form:"end_timestamp"` + PaginationParams +} + +var allowedMidjourneyOrderFields = map[string]bool{ + "id": true, + "user_id": true, + "code": true, + "action": true, + "mj_id": true, + "submit_time": true, + "start_time": true, + "finish_time": true, + "status": true, + "channel_id": true, +} + +func GetAllUserTask(userId int, params *TaskQueryParams) (*DataResult[Midjourney], error) { + var tasks []*Midjourney + + // 初始化查询构建器 + query := DB.Where("user_id = ?", userId) + + if params.MjID != "" { + query = query.Where("mj_id = ?", params.MjID) + } + if params.StartTimestamp != 0 { + // 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析 + query = query.Where("submit_time >= ?", params.StartTimestamp) + } + if params.EndTimestamp != 0 { + query = query.Where("submit_time <= ?", params.EndTimestamp) + } + + return PaginateAndOrder(query, ¶ms.PaginationParams, &tasks, allowedMidjourneyOrderFields) +} + +func GetAllTasks(params *TaskQueryParams) (*DataResult[Midjourney], error) { + var tasks []*Midjourney + + // 初始化查询构建器 + query := DB + + // 添加过滤条件 + if params.ChannelID != 0 { + query = query.Where("channel_id = ?", params.ChannelID) + } + if params.MjID != "" { + query = query.Where("mj_id = ?", params.MjID) + } + if params.StartTimestamp != 0 { + query = query.Where("submit_time >= ?", params.StartTimestamp) + } + if params.EndTimestamp != 0 { + query = query.Where("submit_time <= ?", params.EndTimestamp) + } + + return PaginateAndOrder(query, ¶ms.PaginationParams, &tasks, allowedMidjourneyOrderFields) +} + +func GetAllUnFinishTasks() []*Midjourney { + var tasks []*Midjourney + // get all tasks progress is not 100% + err := DB.Where("progress != ?", "100%").Find(&tasks).Error + if err != nil { + return nil + } + return tasks +} + +func GetByOnlyMJId(mjId string) *Midjourney { + var mj *Midjourney + err := DB.Where("mj_id = ?", mjId).First(&mj).Error + if err != nil { + return nil + } + return mj +} + +func GetByMJId(userId int, mjId string) *Midjourney { + var mj *Midjourney + err := DB.Where("user_id = ? and mj_id = ?", userId, mjId).First(&mj).Error + if err != nil { + return nil + } + return mj +} + +func GetByMJIds(userId int, mjIds []string) []*Midjourney { + var mj []*Midjourney + err := DB.Where("user_id = ? and mj_id in (?)", userId, mjIds).Find(&mj).Error + if err != nil { + return nil + } + return mj +} + +func GetMjByuId(id int) *Midjourney { + var mj *Midjourney + err := DB.Where("id = ?", id).First(&mj).Error + if err != nil { + return nil + } + return mj +} + +func UpdateProgress(id int, progress string) error { + return DB.Model(&Midjourney{}).Where("id = ?", id).Update("progress", progress).Error +} + +func (midjourney *Midjourney) Insert() error { + return DB.Create(midjourney).Error +} + +func (midjourney *Midjourney) Update() error { + return DB.Save(midjourney).Error +} + +func MjBulkUpdate(mjIds []string, params map[string]any) error { + return DB.Model(&Midjourney{}). + Where("mj_id in (?)", mjIds). + Updates(params).Error +} + +func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error { + return DB.Model(&Midjourney{}). + Where("id in (?)", taskIDs). + Updates(params).Error +} diff --git a/model/option.go b/model/option.go index 62714ce1..94103bd8 100644 --- a/model/option.go +++ b/model/option.go @@ -74,6 +74,8 @@ func InitOptionMap() { common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) common.OptionMap["RetryCooldownSeconds"] = strconv.Itoa(common.RetryCooldownSeconds) + common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(common.MjNotifyEnabled) + common.OptionMapRWMutex.Unlock() loadOptionsFromDatabase() } @@ -138,6 +140,7 @@ var optionBoolMap = map[string]*bool{ "LogConsumeEnabled": &common.LogConsumeEnabled, "DisplayInCurrencyEnabled": &common.DisplayInCurrencyEnabled, "DisplayTokenStatEnabled": &common.DisplayTokenStatEnabled, + "MjNotifyEnabled": &common.MjNotifyEnabled, } var optionStringMap = map[string]*string{ diff --git a/model/price.go b/model/price.go index 3f5d14e0..7a75e292 100644 --- a/model/price.go +++ b/model/price.go @@ -301,5 +301,33 @@ func GetDefaultPrice() []*Price { }) } + var DefaultMJPrice = map[string]float64{ + "mj_imagine": 50, + "mj_variation": 50, + "mj_reroll": 50, + "mj_blend": 50, + "mj_modal": 50, + "mj_zoom": 50, + "mj_shorten": 50, + "mj_high_variation": 50, + "mj_low_variation": 50, + "mj_pan": 50, + "mj_inpaint": 0, + "mj_custom_zoom": 0, + "mj_describe": 25, + "mj_upscale": 25, + "swap_face": 25, + } + + for model, mjPrice := range DefaultMJPrice { + prices = append(prices, &Price{ + Model: model, + Type: TimesPriceType, + ChannelType: common.ChannelTypeMidjourney, + Input: mjPrice, + Output: mjPrice, + }) + } + return prices } diff --git a/providers/midjourney/base.go b/providers/midjourney/base.go new file mode 100644 index 00000000..5b6a2569 --- /dev/null +++ b/providers/midjourney/base.go @@ -0,0 +1,121 @@ +package midjourney + +import ( + "bytes" + "context" + "encoding/json" + "io" + "log" + "net/http" + "one-api/common" + "one-api/common/requester" + "one-api/model" + "one-api/providers/base" + "time" +) + +// 定义供应商工厂 +type MidjourneyProviderFactory struct{} + +// 创建 MidjourneyProvider +func (f MidjourneyProviderFactory) Create(channel *model.Channel) base.ProviderInterface { + return &MidjourneyProvider{ + BaseProvider: base.BaseProvider{ + Config: getConfig(), + Channel: channel, + Requester: requester.NewHTTPRequester(*channel.Proxy, nil), + }, + } +} + +func getConfig() base.ProviderConfig { + return base.ProviderConfig{ + BaseURL: "", + } +} + +type MidjourneyProvider struct { + base.BaseProvider +} + +func (p *MidjourneyProvider) Send(timeout int, requestURL string) (*MidjourneyResponseWithStatusCode, []byte, error) { + var nullBytes []byte + var mapResult map[string]interface{} + if p.Context.Request.Method != "GET" { + err := json.NewDecoder(p.Context.Request.Body).Decode(&mapResult) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err + } + delete(mapResult, "accountFilter") + if !common.MjNotifyEnabled { + delete(mapResult, "notifyHook") + } + } + + reqBody, err := json.Marshal(mapResult) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err + } + + fullRequestURL := p.GetFullRequestURL(requestURL, "") + + var cancel context.CancelFunc + p.Requester.Context, cancel = context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) + headers := p.GetRequestHeaders() + defer cancel() + + req, err := p.Requester.NewRequest(p.Context.Request.Method, fullRequestURL, p.Requester.WithBody(bytes.NewBuffer(reqBody)), p.Requester.WithHeader(headers)) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err + } + + resp, errWith := p.Requester.SendRequestRaw(req) + if errWith != nil { + common.SysError("do request failed: " + errWith.Error()) + return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err + } + statusCode := resp.StatusCode + err = req.Body.Close() + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err + } + err = p.Context.Request.Body.Close() + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err + } + var midjResponse MidjourneyResponse + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err + } + err = resp.Body.Close() + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err + } + respStr := string(responseBody) + log.Printf("responseBody: %s", respStr) + if respStr == "" { + return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil + } else { + err = json.Unmarshal(responseBody, &midjResponse) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err + } + } + + return &MidjourneyResponseWithStatusCode{ + StatusCode: statusCode, + Response: midjResponse, + }, responseBody, nil + +} + +func (p *MidjourneyProvider) GetRequestHeaders() (headers map[string]string) { + headers = make(map[string]string) + headers["mj-api-secret"] = p.Channel.Key + headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type") + headers["Accept"] = p.Context.Request.Header.Get("Accept") + + return headers +} diff --git a/providers/midjourney/constant.go b/providers/midjourney/constant.go new file mode 100644 index 00000000..770ee685 --- /dev/null +++ b/providers/midjourney/constant.go @@ -0,0 +1,69 @@ +// Author: Calcium-Ion +// GitHub: https://github.com/Calcium-Ion/new-api +// Path: relay/constant/relay_mode.go +package midjourney + +const ( + RelayModeUnknown = iota + RelayModeMidjourneyImagine + RelayModeMidjourneyDescribe + RelayModeMidjourneyBlend + RelayModeMidjourneyChange + RelayModeMidjourneySimpleChange + RelayModeMidjourneyNotify + RelayModeMidjourneyTaskFetch + RelayModeMidjourneyTaskImageSeed + RelayModeMidjourneyTaskFetchByCondition + RelayModeAudioSpeech + RelayModeAudioTranscription + RelayModeAudioTranslation + RelayModeMidjourneyAction + RelayModeMidjourneyModal + RelayModeMidjourneyShorten + RelayModeMidjourneySwapFace +) + +// Author: Calcium-Ion +// GitHub: https://github.com/Calcium-Ion/new-api +// Path: constant/midjourney.go + +const ( + MjErrorUnknown = 5 + MjRequestError = 4 +) + +const ( + MjActionImagine = "IMAGINE" + MjActionDescribe = "DESCRIBE" + MjActionBlend = "BLEND" + MjActionUpscale = "UPSCALE" + MjActionVariation = "VARIATION" + MjActionReRoll = "REROLL" + MjActionInPaint = "INPAINT" + MjActionModal = "MODAL" + MjActionZoom = "ZOOM" + MjActionCustomZoom = "CUSTOM_ZOOM" + MjActionShorten = "SHORTEN" + MjActionHighVariation = "HIGH_VARIATION" + MjActionLowVariation = "LOW_VARIATION" + MjActionPan = "PAN" + MjActionSwapFace = "SWAP_FACE" +) + +var MidjourneyModel2Action = map[string]string{ + "mj_imagine": MjActionImagine, + "mj_describe": MjActionDescribe, + "mj_blend": MjActionBlend, + "mj_upscale": MjActionUpscale, + "mj_variation": MjActionVariation, + "mj_reroll": MjActionReRoll, + "mj_modal": MjActionModal, + "mj_inpaint": MjActionInPaint, + "mj_zoom": MjActionZoom, + "mj_custom_zoom": MjActionCustomZoom, + "mj_shorten": MjActionShorten, + "mj_high_variation": MjActionHighVariation, + "mj_low_variation": MjActionLowVariation, + "mj_pan": MjActionPan, + "swap_face": MjActionSwapFace, +} diff --git a/providers/midjourney/error.go b/providers/midjourney/error.go new file mode 100644 index 00000000..04c1f90b --- /dev/null +++ b/providers/midjourney/error.go @@ -0,0 +1,18 @@ +// Author: Calcium-Ion +// GitHub: https://github.com/Calcium-Ion/new-api +// Path: service/error.go +package midjourney + +func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *MidjourneyResponseWithStatusCode { + return &MidjourneyResponseWithStatusCode{ + StatusCode: statusCode, + Response: *MidjourneyErrorWrapper(code, desc), + } +} + +func MidjourneyErrorWrapper(code int, desc string) *MidjourneyResponse { + return &MidjourneyResponse{ + Code: code, + Description: desc, + } +} diff --git a/providers/midjourney/type.go b/providers/midjourney/type.go new file mode 100644 index 00000000..42d156cb --- /dev/null +++ b/providers/midjourney/type.go @@ -0,0 +1,92 @@ +// Author: Calcium-Ion +// GitHub: https://github.com/Calcium-Ion/new-api +// Path: dto/midjourney.go +package midjourney + +type SwapFaceRequest struct { + SourceBase64 string `json:"sourceBase64"` + TargetBase64 string `json:"targetBase64"` +} + +type MidjourneyRequest struct { + Prompt string `json:"prompt"` + CustomId string `json:"customId"` + BotType string `json:"botType"` + NotifyHook string `json:"notifyHook"` + Action string `json:"action"` + Index int `json:"index"` + State string `json:"state"` + TaskId string `json:"taskId"` + Base64Array []string `json:"base64Array"` + Content string `json:"content"` + MaskBase64 string `json:"maskBase64"` +} + +type MidjourneyResponse struct { + Code int `json:"code"` + Description string `json:"description"` + Properties interface{} `json:"properties"` + Result string `json:"result"` + Type string `json:"type,omitempty"` +} + +type MidjourneyResponseWithStatusCode struct { + StatusCode int `json:"statusCode"` + Response MidjourneyResponse +} + +type MidjourneyDto struct { + MjId string `json:"id"` + Action string `json:"action"` + CustomId string `json:"customId"` + BotType string `json:"botType"` + Prompt string `json:"prompt"` + PromptEn string `json:"promptEn"` + Description string `json:"description"` + State string `json:"state"` + SubmitTime int64 `json:"submitTime"` + StartTime int64 `json:"startTime"` + FinishTime int64 `json:"finishTime"` + ImageUrl string `json:"imageUrl"` + Status string `json:"status"` + Progress string `json:"progress"` + FailReason string `json:"failReason"` + Buttons any `json:"buttons"` + MaskBase64 string `json:"maskBase64"` + Properties *Properties `json:"properties"` +} + +type MidjourneyStatus struct { + Status int `json:"status"` +} +type MidjourneyWithoutStatus struct { + Id int `json:"id"` + Code int `json:"code"` + UserId int `json:"user_id" gorm:"index"` + Action string `json:"action"` + MjId string `json:"mj_id" gorm:"index"` + Prompt string `json:"prompt"` + PromptEn string `json:"prompt_en"` + Description string `json:"description"` + State string `json:"state"` + SubmitTime int64 `json:"submit_time"` + StartTime int64 `json:"start_time"` + FinishTime int64 `json:"finish_time"` + ImageUrl string `json:"image_url"` + Progress string `json:"progress"` + FailReason string `json:"fail_reason"` + ChannelId int `json:"channel_id"` +} + +type ActionButton struct { + CustomId any `json:"customId"` + Emoji any `json:"emoji"` + Label any `json:"label"` + Type any `json:"type"` + Style any `json:"style"` +} + +type Properties struct { + FinalPrompt string `json:"finalPrompt"` + FinalZhPrompt string `json:"finalZhPrompt"` +} diff --git a/providers/providers.go b/providers/providers.go index c40d988b..a28cf708 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -14,6 +14,7 @@ import ( "one-api/providers/deepseek" "one-api/providers/gemini" "one-api/providers/groq" + "one-api/providers/midjourney" "one-api/providers/minimax" "one-api/providers/mistral" "one-api/providers/openai" @@ -52,6 +53,7 @@ func init() { providerFactories[common.ChannelTypeMistral] = mistral.MistralProviderFactory{} providerFactories[common.ChannelTypeGroq] = groq.GroqProviderFactory{} providerFactories[common.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{} + providerFactories[common.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{} } diff --git a/relay/base.go b/relay/base.go index 0f644f02..5891010c 100644 --- a/relay/base.go +++ b/relay/base.go @@ -27,7 +27,7 @@ type RelayBaseInterface interface { } func (r *relayBase) setProvider(modelName string) error { - provider, modelName, fail := getProvider(r.c, modelName) + provider, modelName, fail := GetProvider(r.c, modelName) if fail != nil { return fail } diff --git a/relay/common.go b/relay/common.go index d3443f4e..c3567999 100644 --- a/relay/common.go +++ b/relay/common.go @@ -45,7 +45,7 @@ func Path2Relay(c *gin.Context, path string) RelayBaseInterface { return nil } -func getProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) { +func GetProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) { channel, fail := fetchChannel(c, modeName) if fail != nil { return diff --git a/relay/midjourney/LICENSE b/relay/midjourney/LICENSE new file mode 100644 index 00000000..f0fec6b4 --- /dev/null +++ b/relay/midjourney/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2024 Calcium-Ion + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/relay/midjourney/relay-mj.go b/relay/midjourney/relay-mj.go new file mode 100644 index 00000000..e8993185 --- /dev/null +++ b/relay/midjourney/relay-mj.go @@ -0,0 +1,578 @@ +// Author: Calcium-Ion +// GitHub: https://github.com/Calcium-Ion/new-api +// Path: relay/relay-mj.go +package midjourney + +import ( + "bytes" + "encoding/json" + "io" + "log" + "net/http" + "one-api/common" + "one-api/controller" + "one-api/model" + providersBase "one-api/providers/base" + provider "one-api/providers/midjourney" + "one-api/relay" + "one-api/relay/util" + "one-api/types" + "strconv" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +func RelayMidjourneyImage(c *gin.Context) { + taskId := c.Param("id") + midjourneyTask := model.GetByOnlyMJId(taskId) + if midjourneyTask == nil { + c.JSON(400, gin.H{ + "error": "midjourney_task_not_found", + }) + return + } + resp, err := http.Get(midjourneyTask.ImageUrl) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "http_get_image_failed", + }) + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + responseBody, _ := io.ReadAll(resp.Body) + c.JSON(resp.StatusCode, gin.H{ + "error": string(responseBody), + }) + return + } + // 从Content-Type头获取MIME类型 + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + // 如果无法确定内容类型,则默认为jpeg + contentType = "image/jpeg" + } + // 设置响应的内容类型 + c.Writer.Header().Set("Content-Type", contentType) + // 将图片流式传输到响应体 + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + log.Println("Failed to stream image:", err) + } +} + +func RelayMidjourneyNotify(c *gin.Context) *provider.MidjourneyResponse { + var midjRequest provider.MidjourneyDto + err := common.UnmarshalBodyReusable(c, &midjRequest) + if err != nil { + return &provider.MidjourneyResponse{ + Code: 4, + Description: "bind_request_body_failed", + Properties: nil, + Result: "", + } + } + midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId) + if midjourneyTask == nil { + return &provider.MidjourneyResponse{ + Code: 4, + Description: "midjourney_task_not_found", + Properties: nil, + Result: "", + } + } + midjourneyTask.Progress = midjRequest.Progress + midjourneyTask.PromptEn = midjRequest.PromptEn + midjourneyTask.State = midjRequest.State + midjourneyTask.SubmitTime = midjRequest.SubmitTime + midjourneyTask.StartTime = midjRequest.StartTime + midjourneyTask.FinishTime = midjRequest.FinishTime + midjourneyTask.ImageUrl = midjRequest.ImageUrl + midjourneyTask.Status = midjRequest.Status + midjourneyTask.FailReason = midjRequest.FailReason + err = midjourneyTask.Update() + if err != nil { + return &provider.MidjourneyResponse{ + Code: 4, + Description: "update_midjourney_task_failed", + } + } + + return nil +} + +func coverMidjourneyTaskDto(originTask *model.Midjourney) (midjourneyTask provider.MidjourneyDto) { + midjourneyTask.MjId = originTask.MjId + midjourneyTask.Progress = originTask.Progress + midjourneyTask.PromptEn = originTask.PromptEn + midjourneyTask.State = originTask.State + midjourneyTask.SubmitTime = originTask.SubmitTime + midjourneyTask.StartTime = originTask.StartTime + midjourneyTask.FinishTime = originTask.FinishTime + midjourneyTask.ImageUrl = "" + if originTask.ImageUrl != "" { + midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId + if originTask.Status != "SUCCESS" { + midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10) + } + } + midjourneyTask.Status = originTask.Status + midjourneyTask.FailReason = originTask.FailReason + midjourneyTask.Action = originTask.Action + midjourneyTask.Description = originTask.Description + midjourneyTask.Prompt = originTask.Prompt + if originTask.Buttons != "" { + var buttons []provider.ActionButton + err := json.Unmarshal([]byte(originTask.Buttons), &buttons) + if err == nil { + midjourneyTask.Buttons = buttons + } + } + if originTask.Properties != "" { + var properties provider.Properties + err := json.Unmarshal([]byte(originTask.Properties), &properties) + if err == nil { + midjourneyTask.Properties = &properties + } + } + return +} + +func RelaySwapFace(c *gin.Context) *provider.MidjourneyResponse { + mjProvider, errWithMJ := getMJProvider(c, provider.RelayModeMidjourneySwapFace, 0, nil) + if errWithMJ != nil { + return errWithMJ + } + + startTime := time.Now().UnixNano() / int64(time.Millisecond) + userId := c.GetInt("id") + var swapFaceRequest provider.SwapFaceRequest + err := common.UnmarshalBodyReusable(c, &swapFaceRequest) + if err != nil { + return provider.MidjourneyErrorWrapper(provider.MjRequestError, "bind_request_body_failed") + } + if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" { + return provider.MidjourneyErrorWrapper(provider.MjRequestError, "sour_base64_and_target_base64_is_required") + } + + quotaInstance, errWithOA := getQuota(c, mjProvider.GetOriginalModel()) + if errWithOA != nil { + return &provider.MidjourneyResponse{ + Code: 4, + Description: errWithOA.Message, + } + } + requestURL := getMjRequestPath(c.Request.URL.String()) + mjResp, _, err := mjProvider.Send(60, requestURL) + if err != nil { + quotaInstance.Undo(c) + return &mjResp.Response + } + if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { + quotaInstance.Consume(c, &types.Usage{CompletionTokens: 0, PromptTokens: 1000, TotalTokens: 1000}) + } else { + quotaInstance.Undo(c) + } + + quota := int(quotaInstance.GetInputRatio() * 1000) + + midjResponse := &mjResp.Response + midjourneyTask := &model.Midjourney{ + UserId: userId, + Code: midjResponse.Code, + Action: provider.MjActionSwapFace, + MjId: midjResponse.Result, + Prompt: "InsightFace", + PromptEn: "", + Description: midjResponse.Description, + State: "", + SubmitTime: startTime, + StartTime: time.Now().UnixNano() / int64(time.Millisecond), + FinishTime: 0, + ImageUrl: "", + Status: "", + Progress: "0%", + FailReason: "", + ChannelId: c.GetInt("channel_id"), + Quota: quota, + } + err = midjourneyTask.Insert() + if err != nil { + return provider.MidjourneyErrorWrapper(provider.MjRequestError, "insert_midjourney_task_failed") + } + // 开始激活任务 + controller.ActivateUpdateMidjourneyTaskBulk() + + c.Writer.WriteHeader(mjResp.StatusCode) + respBody, err := json.Marshal(midjResponse) + if err != nil { + return provider.MidjourneyErrorWrapper(provider.MjRequestError, "unmarshal_response_body_failed") + } + _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) + if err != nil { + return provider.MidjourneyErrorWrapper(provider.MjRequestError, "copy_response_body_failed") + } + return nil +} + +func RelayMidjourneyTaskImageSeed(c *gin.Context) *provider.MidjourneyResponse { + taskId := c.Param("id") + userId := c.GetInt("id") + originTask := model.GetByMJId(userId, taskId) + if originTask == nil { + return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_no_found") + } + + mjProvider, errWithMJ := getMJProvider(c, provider.RelayModeMidjourneyTaskImageSeed, originTask.ChannelId, nil) + if errWithMJ != nil { + return errWithMJ + } + + requestURL := getMjRequestPath(c.Request.URL.String()) + midjResponseWithStatus, _, err := mjProvider.Send(30, requestURL) + if err != nil { + return &midjResponseWithStatus.Response + } + midjResponse := &midjResponseWithStatus.Response + c.Writer.WriteHeader(midjResponseWithStatus.StatusCode) + respBody, err := json.Marshal(midjResponse) + if err != nil { + return provider.MidjourneyErrorWrapper(provider.MjRequestError, "unmarshal_response_body_failed") + } + _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) + if err != nil { + return provider.MidjourneyErrorWrapper(provider.MjRequestError, "copy_response_body_failed") + } + return nil +} + +func RelayMidjourneyTask(c *gin.Context, relayMode int) *provider.MidjourneyResponse { + userId := c.GetInt("id") + var err error + var respBody []byte + switch relayMode { + case provider.RelayModeMidjourneyTaskFetch: + taskId := c.Param("id") + originTask := model.GetByMJId(userId, taskId) + if originTask == nil { + return &provider.MidjourneyResponse{ + Code: 4, + Description: "task_no_found", + } + } + midjourneyTask := coverMidjourneyTaskDto(originTask) + respBody, err = json.Marshal(midjourneyTask) + if err != nil { + return &provider.MidjourneyResponse{ + Code: 4, + Description: "unmarshal_response_body_failed", + } + } + case provider.RelayModeMidjourneyTaskFetchByCondition: + var condition = struct { + IDs []string `json:"ids"` + }{} + err = c.BindJSON(&condition) + if err != nil { + return &provider.MidjourneyResponse{ + Code: 4, + Description: "do_request_failed", + } + } + var tasks []provider.MidjourneyDto + if len(condition.IDs) != 0 { + originTasks := model.GetByMJIds(userId, condition.IDs) + for _, originTask := range originTasks { + midjourneyTask := coverMidjourneyTaskDto(originTask) + tasks = append(tasks, midjourneyTask) + } + } + if tasks == nil { + tasks = make([]provider.MidjourneyDto, 0) + } + respBody, err = json.Marshal(tasks) + if err != nil { + return &provider.MidjourneyResponse{ + Code: 4, + Description: "unmarshal_response_body_failed", + } + } + } + + c.Writer.Header().Set("Content-Type", "application/json") + + _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) + if err != nil { + return &provider.MidjourneyResponse{ + Code: 4, + Description: "copy_response_body_failed", + } + } + return nil +} + +func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyResponse { + channelId := 0 + userId := c.GetInt("id") + consumeQuota := true + var midjRequest provider.MidjourneyRequest + err := common.UnmarshalBodyReusable(c, &midjRequest) + if err != nil { + return provider.MidjourneyErrorWrapper(provider.MjRequestError, "bind_request_body_failed") + } + + if relayMode == provider.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息 + mjErr := CoverPlusActionToNormalAction(&midjRequest) + if mjErr != nil { + return mjErr + } + relayMode = provider.RelayModeMidjourneyChange + } + + if relayMode == provider.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复 + if midjRequest.Prompt == "" { + return provider.MidjourneyErrorWrapper(provider.MjRequestError, "prompt_is_required") + } + midjRequest.Action = provider.MjActionImagine + } else if relayMode == provider.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复 + midjRequest.Action = provider.MjActionDescribe + } else if relayMode == provider.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only + midjRequest.Action = provider.MjActionShorten + } else if relayMode == provider.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 + midjRequest.Action = provider.MjActionBlend + } else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果 + mjId := "" + if relayMode == provider.RelayModeMidjourneyChange { + if midjRequest.TaskId == "" { + return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_id_is_required") + } else if midjRequest.Action == "" { + return provider.MidjourneyErrorWrapper(provider.MjRequestError, "action_is_required") + } else if midjRequest.Index == 0 { + return provider.MidjourneyErrorWrapper(provider.MjRequestError, "index_is_required") + } + //action = midjRequest.Action + mjId = midjRequest.TaskId + } else if relayMode == provider.RelayModeMidjourneySimpleChange { + if midjRequest.Content == "" { + return provider.MidjourneyErrorWrapper(provider.MjRequestError, "content_is_required") + } + params := ConvertSimpleChangeParams(midjRequest.Content) + if params == nil { + return provider.MidjourneyErrorWrapper(provider.MjRequestError, "content_parse_failed") + } + mjId = params.TaskId + midjRequest.Action = params.Action + } else if relayMode == provider.RelayModeMidjourneyModal { + //if midjRequest.MaskBase64 == "" { + // return provider.MidjourneyErrorWrapper(provider.MjRequestError, "mask_base64_is_required") + //} + mjId = midjRequest.TaskId + midjRequest.Action = provider.MjActionModal + } + + originTask := model.GetByMJId(userId, mjId) + if originTask == nil { + return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_not_found") + } else if originTask.Status != "SUCCESS" && relayMode != provider.RelayModeMidjourneyModal { + return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_status_not_success") + } else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理 + channelId = originTask.ChannelId + log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %d", originTask.ChannelId) + } + midjRequest.Prompt = originTask.Prompt + + //if channelType == common.ChannelTypeMidjourneyPlus { + // // plus + //} else { + // // 普通版渠道 + // + //} + } + + if midjRequest.Action == provider.MjActionInPaint || midjRequest.Action == provider.MjActionCustomZoom { + consumeQuota = false + } + + mjProvider, errWithMJ := getMJProvider(c, relayMode, channelId, &midjRequest) + if errWithMJ != nil { + return errWithMJ + } + + //baseURL := common.ChannelBaseURLs[channelType] + requestURL := getMjRequestPath(c.Request.URL.String()) + + //midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify" + + quotaInstance, errWithOA := getQuota(c, mjProvider.GetOriginalModel()) + if errWithOA != nil { + return &provider.MidjourneyResponse{ + Code: 4, + Description: errWithOA.Message, + } + } + + midjResponseWithStatus, responseBody, err := mjProvider.Send(60, requestURL) + if err != nil { + quotaInstance.Undo(c) + return &midjResponseWithStatus.Response + } + + if consumeQuota && midjResponseWithStatus.StatusCode == 200 { + quotaInstance.Consume(c, &types.Usage{CompletionTokens: 0, PromptTokens: 1, TotalTokens: 1}) + } else { + quotaInstance.Undo(c) + } + quota := int(quotaInstance.GetInputRatio() * 1000) + + midjResponse := &midjResponseWithStatus.Response + + // 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md + //1-提交成功 + // 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}} + // 22-排队中 {"code":22,"description":"排队中,前面还有1个任务","result":"0741798445574458","properties":{"numberOfQueues":1,"discordInstanceId":"1118138338562560102"}} + // 23-队列已满,请稍后再试 {"code":23,"description":"队列已满,请稍后尝试","result":"14001929738841620","properties":{"discordInstanceId":"1118138338562560102"}} + // 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}} + // other: 提交错误,description为错误描述 + midjourneyTask := &model.Midjourney{ + UserId: userId, + Code: midjResponse.Code, + Action: midjRequest.Action, + MjId: midjResponse.Result, + Prompt: midjRequest.Prompt, + PromptEn: "", + Description: midjResponse.Description, + State: "", + SubmitTime: time.Now().UnixNano() / int64(time.Millisecond), + StartTime: 0, + FinishTime: 0, + ImageUrl: "", + Status: "", + Progress: "0%", + FailReason: "", + ChannelId: c.GetInt("channel_id"), + Quota: quota, + } + + if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 { + //非1-提交成功,21-任务已存在和22-排队中,则记录错误原因 + midjourneyTask.FailReason = midjResponse.Description + consumeQuota = false + } + + if midjResponse.Code == 21 { //21-任务已存在(处理中或者有结果了) + // 将 properties 转换为一个 map + properties, ok := midjResponse.Properties.(map[string]interface{}) + if ok { + imageUrl, ok1 := properties["imageUrl"].(string) + status, ok2 := properties["status"].(string) + if ok1 && ok2 { + midjourneyTask.ImageUrl = imageUrl + midjourneyTask.Status = status + if status == "SUCCESS" { + midjourneyTask.Progress = "100%" + midjourneyTask.StartTime = time.Now().UnixNano() / int64(time.Millisecond) + midjourneyTask.FinishTime = time.Now().UnixNano() / int64(time.Millisecond) + midjResponse.Code = 1 + } + } + } + //修改返回值 + if midjRequest.Action != provider.MjActionInPaint && midjRequest.Action != provider.MjActionCustomZoom { + newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1) + responseBody = []byte(newBody) + } + } + + err = midjourneyTask.Insert() + if err != nil { + return &provider.MidjourneyResponse{ + Code: 4, + Description: "insert_midjourney_task_failed", + } + } + // 开始激活任务 + controller.ActivateUpdateMidjourneyTaskBulk() + + if midjResponse.Code == 22 { //22-排队中,说明任务已存在 + //修改返回值 + newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1) + responseBody = []byte(newBody) + } + + //resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + bodyReader := io.NopCloser(bytes.NewBuffer(responseBody)) + + //for k, v := range resp.Header { + // c.Writer.Header().Set(k, v[0]) + //} + c.Writer.WriteHeader(midjResponseWithStatus.StatusCode) + + _, err = io.Copy(c.Writer, bodyReader) + if err != nil { + return &provider.MidjourneyResponse{ + Code: 4, + Description: "copy_response_body_failed", + } + } + err = bodyReader.Close() + if err != nil { + return &provider.MidjourneyResponse{ + Code: 4, + Description: "close_response_body_failed", + } + } + return nil +} + +func getMjRequestPath(path string) string { + requestURL := path + if strings.Contains(requestURL, "/mj-") { + urls := strings.Split(requestURL, "/mj/") + if len(urls) < 2 { + return requestURL + } + requestURL = "/mj/" + urls[1] + } + return requestURL +} + +func getQuota(c *gin.Context, modelName string) (*util.Quota, *types.OpenAIErrorWithStatusCode) { + // modelName = CoverActionToModelName(modelName) + + return util.NewQuota(c, modelName, 1000) +} + +func getMJProvider(c *gin.Context, relayMode, channel_id int, request *provider.MidjourneyRequest) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) { + var baseProvider providersBase.ProviderInterface + modelName := "" + if channel_id > 0 { + c.Set("specific_channel_id", channel_id) + } + + if request != nil { + midjourneyModel, mjErr, _ := GetMjRequestModel(relayMode, request) + if mjErr != nil { + return nil, MidjourneyErrorFromInternal(mjErr.Code, mjErr.Description) + } + if midjourneyModel == "" { + return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无效的请求, 无法解析模型") + } + + modelName = midjourneyModel + } + + var err error + baseProvider, _, err = relay.GetProvider(c, modelName) + if err != nil { + return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无法获取provider:"+err.Error()) + } + + mjProvider, ok := baseProvider.(*provider.MidjourneyProvider) + if !ok { + return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无效的请求, 无法获取midjourney provider") + } + + return mjProvider, nil +} diff --git a/relay/midjourney/relay.go b/relay/midjourney/relay.go new file mode 100644 index 00000000..e4ce2f15 --- /dev/null +++ b/relay/midjourney/relay.go @@ -0,0 +1,95 @@ +// Author: Calcium-Ion +// GitHub: https://github.com/Calcium-Ion/new-api +// Path: controller/relay.go +package midjourney + +import ( + "fmt" + "net/http" + "one-api/common" + provider "one-api/providers/midjourney" + "strings" + + "github.com/gin-gonic/gin" +) + +func RelayMidjourney(c *gin.Context) { + relayMode := Path2RelayModeMidjourney(c.Request.URL.Path) + var err *provider.MidjourneyResponse + switch relayMode { + case provider.RelayModeMidjourneyNotify: + err = RelayMidjourneyNotify(c) + case provider.RelayModeMidjourneyTaskFetch, provider.RelayModeMidjourneyTaskFetchByCondition: + err = RelayMidjourneyTask(c, relayMode) + case provider.RelayModeMidjourneyTaskImageSeed: + err = RelayMidjourneyTaskImageSeed(c) + case provider.RelayModeMidjourneySwapFace: + err = RelaySwapFace(c) + default: + err = RelayMidjourneySubmit(c, relayMode) + } + + if err != nil { + statusCode := http.StatusBadRequest + if err.Code == 30 { + err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" + statusCode = http.StatusTooManyRequests + } + + typeMsg := "upstream_error" + if err.Type != "" { + typeMsg = err.Type + } + c.JSON(statusCode, gin.H{ + "description": fmt.Sprintf("%s %s", err.Description, err.Result), + "type": typeMsg, + "code": err.Code, + }) + channelId := c.GetInt("channel_id") + common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result))) + } +} + +func MidjourneyErrorFromInternal(code int, description string) *provider.MidjourneyResponse { + return &provider.MidjourneyResponse{ + Code: code, + Description: description, + Type: "internal_error", + } +} + +func Path2RelayModeMidjourney(path string) int { + relayMode := provider.RelayModeUnknown + if strings.HasSuffix(path, "/mj/submit/action") { + // midjourney plus + relayMode = provider.RelayModeMidjourneyAction + } else if strings.HasSuffix(path, "/mj/submit/modal") { + // midjourney plus + relayMode = provider.RelayModeMidjourneyModal + } else if strings.HasSuffix(path, "/mj/submit/shorten") { + // midjourney plus + relayMode = provider.RelayModeMidjourneyShorten + } else if strings.HasSuffix(path, "/mj/insight-face/swap") { + // midjourney plus + relayMode = provider.RelayModeMidjourneySwapFace + } else if strings.HasSuffix(path, "/mj/submit/imagine") { + relayMode = provider.RelayModeMidjourneyImagine + } else if strings.HasSuffix(path, "/mj/submit/blend") { + relayMode = provider.RelayModeMidjourneyBlend + } else if strings.HasSuffix(path, "/mj/submit/describe") { + relayMode = provider.RelayModeMidjourneyDescribe + } else if strings.HasSuffix(path, "/mj/notify") { + relayMode = provider.RelayModeMidjourneyNotify + } else if strings.HasSuffix(path, "/mj/submit/change") { + relayMode = provider.RelayModeMidjourneyChange + } else if strings.HasSuffix(path, "/mj/submit/simple-change") { + relayMode = provider.RelayModeMidjourneyChange + } else if strings.HasSuffix(path, "/fetch") { + relayMode = provider.RelayModeMidjourneyTaskFetch + } else if strings.HasSuffix(path, "/image-seed") { + relayMode = provider.RelayModeMidjourneyTaskImageSeed + } else if strings.HasSuffix(path, "/list-by-condition") { + relayMode = provider.RelayModeMidjourneyTaskFetchByCondition + } + return relayMode +} diff --git a/relay/midjourney/service.go b/relay/midjourney/service.go new file mode 100644 index 00000000..1118a785 --- /dev/null +++ b/relay/midjourney/service.go @@ -0,0 +1,148 @@ +// Author: Calcium-Ion +// GitHub: https://github.com/Calcium-Ion/new-api +// Path: service/midjourney.go +package midjourney + +import ( + mjProvider "one-api/providers/midjourney" + "strconv" + "strings" +) + +func CoverActionToModelName(mjAction string) string { + modelName := "mj_" + strings.ToLower(mjAction) + if mjAction == mjProvider.MjActionSwapFace { + modelName = "swap_face" + } + return modelName +} + +func GetMjRequestModel(relayMode int, midjRequest *mjProvider.MidjourneyRequest) (string, *mjProvider.MidjourneyResponse, bool) { + action := "" + if relayMode == mjProvider.RelayModeMidjourneyAction { + // plus request + err := CoverPlusActionToNormalAction(midjRequest) + if err != nil { + return "", err, false + } + action = midjRequest.Action + } else { + switch relayMode { + case mjProvider.RelayModeMidjourneyImagine: + action = mjProvider.MjActionImagine + case mjProvider.RelayModeMidjourneyDescribe: + action = mjProvider.MjActionDescribe + case mjProvider.RelayModeMidjourneyBlend: + action = mjProvider.MjActionBlend + case mjProvider.RelayModeMidjourneyShorten: + action = mjProvider.MjActionShorten + case mjProvider.RelayModeMidjourneyChange: + action = midjRequest.Action + case mjProvider.RelayModeMidjourneyModal: + action = mjProvider.MjActionModal + case mjProvider.RelayModeMidjourneySwapFace: + action = mjProvider.MjActionSwapFace + case mjProvider.RelayModeMidjourneySimpleChange: + params := ConvertSimpleChangeParams(midjRequest.Content) + if params == nil { + return "", mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "invalid_request"), false + } + action = params.Action + case mjProvider.RelayModeMidjourneyTaskFetch, mjProvider.RelayModeMidjourneyTaskFetchByCondition, mjProvider.RelayModeMidjourneyNotify: + return "", nil, true + default: + return "", mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "unknown_relay_action"), false + } + } + + modelName := CoverActionToModelName(action) + return modelName, nil, true +} + +func CoverPlusActionToNormalAction(midjRequest *mjProvider.MidjourneyRequest) *mjProvider.MidjourneyResponse { + // "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011" + customId := midjRequest.CustomId + if customId == "" { + return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "custom_id_is_required") + } + splits := strings.Split(customId, "::") + var action string + if splits[1] == "JOB" { + action = splits[2] + } else { + action = splits[1] + } + + if action == "" { + return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "unknown_action") + } + if strings.Contains(action, "upsample") { + index, err := strconv.Atoi(splits[3]) + if err != nil { + return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "index_parse_failed") + } + midjRequest.Index = index + midjRequest.Action = mjProvider.MjActionUpscale + } else if strings.Contains(action, "variation") { + midjRequest.Index = 1 + if action == "variation" { + index, err := strconv.Atoi(splits[3]) + if err != nil { + return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "index_parse_failed") + } + midjRequest.Index = index + midjRequest.Action = mjProvider.MjActionVariation + } else if action == "low_variation" { + midjRequest.Action = mjProvider.MjActionLowVariation + } else if action == "high_variation" { + midjRequest.Action = mjProvider.MjActionHighVariation + } + } else if strings.Contains(action, "pan") { + midjRequest.Action = mjProvider.MjActionPan + midjRequest.Index = 1 + } else if strings.Contains(action, "reroll") { + midjRequest.Action = mjProvider.MjActionReRoll + midjRequest.Index = 1 + } else if action == "Outpaint" { + midjRequest.Action = mjProvider.MjActionZoom + midjRequest.Index = 1 + } else if action == "CustomZoom" { + midjRequest.Action = mjProvider.MjActionCustomZoom + midjRequest.Index = 1 + } else if action == "Inpaint" { + midjRequest.Action = mjProvider.MjActionInPaint + midjRequest.Index = 1 + } else { + return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "unknown_action:"+customId) + } + return nil +} + +func ConvertSimpleChangeParams(content string) *mjProvider.MidjourneyRequest { + split := strings.Split(content, " ") + if len(split) != 2 { + return nil + } + + action := strings.ToLower(split[1]) + changeParams := &mjProvider.MidjourneyRequest{} + changeParams.TaskId = split[0] + + if action[0] == 'u' { + changeParams.Action = "UPSCALE" + } else if action[0] == 'v' { + changeParams.Action = "VARIATION" + } else if action == "r" { + changeParams.Action = "REROLL" + return changeParams + } else { + return nil + } + + index, err := strconv.Atoi(action[1:2]) + if err != nil || index < 1 || index > 4 { + return nil + } + changeParams.Index = index + return changeParams +} diff --git a/relay/util/quota.go b/relay/util/quota.go index 0274f51c..13392881 100644 --- a/relay/util/quota.go +++ b/relay/util/quota.go @@ -170,3 +170,7 @@ func (q *Quota) Consume(c *gin.Context, usage *types.Usage) { } }(c.Request.Context()) } + +func (q *Quota) GetInputRatio() float64 { + return q.inputRatio +} diff --git a/relay/util/type.go b/relay/util/type.go index 5c16d288..a3ea0300 100644 --- a/relay/util/type.go +++ b/relay/util/type.go @@ -7,22 +7,23 @@ var ModelOwnedBy map[int]string func init() { ModelOwnedBy = map[int]string{ - common.ChannelTypeOpenAI: "OpenAI", - common.ChannelTypeAnthropic: "Anthropic", - common.ChannelTypeBaidu: "Baidu", - common.ChannelTypePaLM: "Google PaLM", - common.ChannelTypeGemini: "Google Gemini", - common.ChannelTypeZhipu: "Zhipu", - common.ChannelTypeAli: "Ali", - common.ChannelTypeXunfei: "Xunfei", - common.ChannelType360: "360", - common.ChannelTypeTencent: "Tencent", - common.ChannelTypeBaichuan: "Baichuan", - common.ChannelTypeMiniMax: "MiniMax", - common.ChannelTypeDeepseek: "Deepseek", - common.ChannelTypeMoonshot: "Moonshot", - common.ChannelTypeMistral: "Mistral", - common.ChannelTypeGroq: "Groq", - common.ChannelTypeLingyi: "Lingyiwanwu", + common.ChannelTypeOpenAI: "OpenAI", + common.ChannelTypeAnthropic: "Anthropic", + common.ChannelTypeBaidu: "Baidu", + common.ChannelTypePaLM: "Google PaLM", + common.ChannelTypeGemini: "Google Gemini", + common.ChannelTypeZhipu: "Zhipu", + common.ChannelTypeAli: "Ali", + common.ChannelTypeXunfei: "Xunfei", + common.ChannelType360: "360", + common.ChannelTypeTencent: "Tencent", + common.ChannelTypeBaichuan: "Baichuan", + common.ChannelTypeMiniMax: "MiniMax", + common.ChannelTypeDeepseek: "Deepseek", + common.ChannelTypeMoonshot: "Moonshot", + common.ChannelTypeMistral: "Mistral", + common.ChannelTypeGroq: "Groq", + common.ChannelTypeLingyi: "Lingyiwanwu", + common.ChannelTypeMidjourney: "Midjourney", } } diff --git a/router/api-router.go b/router/api-router.go index b8742c16..923048ef 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -145,6 +145,9 @@ func SetApiRouter(router *gin.Engine) { } + mjRoute := apiRouter.Group("/mj") + mjRoute.GET("/self", middleware.UserAuth(), controller.GetUserMidjourney) + mjRoute.GET("/", middleware.AdminAuth(), controller.GetAllMidjourney) } } diff --git a/router/dashboard.go b/router/dashboard.go index fa3cac88..1a7350b0 100644 --- a/router/dashboard.go +++ b/router/dashboard.go @@ -1,10 +1,11 @@ package router import ( - "github.com/gin-contrib/gzip" - "github.com/gin-gonic/gin" "one-api/controller" "one-api/middleware" + + "github.com/gin-contrib/gzip" + "github.com/gin-gonic/gin" ) func SetDashboardRouter(router *gin.Engine) { @@ -12,7 +13,7 @@ func SetDashboardRouter(router *gin.Engine) { apiRouter := router.Group("/") apiRouter.Use(gzip.Gzip(gzip.DefaultCompression)) apiRouter.Use(middleware.GlobalAPIRateLimit()) - apiRouter.Use(middleware.TokenAuth()) + apiRouter.Use(middleware.OpenaiAuth()) { apiRouter.GET("/dashboard/billing/subscription", controller.GetSubscription) apiRouter.GET("/v1/dashboard/billing/subscription", controller.GetSubscription) diff --git a/router/relay-router.go b/router/relay-router.go index 824cc1b3..0d4416fa 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -4,6 +4,7 @@ import ( "one-api/controller" "one-api/middleware" "one-api/relay" + "one-api/relay/midjourney" "github.com/gin-gonic/gin" ) @@ -11,14 +12,19 @@ import ( func SetRelayRouter(router *gin.Engine) { router.Use(middleware.CORS()) // https://platform.openai.com/docs/api-reference/introduction + setOpenAIRouter(router) + setMJRouter(router) +} + +func setOpenAIRouter(router *gin.Engine) { modelsRouter := router.Group("/v1/models") - modelsRouter.Use(middleware.TokenAuth(), middleware.Distribute()) + modelsRouter.Use(middleware.OpenaiAuth(), middleware.Distribute()) { modelsRouter.GET("", relay.ListModels) modelsRouter.GET("/:model", relay.RetrieveModel) } relayV1Router := router.Group("/v1") - relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute()) + relayV1Router.Use(middleware.RelayPanicRecover(), middleware.OpenaiAuth(), middleware.Distribute()) { relayV1Router.POST("/completions", relay.Relay) relayV1Router.POST("/chat/completions", relay.Relay) @@ -71,3 +77,34 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.GET("/threads/:id/runs/:runsId/steps", controller.RelayNotImplemented) } } + +func setMJRouter(router *gin.Engine) { + relayMjRouter := router.Group("/mj") + registerMjRouterGroup(relayMjRouter) + + relayMjModeRouter := router.Group("/:mode/mj") + registerMjRouterGroup(relayMjModeRouter) +} + +// Author: Calcium-Ion +// GitHub: https://github.com/Calcium-Ion/new-api +// Path: router/relay-router.go +func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) { + relayMjRouter.GET("/image/:id", midjourney.RelayMidjourneyImage) + relayMjRouter.Use(middleware.MjAuth(), middleware.Distribute()) + { + relayMjRouter.POST("/submit/action", midjourney.RelayMidjourney) + relayMjRouter.POST("/submit/shorten", midjourney.RelayMidjourney) + relayMjRouter.POST("/submit/modal", midjourney.RelayMidjourney) + relayMjRouter.POST("/submit/imagine", midjourney.RelayMidjourney) + relayMjRouter.POST("/submit/change", midjourney.RelayMidjourney) + relayMjRouter.POST("/submit/simple-change", midjourney.RelayMidjourney) + relayMjRouter.POST("/submit/describe", midjourney.RelayMidjourney) + relayMjRouter.POST("/submit/blend", midjourney.RelayMidjourney) + relayMjRouter.POST("/notify", midjourney.RelayMidjourney) + relayMjRouter.GET("/task/:id/fetch", midjourney.RelayMidjourney) + relayMjRouter.GET("/task/:id/image-seed", midjourney.RelayMidjourney) + relayMjRouter.POST("/task/list-by-condition", midjourney.RelayMidjourney) + relayMjRouter.POST("/insight-face/swap", midjourney.RelayMidjourney) + } +} diff --git a/web/README.md b/web/README.md index 07ca93ca..38c1ae6c 100644 --- a/web/README.md +++ b/web/README.md @@ -7,7 +7,7 @@ 使用了以下开源项目作为我们项目的一部分: - [Berry Free React Admin Template](https://github.com/codedthemes/berry-free-react-admin-template) -- [minimal-ui-kit](minimal-ui-kit) +- [minimal-ui-kit](https://github.com/minimal-ui-kit/material-kit-react) ## 许可证 diff --git a/web/src/constants/ChannelConstants.js b/web/src/constants/ChannelConstants.js index 49f82ae1..376ea6ed 100644 --- a/web/src/constants/ChannelConstants.js +++ b/web/src/constants/ChannelConstants.js @@ -132,6 +132,13 @@ export const CHANNEL_OPTIONS = { color: 'primary', url: 'https://platform.lingyiwanwu.com/details' }, + 34: { + key: 34, + text: 'Midjourney', + value: 34, + color: 'orange', + url: '' + }, 24: { key: 24, text: 'Azure Speech', diff --git a/web/src/menu-items/panel.js b/web/src/menu-items/panel.js index a85f6a02..05b41c3f 100644 --- a/web/src/menu-items/panel.js +++ b/web/src/menu-items/panel.js @@ -11,7 +11,8 @@ import { IconUserScan, IconActivity, IconBrandTelegram, - IconReceipt2 + IconReceipt2, + IconBrush } from '@tabler/icons-react'; // constant @@ -27,7 +28,8 @@ const icons = { IconUserScan, IconActivity, IconBrandTelegram, - IconReceipt2 + IconReceipt2, + IconBrush }; // ==============================|| DASHBOARD MENU ITEMS ||============================== // @@ -96,6 +98,14 @@ const panel = { icon: icons.IconGardenCart, breadcrumbs: false }, + { + id: 'midjourney', + title: 'Midjourney', + type: 'item', + url: '/panel/midjourney', + icon: icons.IconBrush, + breadcrumbs: false + }, { id: 'user', title: '用户', diff --git a/web/src/routes/MainRoutes.js b/web/src/routes/MainRoutes.js index 45abc176..582fd1f0 100644 --- a/web/src/routes/MainRoutes.js +++ b/web/src/routes/MainRoutes.js @@ -16,6 +16,7 @@ const NotFoundView = Loadable(lazy(() => import('views/Error'))); const Analytics = Loadable(lazy(() => import('views/Analytics'))); const Telegram = Loadable(lazy(() => import('views/Telegram'))); const Pricing = Loadable(lazy(() => import('views/Pricing'))); +const Midjourney = Loadable(lazy(() => import('views/Midjourney'))); // dashboard routing const Dashboard = Loadable(lazy(() => import('views/Dashboard'))); @@ -81,6 +82,10 @@ const MainRoutes = { { path: 'pricing', element: + }, + { + path: 'midjourney', + element: } ] }; diff --git a/web/src/themes/compStyleOverride.js b/web/src/themes/compStyleOverride.js index 0bdc608d..67a3dd14 100644 --- a/web/src/themes/compStyleOverride.js +++ b/web/src/themes/compStyleOverride.js @@ -12,15 +12,7 @@ export default function componentStyleOverrides(theme) { } } }, - MuiMenuItem: { - styleOverrides: { - root: { - '&:hover': { - backgroundColor: theme.colors?.grey100 - } - } - } - }, //MuiAutocomplete-popper MuiPopover-root + //MuiAutocomplete-popper MuiPopover-root MuiAutocomplete: { styleOverrides: { popper: { @@ -247,7 +239,7 @@ export default function componentStyleOverrides(theme) { MuiTooltip: { styleOverrides: { tooltip: { - color: theme.paper, + color: theme.colors.paper, background: theme.colors?.grey700 } } @@ -266,6 +258,9 @@ export default function componentStyleOverrides(theme) { .apexcharts-menu { background: ${theme.backgroundDefault} !important } + .apexcharts-gridline, .apexcharts-xaxistooltip-background, .apexcharts-yaxistooltip-background { + stroke: ${theme.divider} !important; + } ` } }; diff --git a/web/src/ui-component/Footer.js b/web/src/ui-component/Footer.js index 522c52b3..86e7bd9b 100644 --- a/web/src/ui-component/Footer.js +++ b/web/src/ui-component/Footer.js @@ -19,14 +19,14 @@ const Footer = () => { {siteInfo.system_name} {process.env.REACT_APP_VERSION}{' '} 由{' '} - - JustSong - {' '} - 构建, MartialBE - 修改,源代码遵循 + 开发,基于 + + JustSong + {' '} + One API,源代码遵循 MIT 协议 )} diff --git a/web/src/views/Channel/type/Config.js b/web/src/views/Channel/type/Config.js index c224fb15..2c0da6d5 100644 --- a/web/src/views/Channel/type/Config.js +++ b/web/src/views/Channel/type/Config.js @@ -234,6 +234,33 @@ const typeConfig = { test_model: 'yi-34b-chat-0205' }, modelGroup: 'Lingyiwanwu' + }, + 34: { + input: { + models: [ + 'mj_imagine', + 'mj_variation', + 'mj_reroll', + 'mj_blend', + 'mj_modal', + 'mj_zoom', + 'mj_shorten', + 'mj_high_variation', + 'mj_low_variation', + 'mj_pan', + 'mj_inpaint', + 'mj_custom_zoom', + 'mj_describe', + 'mj_upscale', + 'swap_face' + ] + }, + prompt: { + key: '密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填', + base_url: '地址填写midjourney-proxy部署的地址', + test_model: '' + }, + modelGroup: 'Midjourney' } }; diff --git a/web/src/views/Midjourney/component/TableRow.js b/web/src/views/Midjourney/component/TableRow.js new file mode 100644 index 00000000..31f4acb2 --- /dev/null +++ b/web/src/views/Midjourney/component/TableRow.js @@ -0,0 +1,174 @@ +import PropTypes from 'prop-types'; + +import { useState } from 'react'; +import { + TableRow, + TableCell, + Button, + Dialog, + DialogActions, + DialogContent, + ButtonGroup, + Popover, + MenuItem, + MenuList, + Tooltip +} from '@mui/material'; + +import { timestamp2string, copy } from 'utils/common'; +import Label from 'ui-component/Label'; +import { ACTION_TYPE, CODE_TYPE, STATUS_TYPE } from '../type/Type'; +import { IconCaretDownFilled, IconCopy, IconDownload, IconExternalLink } from '@tabler/icons-react'; + +function renderType(types, type) { + const typeOption = types[type]; + if (typeOption) { + return ( + + ); + } else { + return ( + + ); + } +} +async function downloadImage(url, filename) { + const response = await fetch(url); + const blob = await response.blob(); + const blobUrl = URL.createObjectURL(blob); + const link = document.createElement('a'); + link.href = blobUrl; + link.download = filename; + link.click(); + URL.revokeObjectURL(blobUrl); +} + +function TruncatedText(text) { + const truncatedText = text.length > 30 ? text.substring(0, 100) + '...' : text; + + return ( + { + copy(text, ''); + }} + > + {truncatedText} + + ); +} + +export default function LogTableRow({ item, userIsAdmin }) { + const [open, setOpen] = useState(false); + const [menuOpen, setMenuOpen] = useState(null); + const handleClickOpen = () => { + setOpen(true); + }; + + const handleClose = () => { + setOpen(false); + }; + + const handleOpenMenu = (event) => { + setMenuOpen(event.currentTarget); + }; + + const handleCloseMenu = () => { + setMenuOpen(null); + }; + + return ( + <> + + {item.mj_id} + {timestamp2string(item.submit_time / 1000)} + + {userIsAdmin && {item.channel_id || ''}} + {userIsAdmin && {item.user_id || ''}} + + {renderType(ACTION_TYPE, item.action)} + {userIsAdmin && {renderType(CODE_TYPE, item.code)}} + {userIsAdmin && {renderType(STATUS_TYPE, item.status)}} + {item.progress} + + {item.image_url == '' ? ( + '无' + ) : ( + + + + + )} + + {TruncatedText(item.prompt)} + {TruncatedText(item.prompt_en)} + {TruncatedText(item.fail_reason)} + + + + item + + + + + + + + + { + handleCloseMenu(); + copy(item.image_url, '图片地址'); + }} + > + + 复制地址 + + + { + handleCloseMenu(); + await downloadImage(item.image_url, item.mj_id + '.png'); + }} + > + 下载图片{' '} + + { + handleCloseMenu(); + }} + > + 新窗口打开{' '} + + + + + ); +} + +LogTableRow.propTypes = { + item: PropTypes.object, + userIsAdmin: PropTypes.bool +}; diff --git a/web/src/views/Midjourney/component/TableToolBar.js b/web/src/views/Midjourney/component/TableToolBar.js new file mode 100644 index 00000000..39d2f6de --- /dev/null +++ b/web/src/views/Midjourney/component/TableToolBar.js @@ -0,0 +1,113 @@ +import PropTypes from 'prop-types'; +import { useTheme } from '@mui/material/styles'; +import { IconBroadcast, IconCalendarEvent } from '@tabler/icons-react'; +import { InputAdornment, OutlinedInput, Stack, FormControl, InputLabel } from '@mui/material'; +import { LocalizationProvider, DateTimePicker } from '@mui/x-date-pickers'; +import { AdapterDayjs } from '@mui/x-date-pickers/AdapterDayjs'; +import dayjs from 'dayjs'; +require('dayjs/locale/zh-cn'); +// ---------------------------------------------------------------------- + +export default function TableToolBar({ filterName, handleFilterName, userIsAdmin }) { + const theme = useTheme(); + const grey500 = theme.palette.grey[500]; + + return ( + <> + + {userIsAdmin && ( + + 渠道ID + + + + } + /> + + )} + + 任务ID + + + + } + /> + + + + + { + if (value === null) { + handleFilterName({ target: { name: 'start_timestamp', value: 0 } }); + return; + } + handleFilterName({ target: { name: 'start_timestamp', value: value.unix() * 1000 } }); + }} + slotProps={{ + actionBar: { + actions: ['clear', 'today', 'accept'] + } + }} + /> + + + + + + { + if (value === null) { + handleFilterName({ target: { name: 'end_timestamp', value: 0 } }); + return; + } + handleFilterName({ target: { name: 'end_timestamp', value: value.unix() * 1000 } }); + }} + slotProps={{ + actionBar: { + actions: ['clear', 'today', 'accept'] + } + }} + /> + + + + + ); +} + +TableToolBar.propTypes = { + filterName: PropTypes.object, + handleFilterName: PropTypes.func, + userIsAdmin: PropTypes.bool +}; diff --git a/web/src/views/Midjourney/index.js b/web/src/views/Midjourney/index.js new file mode 100644 index 00000000..ee799f18 --- /dev/null +++ b/web/src/views/Midjourney/index.js @@ -0,0 +1,247 @@ +import { useState, useEffect, useCallback } from 'react'; +import { showError } from 'utils/common'; + +import Table from '@mui/material/Table'; +import TableBody from '@mui/material/TableBody'; +import TableContainer from '@mui/material/TableContainer'; +import PerfectScrollbar from 'react-perfect-scrollbar'; +import TablePagination from '@mui/material/TablePagination'; +import LinearProgress from '@mui/material/LinearProgress'; +import ButtonGroup from '@mui/material/ButtonGroup'; +import Toolbar from '@mui/material/Toolbar'; + +import { Button, Card, Stack, Container, Typography, Box } from '@mui/material'; +import LogTableRow from './component/TableRow'; +import KeywordTableHead from 'ui-component/TableHead'; +import TableToolBar from './component/TableToolBar'; +import { API } from 'utils/api'; +import { isAdmin } from 'utils/common'; +import { ITEMS_PER_PAGE } from 'constants'; +import { IconRefresh, IconSearch } from '@tabler/icons-react'; +import dayjs from 'dayjs'; + +export default function Log() { + const originalKeyword = { + p: 0, + channel_id: '', + mj_id: '', + start_timestamp: 0, + end_timestamp: dayjs().unix() * 1000 + 3600 + }; + + const [page, setPage] = useState(0); + const [order, setOrder] = useState('desc'); + const [orderBy, setOrderBy] = useState('id'); + const [rowsPerPage, setRowsPerPage] = useState(ITEMS_PER_PAGE); + const [listCount, setListCount] = useState(0); + const [searching, setSearching] = useState(false); + const [toolBarValue, setToolBarValue] = useState(originalKeyword); + const [searchKeyword, setSearchKeyword] = useState(originalKeyword); + const [refreshFlag, setRefreshFlag] = useState(false); + + const [logs, setLogs] = useState([]); + const userIsAdmin = isAdmin(); + + const handleSort = (event, id) => { + const isAsc = orderBy === id && order === 'asc'; + if (id !== '') { + setOrder(isAsc ? 'desc' : 'asc'); + setOrderBy(id); + } + }; + + const handleChangePage = (event, newPage) => { + setPage(newPage); + }; + + const handleChangeRowsPerPage = (event) => { + setPage(0); + setRowsPerPage(parseInt(event.target.value, 10)); + }; + + const searchLogs = async () => { + setPage(0); + setSearchKeyword(toolBarValue); + }; + + const handleToolBarValue = (event) => { + setToolBarValue({ ...toolBarValue, [event.target.name]: event.target.value }); + }; + + const fetchData = useCallback( + async (page, rowsPerPage, keyword, order, orderBy) => { + setSearching(true); + try { + if (orderBy) { + orderBy = order === 'desc' ? '-' + orderBy : orderBy; + } + const url = userIsAdmin ? '/api/mj/' : '/api/mj/self/'; + if (!userIsAdmin) { + delete keyword.channel_id; + } + + const res = await API.get(url, { + params: { + page: page + 1, + size: rowsPerPage, + order: orderBy, + ...keyword + } + }); + const { success, message, data } = res.data; + if (success) { + setListCount(data.total_count); + setLogs(data.data); + } else { + showError(message); + } + } catch (error) { + console.error(error); + } + setSearching(false); + }, + [userIsAdmin] + ); + + // 处理刷新 + const handleRefresh = async () => { + setOrderBy('id'); + setOrder('desc'); + setToolBarValue(originalKeyword); + setSearchKeyword(originalKeyword); + setRefreshFlag(!refreshFlag); + }; + + useEffect(() => { + fetchData(page, rowsPerPage, searchKeyword, order, orderBy); + }, [page, rowsPerPage, searchKeyword, order, orderBy, fetchData, refreshFlag]); + + return ( + <> + + Midjourney + + + + + + theme.spacing(0, 1, 0, 3) + }} + > + + + + + + + + + {searching && } + + + + + + {logs.map((row, index) => ( + + ))} + +
+
+
+ +
+ + ); +} diff --git a/web/src/views/Midjourney/type/Type.js b/web/src/views/Midjourney/type/Type.js new file mode 100644 index 00000000..8ec625ec --- /dev/null +++ b/web/src/views/Midjourney/type/Type.js @@ -0,0 +1,33 @@ +export const ACTION_TYPE = { + IMAGINE: { value: 'IMAGINE', text: '绘图', color: 'primary' }, + UPSCALE: { value: 'UPSCALE', text: '放大', color: 'orange' }, + VARIATION: { value: 'VARIATION', text: '变换', color: 'default' }, + HIGH_VARIATION: { value: 'HIGH_VARIATION', text: '强变换', color: 'default' }, + LOW_VARIATION: { value: 'LOW_VARIATION', text: '弱变换', color: 'default' }, + PAN: { value: 'PAN', text: '平移', color: 'secondary' }, + DESCRIBE: { value: 'DESCRIBE', text: '图生文', color: 'secondary' }, + BLEND: { value: 'BLEND', text: '图混合', color: 'secondary' }, + SHORTEN: { value: 'SHORTEN', text: '缩词', color: 'secondary' }, + REROLL: { value: 'REROLL', text: '重绘', color: 'secondary' }, + INPAINT: { value: 'INPAINT', text: '局部重绘-提交', color: 'secondary' }, + ZOOM: { value: 'ZOOM', text: '变焦', color: 'secondary' }, + CUSTOM_ZOOM: { value: 'CUSTOM_ZOOM', text: '自定义变焦-提交', color: 'secondary' }, + MODAL: { value: 'MODAL', text: '窗口处理', color: 'secondary' }, + SWAP_FACE: { value: 'SWAP_FACE', text: '换脸', color: 'secondary' } +}; + +export const CODE_TYPE = { + 1: { value: 1, text: '已提交', color: 'primary' }, + 21: { value: 21, text: '等待中', color: 'orange' }, + 22: { value: 22, text: '重复提交', color: 'default' }, + 0: { value: 0, text: '未提交', color: 'default' } +}; + +export const STATUS_TYPE = { + SUCCESS: { value: 'SUCCESS', text: '成功', color: 'success' }, + NOT_START: { value: 'NOT_START', text: '未启动', color: 'default' }, + SUBMITTED: { value: 'SUBMITTED', text: '队列中', color: 'secondary' }, + IN_PROGRESS: { value: 'IN_PROGRESS', text: '执行中', color: 'primary' }, + FAILURE: { value: 'FAILURE', text: '失败', color: 'orange' }, + MODAL: { value: 'MODAL', text: '窗口等待', color: 'default' } +}; diff --git a/web/src/views/Setting/component/OperationSetting.js b/web/src/views/Setting/component/OperationSetting.js index 85cfe0ef..6712e6ca 100644 --- a/web/src/views/Setting/component/OperationSetting.js +++ b/web/src/views/Setting/component/OperationSetting.js @@ -29,7 +29,8 @@ const OperationSetting = () => { DisplayTokenStatEnabled: '', ApproximateTokenEnabled: '', RetryTimes: 0, - RetryCooldownSeconds: 0 + RetryCooldownSeconds: 0, + MjNotifyEnabled: '' }); const [originInputs, setOriginInputs] = useState({}); let [loading, setLoading] = useState(false); @@ -278,6 +279,22 @@ const OperationSetting = () => { + + + + } + /> + + +