✨ feat: add Midjourney (#138)
* 🚧 stash * ✨ feat: add Midjourney * 📝 doc: update readme
This commit is contained in:
parent
87bfecf3e9
commit
c1fc32add7
10
README.md
10
README.md
@ -58,6 +58,16 @@ _本项目是基于[one-api](https://github.com/songquanpeng/one-api)二次开
|
||||
|
||||
请查看[文档](https://github.com/MartialBE/one-api/wiki)
|
||||
|
||||
## 感谢
|
||||
|
||||
- 本程序使用了以下开源项目
|
||||
- [one-api](https://github.com/songquanpeng/one-api)为本项目的基础
|
||||
- [Berry Free React Admin Template](https://github.com/codedthemes/berry-free-react-admin-template)为本项目的前端界面
|
||||
- [minimal-ui-kit](https://github.com/minimal-ui-kit/material-kit-react),使用了其中的部分样式
|
||||
- [new api](https://github.com/Calcium-Ion/new-api),Midjourney 模块的代码来源于此
|
||||
|
||||
感谢以上项目的作者和贡献者
|
||||
|
||||
## 其他
|
||||
|
||||
<a href="https://next.ossinsight.io/widgets/official/analyze-repo-stars-history?repo_id=689214770" target="_blank" style="display: block" align="center">
|
||||
|
@ -37,6 +37,9 @@ var WeChatAuthEnabled = false
|
||||
var TurnstileCheckEnabled = false
|
||||
var RegisterEnabled = true
|
||||
|
||||
// mj
|
||||
var MjNotifyEnabled = false
|
||||
|
||||
var EmailDomainRestrictionEnabled = false
|
||||
var EmailDomainWhitelist = []string{
|
||||
"gmail.com",
|
||||
@ -161,6 +164,7 @@ const (
|
||||
ChannelTypeGroq = 31
|
||||
ChannelTypeBedrock = 32
|
||||
ChannelTypeLingyi = 33
|
||||
ChannelTypeMidjourney = 34
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
@ -198,6 +202,7 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.groq.com/openai", //31
|
||||
"", //32
|
||||
"https://api.lingyiwanwu.com", //33
|
||||
"", //34
|
||||
}
|
||||
|
||||
const (
|
||||
|
32
common/go-channel.go
Normal file
32
common/go-channel.go
Normal file
@ -0,0 +1,32 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
)
|
||||
|
||||
func SafeGoroutine(f func()) {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
SysError(fmt.Sprintf("child goroutine panic occured: error: %v, stack: %s", r, string(debug.Stack())))
|
||||
}
|
||||
}()
|
||||
f()
|
||||
}()
|
||||
}
|
||||
|
||||
func SafeSend(ch chan bool, value bool) (closed bool) {
|
||||
defer func() {
|
||||
// Recover from panic if one occured. A panic would mean the channel was closed.
|
||||
if recover() != nil {
|
||||
closed = true
|
||||
}
|
||||
}()
|
||||
|
||||
// This will panic if the channel is closed.
|
||||
ch <- value
|
||||
|
||||
// If the code reaches here, then the channel was not closed.
|
||||
return false
|
||||
}
|
@ -106,7 +106,10 @@ func logHelper(ctx context.Context, level string, msg string) {
|
||||
if level == loggerINFO {
|
||||
writer = gin.DefaultWriter
|
||||
}
|
||||
id := ctx.Value(RequestIdKey)
|
||||
id, ok := ctx.Value(RequestIdKey).(string)
|
||||
if !ok {
|
||||
id = "unknown"
|
||||
}
|
||||
now := time.Now()
|
||||
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
|
||||
logCount++ // we don't need accurate count, so no lock here
|
||||
|
@ -23,6 +23,7 @@ type HTTPRequester struct {
|
||||
CreateFormBuilder func(io.Writer) FormBuilder
|
||||
ErrorHandler HttpErrorHandler
|
||||
proxyAddr string
|
||||
Context context.Context
|
||||
}
|
||||
|
||||
// NewHTTPRequester 创建一个新的 HTTPRequester 实例。
|
||||
@ -37,6 +38,7 @@ func NewHTTPRequester(proxyAddr string, errorHandler HttpErrorHandler) *HTTPRequ
|
||||
},
|
||||
ErrorHandler: errorHandler,
|
||||
proxyAddr: proxyAddr,
|
||||
Context: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -47,18 +49,18 @@ type requestOptions struct {
|
||||
|
||||
type requestOption func(*requestOptions)
|
||||
|
||||
func (r *HTTPRequester) getContext() context.Context {
|
||||
func (r *HTTPRequester) setProxy() context.Context {
|
||||
if r.proxyAddr == "" {
|
||||
return context.Background()
|
||||
return r.Context
|
||||
}
|
||||
|
||||
// 如果是以 socks5:// 开头的地址,那么使用 socks5 代理
|
||||
if strings.HasPrefix(r.proxyAddr, "socks5://") {
|
||||
return context.WithValue(context.Background(), ProxySock5AddrKey, r.proxyAddr)
|
||||
return context.WithValue(r.Context, ProxySock5AddrKey, r.proxyAddr)
|
||||
}
|
||||
|
||||
// 否则使用 http 代理
|
||||
return context.WithValue(context.Background(), ProxyHTTPAddrKey, r.proxyAddr)
|
||||
return context.WithValue(r.Context, ProxyHTTPAddrKey, r.proxyAddr)
|
||||
|
||||
}
|
||||
|
||||
@ -71,7 +73,7 @@ func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption)
|
||||
for _, setter := range setters {
|
||||
setter(args)
|
||||
}
|
||||
req, err := r.requestBuilder.Build(r.getContext(), method, url, args.body, args.header)
|
||||
req, err := r.requestBuilder.Build(r.setProxy(), method, url, args.body, args.header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
285
controller/midjourney.go
Normal file
285
controller/midjourney.go
Normal file
@ -0,0 +1,285 @@
|
||||
// Author: Calcium-Ion
|
||||
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||
// Path: controller/midjourney.go
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
provider "one-api/providers/midjourney"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var activeMidjourneyTask = make(chan bool, 1)
|
||||
|
||||
func InitMidjourneyTask() {
|
||||
common.SafeGoroutine(func() {
|
||||
midjourneyTask()
|
||||
})
|
||||
|
||||
ActivateUpdateMidjourneyTaskBulk()
|
||||
}
|
||||
|
||||
func midjourneyTask() {
|
||||
for {
|
||||
select {
|
||||
case <-activeMidjourneyTask:
|
||||
UpdateMidjourneyTaskBulk()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ActivateUpdateMidjourneyTaskBulk() {
|
||||
if len(activeMidjourneyTask) == 0 {
|
||||
activeMidjourneyTask <- true
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateMidjourneyTaskBulk() {
|
||||
ctx := context.WithValue(context.Background(), common.RequestIdKey, "MidjourneyTask")
|
||||
for {
|
||||
common.LogInfo(ctx, "running")
|
||||
|
||||
tasks := model.GetAllUnFinishTasks()
|
||||
|
||||
// 如果没有未完成的任务,则等待
|
||||
if len(tasks) == 0 {
|
||||
for len(activeMidjourneyTask) > 0 {
|
||||
<-activeMidjourneyTask
|
||||
}
|
||||
common.LogInfo(ctx, "no tasks, waiting...")
|
||||
return
|
||||
}
|
||||
|
||||
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
||||
taskChannelM := make(map[int][]string)
|
||||
taskM := make(map[string]*model.Midjourney)
|
||||
nullTaskIds := make([]int, 0)
|
||||
for _, task := range tasks {
|
||||
if task.MjId == "" {
|
||||
// 统计失败的未完成任务
|
||||
nullTaskIds = append(nullTaskIds, task.Id)
|
||||
continue
|
||||
}
|
||||
taskM[task.MjId] = task
|
||||
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
|
||||
}
|
||||
if len(nullTaskIds) > 0 {
|
||||
err := model.MjBulkUpdateByTaskIds(nullTaskIds, map[string]any{
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
|
||||
} else {
|
||||
common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
|
||||
}
|
||||
}
|
||||
if len(taskChannelM) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
continue
|
||||
}
|
||||
midjourneyChannel := model.ChannelGroup.GetChannel(channelId)
|
||||
if midjourneyChannel == nil {
|
||||
err := model.MjBulkUpdate(taskIds, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
||||
continue
|
||||
}
|
||||
requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"ids": taskIds,
|
||||
})
|
||||
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
|
||||
continue
|
||||
}
|
||||
// 设置超时时间
|
||||
timeout := time.Second * 5
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
// 使用带有超时的 context 创建新的请求
|
||||
req = req.WithContext(ctx)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
||||
resp, err := requester.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
continue
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
||||
continue
|
||||
}
|
||||
var responseItems []provider.MidjourneyDto
|
||||
err = json.Unmarshal(responseBody, &responseItems)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
req.Body.Close()
|
||||
cancel()
|
||||
|
||||
for _, responseItem := range responseItems {
|
||||
task := taskM[responseItem.MjId]
|
||||
|
||||
useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime
|
||||
// 如果时间超过一小时,且进度不是100%,则认为任务失败
|
||||
if useTime > 3600000 && task.Progress != "100%" {
|
||||
responseItem.FailReason = "上游任务超时(超过1小时)"
|
||||
responseItem.Status = "FAILURE"
|
||||
}
|
||||
if !checkMjTaskNeedUpdate(task, responseItem) {
|
||||
continue
|
||||
}
|
||||
task.Code = 1
|
||||
task.Progress = responseItem.Progress
|
||||
task.PromptEn = responseItem.PromptEn
|
||||
task.State = responseItem.State
|
||||
task.SubmitTime = responseItem.SubmitTime
|
||||
task.StartTime = responseItem.StartTime
|
||||
task.FinishTime = responseItem.FinishTime
|
||||
task.ImageUrl = responseItem.ImageUrl
|
||||
task.Status = responseItem.Status
|
||||
task.FailReason = responseItem.FailReason
|
||||
if responseItem.Properties != nil {
|
||||
propertiesStr, _ := json.Marshal(responseItem.Properties)
|
||||
task.Properties = string(propertiesStr)
|
||||
}
|
||||
if responseItem.Buttons != nil {
|
||||
buttonStr, _ := json.Marshal(responseItem.Buttons)
|
||||
task.Buttons = string(buttonStr)
|
||||
}
|
||||
|
||||
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
|
||||
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
||||
task.Progress = "100%"
|
||||
err = model.CacheUpdateUserQuota(task.UserId)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
} else {
|
||||
quota := task.Quota
|
||||
if quota != 0 {
|
||||
err = model.IncreaseUserQuota(task.UserId, quota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
err = task.Update()
|
||||
if err != nil {
|
||||
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
time.Sleep(time.Duration(15) * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask provider.MidjourneyDto) bool {
|
||||
if oldTask.Code != 1 {
|
||||
return true
|
||||
}
|
||||
if oldTask.Progress != newTask.Progress {
|
||||
return true
|
||||
}
|
||||
if oldTask.PromptEn != newTask.PromptEn {
|
||||
return true
|
||||
}
|
||||
if oldTask.State != newTask.State {
|
||||
return true
|
||||
}
|
||||
if oldTask.SubmitTime != newTask.SubmitTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.StartTime != newTask.StartTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.FinishTime != newTask.FinishTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.ImageUrl != newTask.ImageUrl {
|
||||
return true
|
||||
}
|
||||
if oldTask.Status != newTask.Status {
|
||||
return true
|
||||
}
|
||||
if oldTask.FailReason != newTask.FailReason {
|
||||
return true
|
||||
}
|
||||
if oldTask.FinishTime != newTask.FinishTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.Progress != "100%" && newTask.FailReason != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func GetAllMidjourney(c *gin.Context) {
|
||||
var params model.TaskQueryParams
|
||||
if err := c.ShouldBindQuery(¶ms); err != nil {
|
||||
common.APIRespondWithError(c, http.StatusOK, err)
|
||||
return
|
||||
}
|
||||
|
||||
midjourneys, err := model.GetAllTasks(¶ms)
|
||||
if err != nil {
|
||||
common.APIRespondWithError(c, http.StatusOK, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": midjourneys,
|
||||
})
|
||||
}
|
||||
|
||||
func GetUserMidjourney(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
|
||||
var params model.TaskQueryParams
|
||||
if err := c.ShouldBindQuery(¶ms); err != nil {
|
||||
common.APIRespondWithError(c, http.StatusOK, err)
|
||||
return
|
||||
}
|
||||
|
||||
midjourneys, err := model.GetAllUserTask(userId, ¶ms)
|
||||
if err != nil {
|
||||
common.APIRespondWithError(c, http.StatusOK, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": midjourneys,
|
||||
})
|
||||
}
|
@ -40,6 +40,7 @@ func GetStatus(c *gin.Context) {
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"telegram_bot": telegram_bot,
|
||||
"mj_notify_enabled": common.MjNotifyEnabled,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
2
main.go
2
main.go
@ -45,6 +45,8 @@ func main() {
|
||||
// Initialize Telegram bot
|
||||
telegram.InitTelegramBot()
|
||||
|
||||
controller.InitMidjourneyTask()
|
||||
|
||||
initHttpServer()
|
||||
}
|
||||
|
||||
|
@ -83,43 +83,54 @@ func RootAuth() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func TokenAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
key := c.Request.Header.Get("Authorization")
|
||||
key = strings.TrimPrefix(key, "Bearer ")
|
||||
key = strings.TrimPrefix(key, "sk-")
|
||||
parts := strings.Split(key, "-")
|
||||
key = parts[0]
|
||||
token, err := model.ValidateUserToken(key)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if !userEnabled {
|
||||
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||
return
|
||||
}
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_name", token.Name)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
channelId := common.String2Int(parts[1])
|
||||
if channelId == 0 {
|
||||
abortWithMessage(c, http.StatusForbidden, "无效的渠道 Id")
|
||||
return
|
||||
}
|
||||
c.Set("specific_channel_id", channelId)
|
||||
} else {
|
||||
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||
func tokenAuth(c *gin.Context, key string) {
|
||||
key = strings.TrimPrefix(key, "Bearer ")
|
||||
key = strings.TrimPrefix(key, "sk-")
|
||||
parts := strings.Split(key, "-")
|
||||
key = parts[0]
|
||||
token, err := model.ValidateUserToken(key)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if !userEnabled {
|
||||
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||
return
|
||||
}
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_name", token.Name)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
channelId := common.String2Int(parts[1])
|
||||
if channelId == 0 {
|
||||
abortWithMessage(c, http.StatusForbidden, "无效的渠道 Id")
|
||||
return
|
||||
}
|
||||
c.Set("specific_channel_id", channelId)
|
||||
} else {
|
||||
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
|
||||
func OpenaiAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
key := c.Request.Header.Get("Authorization")
|
||||
tokenAuth(c, key)
|
||||
}
|
||||
}
|
||||
|
||||
func MjAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
key := c.Request.Header.Get("mj-api-secret")
|
||||
tokenAuth(c, key)
|
||||
}
|
||||
}
|
||||
|
@ -114,6 +114,17 @@ func (cc *ChannelsChooser) GetGroupModels(group string) ([]string, error) {
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (cc *ChannelsChooser) GetChannel(channelId int) *Channel {
|
||||
cc.RLock()
|
||||
defer cc.RUnlock()
|
||||
|
||||
if choice, ok := cc.Channels[channelId]; ok {
|
||||
return choice.Channel
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var ChannelGroup = ChannelsChooser{}
|
||||
|
||||
func (cc *ChannelsChooser) Load() {
|
||||
|
@ -139,6 +139,10 @@ func InitDB() (err error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = db.AutoMigrate(&Midjourney{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
common.SysLog("database migrated")
|
||||
err = createRootAccountIfNeed()
|
||||
return err
|
||||
|
182
model/midjourney.go
Normal file
182
model/midjourney.go
Normal file
@ -0,0 +1,182 @@
|
||||
// Copyright (c) 2024 Calcium-Ion
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
// of this software and associated documentation files (the "Software"), to deal
|
||||
// in the Software without restriction, including without limitation the rights
|
||||
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
// copies of the Software, and to permit persons to whom the Software is
|
||||
// furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
// SOFTWARE.
|
||||
//
|
||||
// Author: Calcium-Ion
|
||||
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||
|
||||
package model
|
||||
|
||||
type Midjourney struct {
|
||||
Id int `json:"id"`
|
||||
Code int `json:"code"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Action string `json:"action" gorm:"type:varchar(40);index"`
|
||||
MjId string `json:"mj_id" gorm:"index"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"prompt_en"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
SubmitTime int64 `json:"submit_time" gorm:"index"`
|
||||
StartTime int64 `json:"start_time" gorm:"index"`
|
||||
FinishTime int64 `json:"finish_time" gorm:"index"`
|
||||
ImageUrl string `json:"image_url"`
|
||||
Status string `json:"status" gorm:"type:varchar(20);index"`
|
||||
Progress string `json:"progress" gorm:"type:varchar(30);index"`
|
||||
FailReason string `json:"fail_reason"`
|
||||
ChannelId int `json:"channel_id"`
|
||||
Quota int `json:"quota"`
|
||||
Buttons string `json:"buttons"`
|
||||
Properties string `json:"properties"`
|
||||
}
|
||||
|
||||
// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
|
||||
type TaskQueryParams struct {
|
||||
ChannelID int `form:"channel_id"`
|
||||
MjID string `form:"mj_id"`
|
||||
StartTimestamp int `form:"start_timestamp"`
|
||||
EndTimestamp int `form:"end_timestamp"`
|
||||
PaginationParams
|
||||
}
|
||||
|
||||
var allowedMidjourneyOrderFields = map[string]bool{
|
||||
"id": true,
|
||||
"user_id": true,
|
||||
"code": true,
|
||||
"action": true,
|
||||
"mj_id": true,
|
||||
"submit_time": true,
|
||||
"start_time": true,
|
||||
"finish_time": true,
|
||||
"status": true,
|
||||
"channel_id": true,
|
||||
}
|
||||
|
||||
func GetAllUserTask(userId int, params *TaskQueryParams) (*DataResult[Midjourney], error) {
|
||||
var tasks []*Midjourney
|
||||
|
||||
// 初始化查询构建器
|
||||
query := DB.Where("user_id = ?", userId)
|
||||
|
||||
if params.MjID != "" {
|
||||
query = query.Where("mj_id = ?", params.MjID)
|
||||
}
|
||||
if params.StartTimestamp != 0 {
|
||||
// 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析
|
||||
query = query.Where("submit_time >= ?", params.StartTimestamp)
|
||||
}
|
||||
if params.EndTimestamp != 0 {
|
||||
query = query.Where("submit_time <= ?", params.EndTimestamp)
|
||||
}
|
||||
|
||||
return PaginateAndOrder(query, ¶ms.PaginationParams, &tasks, allowedMidjourneyOrderFields)
|
||||
}
|
||||
|
||||
func GetAllTasks(params *TaskQueryParams) (*DataResult[Midjourney], error) {
|
||||
var tasks []*Midjourney
|
||||
|
||||
// 初始化查询构建器
|
||||
query := DB
|
||||
|
||||
// 添加过滤条件
|
||||
if params.ChannelID != 0 {
|
||||
query = query.Where("channel_id = ?", params.ChannelID)
|
||||
}
|
||||
if params.MjID != "" {
|
||||
query = query.Where("mj_id = ?", params.MjID)
|
||||
}
|
||||
if params.StartTimestamp != 0 {
|
||||
query = query.Where("submit_time >= ?", params.StartTimestamp)
|
||||
}
|
||||
if params.EndTimestamp != 0 {
|
||||
query = query.Where("submit_time <= ?", params.EndTimestamp)
|
||||
}
|
||||
|
||||
return PaginateAndOrder(query, ¶ms.PaginationParams, &tasks, allowedMidjourneyOrderFields)
|
||||
}
|
||||
|
||||
func GetAllUnFinishTasks() []*Midjourney {
|
||||
var tasks []*Midjourney
|
||||
// get all tasks progress is not 100%
|
||||
err := DB.Where("progress != ?", "100%").Find(&tasks).Error
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return tasks
|
||||
}
|
||||
|
||||
func GetByOnlyMJId(mjId string) *Midjourney {
|
||||
var mj *Midjourney
|
||||
err := DB.Where("mj_id = ?", mjId).First(&mj).Error
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return mj
|
||||
}
|
||||
|
||||
func GetByMJId(userId int, mjId string) *Midjourney {
|
||||
var mj *Midjourney
|
||||
err := DB.Where("user_id = ? and mj_id = ?", userId, mjId).First(&mj).Error
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return mj
|
||||
}
|
||||
|
||||
func GetByMJIds(userId int, mjIds []string) []*Midjourney {
|
||||
var mj []*Midjourney
|
||||
err := DB.Where("user_id = ? and mj_id in (?)", userId, mjIds).Find(&mj).Error
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return mj
|
||||
}
|
||||
|
||||
func GetMjByuId(id int) *Midjourney {
|
||||
var mj *Midjourney
|
||||
err := DB.Where("id = ?", id).First(&mj).Error
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return mj
|
||||
}
|
||||
|
||||
func UpdateProgress(id int, progress string) error {
|
||||
return DB.Model(&Midjourney{}).Where("id = ?", id).Update("progress", progress).Error
|
||||
}
|
||||
|
||||
func (midjourney *Midjourney) Insert() error {
|
||||
return DB.Create(midjourney).Error
|
||||
}
|
||||
|
||||
func (midjourney *Midjourney) Update() error {
|
||||
return DB.Save(midjourney).Error
|
||||
}
|
||||
|
||||
func MjBulkUpdate(mjIds []string, params map[string]any) error {
|
||||
return DB.Model(&Midjourney{}).
|
||||
Where("mj_id in (?)", mjIds).
|
||||
Updates(params).Error
|
||||
}
|
||||
|
||||
func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error {
|
||||
return DB.Model(&Midjourney{}).
|
||||
Where("id in (?)", taskIDs).
|
||||
Updates(params).Error
|
||||
}
|
@ -74,6 +74,8 @@ func InitOptionMap() {
|
||||
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
|
||||
common.OptionMap["RetryCooldownSeconds"] = strconv.Itoa(common.RetryCooldownSeconds)
|
||||
|
||||
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(common.MjNotifyEnabled)
|
||||
|
||||
common.OptionMapRWMutex.Unlock()
|
||||
loadOptionsFromDatabase()
|
||||
}
|
||||
@ -138,6 +140,7 @@ var optionBoolMap = map[string]*bool{
|
||||
"LogConsumeEnabled": &common.LogConsumeEnabled,
|
||||
"DisplayInCurrencyEnabled": &common.DisplayInCurrencyEnabled,
|
||||
"DisplayTokenStatEnabled": &common.DisplayTokenStatEnabled,
|
||||
"MjNotifyEnabled": &common.MjNotifyEnabled,
|
||||
}
|
||||
|
||||
var optionStringMap = map[string]*string{
|
||||
|
@ -301,5 +301,33 @@ func GetDefaultPrice() []*Price {
|
||||
})
|
||||
}
|
||||
|
||||
var DefaultMJPrice = map[string]float64{
|
||||
"mj_imagine": 50,
|
||||
"mj_variation": 50,
|
||||
"mj_reroll": 50,
|
||||
"mj_blend": 50,
|
||||
"mj_modal": 50,
|
||||
"mj_zoom": 50,
|
||||
"mj_shorten": 50,
|
||||
"mj_high_variation": 50,
|
||||
"mj_low_variation": 50,
|
||||
"mj_pan": 50,
|
||||
"mj_inpaint": 0,
|
||||
"mj_custom_zoom": 0,
|
||||
"mj_describe": 25,
|
||||
"mj_upscale": 25,
|
||||
"swap_face": 25,
|
||||
}
|
||||
|
||||
for model, mjPrice := range DefaultMJPrice {
|
||||
prices = append(prices, &Price{
|
||||
Model: model,
|
||||
Type: TimesPriceType,
|
||||
ChannelType: common.ChannelTypeMidjourney,
|
||||
Input: mjPrice,
|
||||
Output: mjPrice,
|
||||
})
|
||||
}
|
||||
|
||||
return prices
|
||||
}
|
||||
|
121
providers/midjourney/base.go
Normal file
121
providers/midjourney/base.go
Normal file
@ -0,0 +1,121 @@
|
||||
package midjourney
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 定义供应商工厂
|
||||
type MidjourneyProviderFactory struct{}
|
||||
|
||||
// 创建 MidjourneyProvider
|
||||
func (f MidjourneyProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &MidjourneyProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
Config: getConfig(),
|
||||
Channel: channel,
|
||||
Requester: requester.NewHTTPRequester(*channel.Proxy, nil),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "",
|
||||
}
|
||||
}
|
||||
|
||||
type MidjourneyProvider struct {
|
||||
base.BaseProvider
|
||||
}
|
||||
|
||||
func (p *MidjourneyProvider) Send(timeout int, requestURL string) (*MidjourneyResponseWithStatusCode, []byte, error) {
|
||||
var nullBytes []byte
|
||||
var mapResult map[string]interface{}
|
||||
if p.Context.Request.Method != "GET" {
|
||||
err := json.NewDecoder(p.Context.Request.Body).Decode(&mapResult)
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
||||
}
|
||||
delete(mapResult, "accountFilter")
|
||||
if !common.MjNotifyEnabled {
|
||||
delete(mapResult, "notifyHook")
|
||||
}
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(mapResult)
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(requestURL, "")
|
||||
|
||||
var cancel context.CancelFunc
|
||||
p.Requester.Context, cancel = context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
|
||||
headers := p.GetRequestHeaders()
|
||||
defer cancel()
|
||||
|
||||
req, err := p.Requester.NewRequest(p.Context.Request.Method, fullRequestURL, p.Requester.WithBody(bytes.NewBuffer(reqBody)), p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
|
||||
}
|
||||
|
||||
resp, errWith := p.Requester.SendRequestRaw(req)
|
||||
if errWith != nil {
|
||||
common.SysError("do request failed: " + errWith.Error())
|
||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
|
||||
}
|
||||
statusCode := resp.StatusCode
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
|
||||
}
|
||||
err = p.Context.Request.Body.Close()
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
|
||||
}
|
||||
var midjResponse MidjourneyResponse
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
|
||||
}
|
||||
respStr := string(responseBody)
|
||||
log.Printf("responseBody: %s", respStr)
|
||||
if respStr == "" {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
|
||||
} else {
|
||||
err = json.Unmarshal(responseBody, &midjResponse)
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
|
||||
}
|
||||
}
|
||||
|
||||
return &MidjourneyResponseWithStatusCode{
|
||||
StatusCode: statusCode,
|
||||
Response: midjResponse,
|
||||
}, responseBody, nil
|
||||
|
||||
}
|
||||
|
||||
func (p *MidjourneyProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
headers["mj-api-secret"] = p.Channel.Key
|
||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||
|
||||
return headers
|
||||
}
|
69
providers/midjourney/constant.go
Normal file
69
providers/midjourney/constant.go
Normal file
@ -0,0 +1,69 @@
|
||||
// Author: Calcium-Ion
|
||||
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||
// Path: relay/constant/relay_mode.go
|
||||
package midjourney
|
||||
|
||||
const (
|
||||
RelayModeUnknown = iota
|
||||
RelayModeMidjourneyImagine
|
||||
RelayModeMidjourneyDescribe
|
||||
RelayModeMidjourneyBlend
|
||||
RelayModeMidjourneyChange
|
||||
RelayModeMidjourneySimpleChange
|
||||
RelayModeMidjourneyNotify
|
||||
RelayModeMidjourneyTaskFetch
|
||||
RelayModeMidjourneyTaskImageSeed
|
||||
RelayModeMidjourneyTaskFetchByCondition
|
||||
RelayModeAudioSpeech
|
||||
RelayModeAudioTranscription
|
||||
RelayModeAudioTranslation
|
||||
RelayModeMidjourneyAction
|
||||
RelayModeMidjourneyModal
|
||||
RelayModeMidjourneyShorten
|
||||
RelayModeMidjourneySwapFace
|
||||
)
|
||||
|
||||
// Author: Calcium-Ion
|
||||
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||
// Path: constant/midjourney.go
|
||||
|
||||
const (
|
||||
MjErrorUnknown = 5
|
||||
MjRequestError = 4
|
||||
)
|
||||
|
||||
const (
|
||||
MjActionImagine = "IMAGINE"
|
||||
MjActionDescribe = "DESCRIBE"
|
||||
MjActionBlend = "BLEND"
|
||||
MjActionUpscale = "UPSCALE"
|
||||
MjActionVariation = "VARIATION"
|
||||
MjActionReRoll = "REROLL"
|
||||
MjActionInPaint = "INPAINT"
|
||||
MjActionModal = "MODAL"
|
||||
MjActionZoom = "ZOOM"
|
||||
MjActionCustomZoom = "CUSTOM_ZOOM"
|
||||
MjActionShorten = "SHORTEN"
|
||||
MjActionHighVariation = "HIGH_VARIATION"
|
||||
MjActionLowVariation = "LOW_VARIATION"
|
||||
MjActionPan = "PAN"
|
||||
MjActionSwapFace = "SWAP_FACE"
|
||||
)
|
||||
|
||||
var MidjourneyModel2Action = map[string]string{
|
||||
"mj_imagine": MjActionImagine,
|
||||
"mj_describe": MjActionDescribe,
|
||||
"mj_blend": MjActionBlend,
|
||||
"mj_upscale": MjActionUpscale,
|
||||
"mj_variation": MjActionVariation,
|
||||
"mj_reroll": MjActionReRoll,
|
||||
"mj_modal": MjActionModal,
|
||||
"mj_inpaint": MjActionInPaint,
|
||||
"mj_zoom": MjActionZoom,
|
||||
"mj_custom_zoom": MjActionCustomZoom,
|
||||
"mj_shorten": MjActionShorten,
|
||||
"mj_high_variation": MjActionHighVariation,
|
||||
"mj_low_variation": MjActionLowVariation,
|
||||
"mj_pan": MjActionPan,
|
||||
"swap_face": MjActionSwapFace,
|
||||
}
|
18
providers/midjourney/error.go
Normal file
18
providers/midjourney/error.go
Normal file
@ -0,0 +1,18 @@
|
||||
// Author: Calcium-Ion
|
||||
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||
// Path: service/error.go
|
||||
package midjourney
|
||||
|
||||
func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *MidjourneyResponseWithStatusCode {
|
||||
return &MidjourneyResponseWithStatusCode{
|
||||
StatusCode: statusCode,
|
||||
Response: *MidjourneyErrorWrapper(code, desc),
|
||||
}
|
||||
}
|
||||
|
||||
func MidjourneyErrorWrapper(code int, desc string) *MidjourneyResponse {
|
||||
return &MidjourneyResponse{
|
||||
Code: code,
|
||||
Description: desc,
|
||||
}
|
||||
}
|
92
providers/midjourney/type.go
Normal file
92
providers/midjourney/type.go
Normal file
@ -0,0 +1,92 @@
|
||||
// Author: Calcium-Ion
|
||||
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||
// Path: dto/midjourney.go
|
||||
package midjourney
|
||||
|
||||
type SwapFaceRequest struct {
|
||||
SourceBase64 string `json:"sourceBase64"`
|
||||
TargetBase64 string `json:"targetBase64"`
|
||||
}
|
||||
|
||||
type MidjourneyRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
CustomId string `json:"customId"`
|
||||
BotType string `json:"botType"`
|
||||
NotifyHook string `json:"notifyHook"`
|
||||
Action string `json:"action"`
|
||||
Index int `json:"index"`
|
||||
State string `json:"state"`
|
||||
TaskId string `json:"taskId"`
|
||||
Base64Array []string `json:"base64Array"`
|
||||
Content string `json:"content"`
|
||||
MaskBase64 string `json:"maskBase64"`
|
||||
}
|
||||
|
||||
type MidjourneyResponse struct {
|
||||
Code int `json:"code"`
|
||||
Description string `json:"description"`
|
||||
Properties interface{} `json:"properties"`
|
||||
Result string `json:"result"`
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
type MidjourneyResponseWithStatusCode struct {
|
||||
StatusCode int `json:"statusCode"`
|
||||
Response MidjourneyResponse
|
||||
}
|
||||
|
||||
type MidjourneyDto struct {
|
||||
MjId string `json:"id"`
|
||||
Action string `json:"action"`
|
||||
CustomId string `json:"customId"`
|
||||
BotType string `json:"botType"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"promptEn"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
SubmitTime int64 `json:"submitTime"`
|
||||
StartTime int64 `json:"startTime"`
|
||||
FinishTime int64 `json:"finishTime"`
|
||||
ImageUrl string `json:"imageUrl"`
|
||||
Status string `json:"status"`
|
||||
Progress string `json:"progress"`
|
||||
FailReason string `json:"failReason"`
|
||||
Buttons any `json:"buttons"`
|
||||
MaskBase64 string `json:"maskBase64"`
|
||||
Properties *Properties `json:"properties"`
|
||||
}
|
||||
|
||||
type MidjourneyStatus struct {
|
||||
Status int `json:"status"`
|
||||
}
|
||||
type MidjourneyWithoutStatus struct {
|
||||
Id int `json:"id"`
|
||||
Code int `json:"code"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Action string `json:"action"`
|
||||
MjId string `json:"mj_id" gorm:"index"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"prompt_en"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
SubmitTime int64 `json:"submit_time"`
|
||||
StartTime int64 `json:"start_time"`
|
||||
FinishTime int64 `json:"finish_time"`
|
||||
ImageUrl string `json:"image_url"`
|
||||
Progress string `json:"progress"`
|
||||
FailReason string `json:"fail_reason"`
|
||||
ChannelId int `json:"channel_id"`
|
||||
}
|
||||
|
||||
type ActionButton struct {
|
||||
CustomId any `json:"customId"`
|
||||
Emoji any `json:"emoji"`
|
||||
Label any `json:"label"`
|
||||
Type any `json:"type"`
|
||||
Style any `json:"style"`
|
||||
}
|
||||
|
||||
type Properties struct {
|
||||
FinalPrompt string `json:"finalPrompt"`
|
||||
FinalZhPrompt string `json:"finalZhPrompt"`
|
||||
}
|
@ -14,6 +14,7 @@ import (
|
||||
"one-api/providers/deepseek"
|
||||
"one-api/providers/gemini"
|
||||
"one-api/providers/groq"
|
||||
"one-api/providers/midjourney"
|
||||
"one-api/providers/minimax"
|
||||
"one-api/providers/mistral"
|
||||
"one-api/providers/openai"
|
||||
@ -52,6 +53,7 @@ func init() {
|
||||
providerFactories[common.ChannelTypeMistral] = mistral.MistralProviderFactory{}
|
||||
providerFactories[common.ChannelTypeGroq] = groq.GroqProviderFactory{}
|
||||
providerFactories[common.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{}
|
||||
providerFactories[common.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{}
|
||||
|
||||
}
|
||||
|
||||
|
@ -27,7 +27,7 @@ type RelayBaseInterface interface {
|
||||
}
|
||||
|
||||
func (r *relayBase) setProvider(modelName string) error {
|
||||
provider, modelName, fail := getProvider(r.c, modelName)
|
||||
provider, modelName, fail := GetProvider(r.c, modelName)
|
||||
if fail != nil {
|
||||
return fail
|
||||
}
|
||||
|
@ -45,7 +45,7 @@ func Path2Relay(c *gin.Context, path string) RelayBaseInterface {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) {
|
||||
func GetProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) {
|
||||
channel, fail := fetchChannel(c, modeName)
|
||||
if fail != nil {
|
||||
return
|
||||
|
19
relay/midjourney/LICENSE
Normal file
19
relay/midjourney/LICENSE
Normal file
@ -0,0 +1,19 @@
|
||||
Copyright (c) 2024 Calcium-Ion
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
578
relay/midjourney/relay-mj.go
Normal file
578
relay/midjourney/relay-mj.go
Normal file
@ -0,0 +1,578 @@
|
||||
// Author: Calcium-Ion
|
||||
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||
// Path: relay/relay-mj.go
|
||||
package midjourney
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/controller"
|
||||
"one-api/model"
|
||||
providersBase "one-api/providers/base"
|
||||
provider "one-api/providers/midjourney"
|
||||
"one-api/relay"
|
||||
"one-api/relay/util"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RelayMidjourneyImage(c *gin.Context) {
|
||||
taskId := c.Param("id")
|
||||
midjourneyTask := model.GetByOnlyMJId(taskId)
|
||||
if midjourneyTask == nil {
|
||||
c.JSON(400, gin.H{
|
||||
"error": "midjourney_task_not_found",
|
||||
})
|
||||
return
|
||||
}
|
||||
resp, err := http.Get(midjourneyTask.ImageUrl)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "http_get_image_failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
responseBody, _ := io.ReadAll(resp.Body)
|
||||
c.JSON(resp.StatusCode, gin.H{
|
||||
"error": string(responseBody),
|
||||
})
|
||||
return
|
||||
}
|
||||
// 从Content-Type头获取MIME类型
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
// 如果无法确定内容类型,则默认为jpeg
|
||||
contentType = "image/jpeg"
|
||||
}
|
||||
// 设置响应的内容类型
|
||||
c.Writer.Header().Set("Content-Type", contentType)
|
||||
// 将图片流式传输到响应体
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
log.Println("Failed to stream image:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func RelayMidjourneyNotify(c *gin.Context) *provider.MidjourneyResponse {
|
||||
var midjRequest provider.MidjourneyDto
|
||||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||
if err != nil {
|
||||
return &provider.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "bind_request_body_failed",
|
||||
Properties: nil,
|
||||
Result: "",
|
||||
}
|
||||
}
|
||||
midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId)
|
||||
if midjourneyTask == nil {
|
||||
return &provider.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "midjourney_task_not_found",
|
||||
Properties: nil,
|
||||
Result: "",
|
||||
}
|
||||
}
|
||||
midjourneyTask.Progress = midjRequest.Progress
|
||||
midjourneyTask.PromptEn = midjRequest.PromptEn
|
||||
midjourneyTask.State = midjRequest.State
|
||||
midjourneyTask.SubmitTime = midjRequest.SubmitTime
|
||||
midjourneyTask.StartTime = midjRequest.StartTime
|
||||
midjourneyTask.FinishTime = midjRequest.FinishTime
|
||||
midjourneyTask.ImageUrl = midjRequest.ImageUrl
|
||||
midjourneyTask.Status = midjRequest.Status
|
||||
midjourneyTask.FailReason = midjRequest.FailReason
|
||||
err = midjourneyTask.Update()
|
||||
if err != nil {
|
||||
return &provider.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "update_midjourney_task_failed",
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func coverMidjourneyTaskDto(originTask *model.Midjourney) (midjourneyTask provider.MidjourneyDto) {
|
||||
midjourneyTask.MjId = originTask.MjId
|
||||
midjourneyTask.Progress = originTask.Progress
|
||||
midjourneyTask.PromptEn = originTask.PromptEn
|
||||
midjourneyTask.State = originTask.State
|
||||
midjourneyTask.SubmitTime = originTask.SubmitTime
|
||||
midjourneyTask.StartTime = originTask.StartTime
|
||||
midjourneyTask.FinishTime = originTask.FinishTime
|
||||
midjourneyTask.ImageUrl = ""
|
||||
if originTask.ImageUrl != "" {
|
||||
midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
|
||||
if originTask.Status != "SUCCESS" {
|
||||
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||
}
|
||||
}
|
||||
midjourneyTask.Status = originTask.Status
|
||||
midjourneyTask.FailReason = originTask.FailReason
|
||||
midjourneyTask.Action = originTask.Action
|
||||
midjourneyTask.Description = originTask.Description
|
||||
midjourneyTask.Prompt = originTask.Prompt
|
||||
if originTask.Buttons != "" {
|
||||
var buttons []provider.ActionButton
|
||||
err := json.Unmarshal([]byte(originTask.Buttons), &buttons)
|
||||
if err == nil {
|
||||
midjourneyTask.Buttons = buttons
|
||||
}
|
||||
}
|
||||
if originTask.Properties != "" {
|
||||
var properties provider.Properties
|
||||
err := json.Unmarshal([]byte(originTask.Properties), &properties)
|
||||
if err == nil {
|
||||
midjourneyTask.Properties = &properties
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func RelaySwapFace(c *gin.Context) *provider.MidjourneyResponse {
|
||||
mjProvider, errWithMJ := getMJProvider(c, provider.RelayModeMidjourneySwapFace, 0, nil)
|
||||
if errWithMJ != nil {
|
||||
return errWithMJ
|
||||
}
|
||||
|
||||
startTime := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
userId := c.GetInt("id")
|
||||
var swapFaceRequest provider.SwapFaceRequest
|
||||
err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
|
||||
if err != nil {
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "bind_request_body_failed")
|
||||
}
|
||||
if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "sour_base64_and_target_base64_is_required")
|
||||
}
|
||||
|
||||
quotaInstance, errWithOA := getQuota(c, mjProvider.GetOriginalModel())
|
||||
if errWithOA != nil {
|
||||
return &provider.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: errWithOA.Message,
|
||||
}
|
||||
}
|
||||
requestURL := getMjRequestPath(c.Request.URL.String())
|
||||
mjResp, _, err := mjProvider.Send(60, requestURL)
|
||||
if err != nil {
|
||||
quotaInstance.Undo(c)
|
||||
return &mjResp.Response
|
||||
}
|
||||
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
||||
quotaInstance.Consume(c, &types.Usage{CompletionTokens: 0, PromptTokens: 1000, TotalTokens: 1000})
|
||||
} else {
|
||||
quotaInstance.Undo(c)
|
||||
}
|
||||
|
||||
quota := int(quotaInstance.GetInputRatio() * 1000)
|
||||
|
||||
midjResponse := &mjResp.Response
|
||||
midjourneyTask := &model.Midjourney{
|
||||
UserId: userId,
|
||||
Code: midjResponse.Code,
|
||||
Action: provider.MjActionSwapFace,
|
||||
MjId: midjResponse.Result,
|
||||
Prompt: "InsightFace",
|
||||
PromptEn: "",
|
||||
Description: midjResponse.Description,
|
||||
State: "",
|
||||
SubmitTime: startTime,
|
||||
StartTime: time.Now().UnixNano() / int64(time.Millisecond),
|
||||
FinishTime: 0,
|
||||
ImageUrl: "",
|
||||
Status: "",
|
||||
Progress: "0%",
|
||||
FailReason: "",
|
||||
ChannelId: c.GetInt("channel_id"),
|
||||
Quota: quota,
|
||||
}
|
||||
err = midjourneyTask.Insert()
|
||||
if err != nil {
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "insert_midjourney_task_failed")
|
||||
}
|
||||
// 开始激活任务
|
||||
controller.ActivateUpdateMidjourneyTaskBulk()
|
||||
|
||||
c.Writer.WriteHeader(mjResp.StatusCode)
|
||||
respBody, err := json.Marshal(midjResponse)
|
||||
if err != nil {
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "unmarshal_response_body_failed")
|
||||
}
|
||||
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||
if err != nil {
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "copy_response_body_failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RelayMidjourneyTaskImageSeed(c *gin.Context) *provider.MidjourneyResponse {
|
||||
taskId := c.Param("id")
|
||||
userId := c.GetInt("id")
|
||||
originTask := model.GetByMJId(userId, taskId)
|
||||
if originTask == nil {
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_no_found")
|
||||
}
|
||||
|
||||
mjProvider, errWithMJ := getMJProvider(c, provider.RelayModeMidjourneyTaskImageSeed, originTask.ChannelId, nil)
|
||||
if errWithMJ != nil {
|
||||
return errWithMJ
|
||||
}
|
||||
|
||||
requestURL := getMjRequestPath(c.Request.URL.String())
|
||||
midjResponseWithStatus, _, err := mjProvider.Send(30, requestURL)
|
||||
if err != nil {
|
||||
return &midjResponseWithStatus.Response
|
||||
}
|
||||
midjResponse := &midjResponseWithStatus.Response
|
||||
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
|
||||
respBody, err := json.Marshal(midjResponse)
|
||||
if err != nil {
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "unmarshal_response_body_failed")
|
||||
}
|
||||
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||
if err != nil {
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "copy_response_body_failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RelayMidjourneyTask(c *gin.Context, relayMode int) *provider.MidjourneyResponse {
|
||||
userId := c.GetInt("id")
|
||||
var err error
|
||||
var respBody []byte
|
||||
switch relayMode {
|
||||
case provider.RelayModeMidjourneyTaskFetch:
|
||||
taskId := c.Param("id")
|
||||
originTask := model.GetByMJId(userId, taskId)
|
||||
if originTask == nil {
|
||||
return &provider.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "task_no_found",
|
||||
}
|
||||
}
|
||||
midjourneyTask := coverMidjourneyTaskDto(originTask)
|
||||
respBody, err = json.Marshal(midjourneyTask)
|
||||
if err != nil {
|
||||
return &provider.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "unmarshal_response_body_failed",
|
||||
}
|
||||
}
|
||||
case provider.RelayModeMidjourneyTaskFetchByCondition:
|
||||
var condition = struct {
|
||||
IDs []string `json:"ids"`
|
||||
}{}
|
||||
err = c.BindJSON(&condition)
|
||||
if err != nil {
|
||||
return &provider.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "do_request_failed",
|
||||
}
|
||||
}
|
||||
var tasks []provider.MidjourneyDto
|
||||
if len(condition.IDs) != 0 {
|
||||
originTasks := model.GetByMJIds(userId, condition.IDs)
|
||||
for _, originTask := range originTasks {
|
||||
midjourneyTask := coverMidjourneyTaskDto(originTask)
|
||||
tasks = append(tasks, midjourneyTask)
|
||||
}
|
||||
}
|
||||
if tasks == nil {
|
||||
tasks = make([]provider.MidjourneyDto, 0)
|
||||
}
|
||||
respBody, err = json.Marshal(tasks)
|
||||
if err != nil {
|
||||
return &provider.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "unmarshal_response_body_failed",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
|
||||
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||
if err != nil {
|
||||
return &provider.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "copy_response_body_failed",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyResponse {
|
||||
channelId := 0
|
||||
userId := c.GetInt("id")
|
||||
consumeQuota := true
|
||||
var midjRequest provider.MidjourneyRequest
|
||||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||
if err != nil {
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "bind_request_body_failed")
|
||||
}
|
||||
|
||||
if relayMode == provider.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
|
||||
mjErr := CoverPlusActionToNormalAction(&midjRequest)
|
||||
if mjErr != nil {
|
||||
return mjErr
|
||||
}
|
||||
relayMode = provider.RelayModeMidjourneyChange
|
||||
}
|
||||
|
||||
if relayMode == provider.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
|
||||
if midjRequest.Prompt == "" {
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "prompt_is_required")
|
||||
}
|
||||
midjRequest.Action = provider.MjActionImagine
|
||||
} else if relayMode == provider.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
|
||||
midjRequest.Action = provider.MjActionDescribe
|
||||
} else if relayMode == provider.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
|
||||
midjRequest.Action = provider.MjActionShorten
|
||||
} else if relayMode == provider.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
||||
midjRequest.Action = provider.MjActionBlend
|
||||
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
|
||||
mjId := ""
|
||||
if relayMode == provider.RelayModeMidjourneyChange {
|
||||
if midjRequest.TaskId == "" {
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_id_is_required")
|
||||
} else if midjRequest.Action == "" {
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "action_is_required")
|
||||
} else if midjRequest.Index == 0 {
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "index_is_required")
|
||||
}
|
||||
//action = midjRequest.Action
|
||||
mjId = midjRequest.TaskId
|
||||
} else if relayMode == provider.RelayModeMidjourneySimpleChange {
|
||||
if midjRequest.Content == "" {
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "content_is_required")
|
||||
}
|
||||
params := ConvertSimpleChangeParams(midjRequest.Content)
|
||||
if params == nil {
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "content_parse_failed")
|
||||
}
|
||||
mjId = params.TaskId
|
||||
midjRequest.Action = params.Action
|
||||
} else if relayMode == provider.RelayModeMidjourneyModal {
|
||||
//if midjRequest.MaskBase64 == "" {
|
||||
// return provider.MidjourneyErrorWrapper(provider.MjRequestError, "mask_base64_is_required")
|
||||
//}
|
||||
mjId = midjRequest.TaskId
|
||||
midjRequest.Action = provider.MjActionModal
|
||||
}
|
||||
|
||||
originTask := model.GetByMJId(userId, mjId)
|
||||
if originTask == nil {
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_not_found")
|
||||
} else if originTask.Status != "SUCCESS" && relayMode != provider.RelayModeMidjourneyModal {
|
||||
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_status_not_success")
|
||||
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
||||
channelId = originTask.ChannelId
|
||||
log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %d", originTask.ChannelId)
|
||||
}
|
||||
midjRequest.Prompt = originTask.Prompt
|
||||
|
||||
//if channelType == common.ChannelTypeMidjourneyPlus {
|
||||
// // plus
|
||||
//} else {
|
||||
// // 普通版渠道
|
||||
//
|
||||
//}
|
||||
}
|
||||
|
||||
if midjRequest.Action == provider.MjActionInPaint || midjRequest.Action == provider.MjActionCustomZoom {
|
||||
consumeQuota = false
|
||||
}
|
||||
|
||||
mjProvider, errWithMJ := getMJProvider(c, relayMode, channelId, &midjRequest)
|
||||
if errWithMJ != nil {
|
||||
return errWithMJ
|
||||
}
|
||||
|
||||
//baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := getMjRequestPath(c.Request.URL.String())
|
||||
|
||||
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
|
||||
|
||||
quotaInstance, errWithOA := getQuota(c, mjProvider.GetOriginalModel())
|
||||
if errWithOA != nil {
|
||||
return &provider.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: errWithOA.Message,
|
||||
}
|
||||
}
|
||||
|
||||
midjResponseWithStatus, responseBody, err := mjProvider.Send(60, requestURL)
|
||||
if err != nil {
|
||||
quotaInstance.Undo(c)
|
||||
return &midjResponseWithStatus.Response
|
||||
}
|
||||
|
||||
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
||||
quotaInstance.Consume(c, &types.Usage{CompletionTokens: 0, PromptTokens: 1, TotalTokens: 1})
|
||||
} else {
|
||||
quotaInstance.Undo(c)
|
||||
}
|
||||
quota := int(quotaInstance.GetInputRatio() * 1000)
|
||||
|
||||
midjResponse := &midjResponseWithStatus.Response
|
||||
|
||||
// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
|
||||
//1-提交成功
|
||||
// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
|
||||
// 22-排队中 {"code":22,"description":"排队中,前面还有1个任务","result":"0741798445574458","properties":{"numberOfQueues":1,"discordInstanceId":"1118138338562560102"}}
|
||||
// 23-队列已满,请稍后再试 {"code":23,"description":"队列已满,请稍后尝试","result":"14001929738841620","properties":{"discordInstanceId":"1118138338562560102"}}
|
||||
// 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}}
|
||||
// other: 提交错误,description为错误描述
|
||||
midjourneyTask := &model.Midjourney{
|
||||
UserId: userId,
|
||||
Code: midjResponse.Code,
|
||||
Action: midjRequest.Action,
|
||||
MjId: midjResponse.Result,
|
||||
Prompt: midjRequest.Prompt,
|
||||
PromptEn: "",
|
||||
Description: midjResponse.Description,
|
||||
State: "",
|
||||
SubmitTime: time.Now().UnixNano() / int64(time.Millisecond),
|
||||
StartTime: 0,
|
||||
FinishTime: 0,
|
||||
ImageUrl: "",
|
||||
Status: "",
|
||||
Progress: "0%",
|
||||
FailReason: "",
|
||||
ChannelId: c.GetInt("channel_id"),
|
||||
Quota: quota,
|
||||
}
|
||||
|
||||
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
|
||||
//非1-提交成功,21-任务已存在和22-排队中,则记录错误原因
|
||||
midjourneyTask.FailReason = midjResponse.Description
|
||||
consumeQuota = false
|
||||
}
|
||||
|
||||
if midjResponse.Code == 21 { //21-任务已存在(处理中或者有结果了)
|
||||
// 将 properties 转换为一个 map
|
||||
properties, ok := midjResponse.Properties.(map[string]interface{})
|
||||
if ok {
|
||||
imageUrl, ok1 := properties["imageUrl"].(string)
|
||||
status, ok2 := properties["status"].(string)
|
||||
if ok1 && ok2 {
|
||||
midjourneyTask.ImageUrl = imageUrl
|
||||
midjourneyTask.Status = status
|
||||
if status == "SUCCESS" {
|
||||
midjourneyTask.Progress = "100%"
|
||||
midjourneyTask.StartTime = time.Now().UnixNano() / int64(time.Millisecond)
|
||||
midjourneyTask.FinishTime = time.Now().UnixNano() / int64(time.Millisecond)
|
||||
midjResponse.Code = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
//修改返回值
|
||||
if midjRequest.Action != provider.MjActionInPaint && midjRequest.Action != provider.MjActionCustomZoom {
|
||||
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
|
||||
responseBody = []byte(newBody)
|
||||
}
|
||||
}
|
||||
|
||||
err = midjourneyTask.Insert()
|
||||
if err != nil {
|
||||
return &provider.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "insert_midjourney_task_failed",
|
||||
}
|
||||
}
|
||||
// 开始激活任务
|
||||
controller.ActivateUpdateMidjourneyTaskBulk()
|
||||
|
||||
if midjResponse.Code == 22 { //22-排队中,说明任务已存在
|
||||
//修改返回值
|
||||
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
|
||||
responseBody = []byte(newBody)
|
||||
}
|
||||
|
||||
//resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
|
||||
//for k, v := range resp.Header {
|
||||
// c.Writer.Header().Set(k, v[0])
|
||||
//}
|
||||
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
|
||||
|
||||
_, err = io.Copy(c.Writer, bodyReader)
|
||||
if err != nil {
|
||||
return &provider.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "copy_response_body_failed",
|
||||
}
|
||||
}
|
||||
err = bodyReader.Close()
|
||||
if err != nil {
|
||||
return &provider.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "close_response_body_failed",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMjRequestPath(path string) string {
|
||||
requestURL := path
|
||||
if strings.Contains(requestURL, "/mj-") {
|
||||
urls := strings.Split(requestURL, "/mj/")
|
||||
if len(urls) < 2 {
|
||||
return requestURL
|
||||
}
|
||||
requestURL = "/mj/" + urls[1]
|
||||
}
|
||||
return requestURL
|
||||
}
|
||||
|
||||
func getQuota(c *gin.Context, modelName string) (*util.Quota, *types.OpenAIErrorWithStatusCode) {
|
||||
// modelName = CoverActionToModelName(modelName)
|
||||
|
||||
return util.NewQuota(c, modelName, 1000)
|
||||
}
|
||||
|
||||
func getMJProvider(c *gin.Context, relayMode, channel_id int, request *provider.MidjourneyRequest) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
|
||||
var baseProvider providersBase.ProviderInterface
|
||||
modelName := ""
|
||||
if channel_id > 0 {
|
||||
c.Set("specific_channel_id", channel_id)
|
||||
}
|
||||
|
||||
if request != nil {
|
||||
midjourneyModel, mjErr, _ := GetMjRequestModel(relayMode, request)
|
||||
if mjErr != nil {
|
||||
return nil, MidjourneyErrorFromInternal(mjErr.Code, mjErr.Description)
|
||||
}
|
||||
if midjourneyModel == "" {
|
||||
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无效的请求, 无法解析模型")
|
||||
}
|
||||
|
||||
modelName = midjourneyModel
|
||||
}
|
||||
|
||||
var err error
|
||||
baseProvider, _, err = relay.GetProvider(c, modelName)
|
||||
if err != nil {
|
||||
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无法获取provider:"+err.Error())
|
||||
}
|
||||
|
||||
mjProvider, ok := baseProvider.(*provider.MidjourneyProvider)
|
||||
if !ok {
|
||||
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无效的请求, 无法获取midjourney provider")
|
||||
}
|
||||
|
||||
return mjProvider, nil
|
||||
}
|
95
relay/midjourney/relay.go
Normal file
95
relay/midjourney/relay.go
Normal file
@ -0,0 +1,95 @@
|
||||
// Author: Calcium-Ion
|
||||
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||
// Path: controller/relay.go
|
||||
package midjourney
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
provider "one-api/providers/midjourney"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RelayMidjourney(c *gin.Context) {
|
||||
relayMode := Path2RelayModeMidjourney(c.Request.URL.Path)
|
||||
var err *provider.MidjourneyResponse
|
||||
switch relayMode {
|
||||
case provider.RelayModeMidjourneyNotify:
|
||||
err = RelayMidjourneyNotify(c)
|
||||
case provider.RelayModeMidjourneyTaskFetch, provider.RelayModeMidjourneyTaskFetchByCondition:
|
||||
err = RelayMidjourneyTask(c, relayMode)
|
||||
case provider.RelayModeMidjourneyTaskImageSeed:
|
||||
err = RelayMidjourneyTaskImageSeed(c)
|
||||
case provider.RelayModeMidjourneySwapFace:
|
||||
err = RelaySwapFace(c)
|
||||
default:
|
||||
err = RelayMidjourneySubmit(c, relayMode)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
statusCode := http.StatusBadRequest
|
||||
if err.Code == 30 {
|
||||
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
||||
statusCode = http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
typeMsg := "upstream_error"
|
||||
if err.Type != "" {
|
||||
typeMsg = err.Type
|
||||
}
|
||||
c.JSON(statusCode, gin.H{
|
||||
"description": fmt.Sprintf("%s %s", err.Description, err.Result),
|
||||
"type": typeMsg,
|
||||
"code": err.Code,
|
||||
})
|
||||
channelId := c.GetInt("channel_id")
|
||||
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
|
||||
}
|
||||
}
|
||||
|
||||
func MidjourneyErrorFromInternal(code int, description string) *provider.MidjourneyResponse {
|
||||
return &provider.MidjourneyResponse{
|
||||
Code: code,
|
||||
Description: description,
|
||||
Type: "internal_error",
|
||||
}
|
||||
}
|
||||
|
||||
func Path2RelayModeMidjourney(path string) int {
|
||||
relayMode := provider.RelayModeUnknown
|
||||
if strings.HasSuffix(path, "/mj/submit/action") {
|
||||
// midjourney plus
|
||||
relayMode = provider.RelayModeMidjourneyAction
|
||||
} else if strings.HasSuffix(path, "/mj/submit/modal") {
|
||||
// midjourney plus
|
||||
relayMode = provider.RelayModeMidjourneyModal
|
||||
} else if strings.HasSuffix(path, "/mj/submit/shorten") {
|
||||
// midjourney plus
|
||||
relayMode = provider.RelayModeMidjourneyShorten
|
||||
} else if strings.HasSuffix(path, "/mj/insight-face/swap") {
|
||||
// midjourney plus
|
||||
relayMode = provider.RelayModeMidjourneySwapFace
|
||||
} else if strings.HasSuffix(path, "/mj/submit/imagine") {
|
||||
relayMode = provider.RelayModeMidjourneyImagine
|
||||
} else if strings.HasSuffix(path, "/mj/submit/blend") {
|
||||
relayMode = provider.RelayModeMidjourneyBlend
|
||||
} else if strings.HasSuffix(path, "/mj/submit/describe") {
|
||||
relayMode = provider.RelayModeMidjourneyDescribe
|
||||
} else if strings.HasSuffix(path, "/mj/notify") {
|
||||
relayMode = provider.RelayModeMidjourneyNotify
|
||||
} else if strings.HasSuffix(path, "/mj/submit/change") {
|
||||
relayMode = provider.RelayModeMidjourneyChange
|
||||
} else if strings.HasSuffix(path, "/mj/submit/simple-change") {
|
||||
relayMode = provider.RelayModeMidjourneyChange
|
||||
} else if strings.HasSuffix(path, "/fetch") {
|
||||
relayMode = provider.RelayModeMidjourneyTaskFetch
|
||||
} else if strings.HasSuffix(path, "/image-seed") {
|
||||
relayMode = provider.RelayModeMidjourneyTaskImageSeed
|
||||
} else if strings.HasSuffix(path, "/list-by-condition") {
|
||||
relayMode = provider.RelayModeMidjourneyTaskFetchByCondition
|
||||
}
|
||||
return relayMode
|
||||
}
|
148
relay/midjourney/service.go
Normal file
148
relay/midjourney/service.go
Normal file
@ -0,0 +1,148 @@
|
||||
// Author: Calcium-Ion
|
||||
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||
// Path: service/midjourney.go
|
||||
package midjourney
|
||||
|
||||
import (
|
||||
mjProvider "one-api/providers/midjourney"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func CoverActionToModelName(mjAction string) string {
|
||||
modelName := "mj_" + strings.ToLower(mjAction)
|
||||
if mjAction == mjProvider.MjActionSwapFace {
|
||||
modelName = "swap_face"
|
||||
}
|
||||
return modelName
|
||||
}
|
||||
|
||||
func GetMjRequestModel(relayMode int, midjRequest *mjProvider.MidjourneyRequest) (string, *mjProvider.MidjourneyResponse, bool) {
|
||||
action := ""
|
||||
if relayMode == mjProvider.RelayModeMidjourneyAction {
|
||||
// plus request
|
||||
err := CoverPlusActionToNormalAction(midjRequest)
|
||||
if err != nil {
|
||||
return "", err, false
|
||||
}
|
||||
action = midjRequest.Action
|
||||
} else {
|
||||
switch relayMode {
|
||||
case mjProvider.RelayModeMidjourneyImagine:
|
||||
action = mjProvider.MjActionImagine
|
||||
case mjProvider.RelayModeMidjourneyDescribe:
|
||||
action = mjProvider.MjActionDescribe
|
||||
case mjProvider.RelayModeMidjourneyBlend:
|
||||
action = mjProvider.MjActionBlend
|
||||
case mjProvider.RelayModeMidjourneyShorten:
|
||||
action = mjProvider.MjActionShorten
|
||||
case mjProvider.RelayModeMidjourneyChange:
|
||||
action = midjRequest.Action
|
||||
case mjProvider.RelayModeMidjourneyModal:
|
||||
action = mjProvider.MjActionModal
|
||||
case mjProvider.RelayModeMidjourneySwapFace:
|
||||
action = mjProvider.MjActionSwapFace
|
||||
case mjProvider.RelayModeMidjourneySimpleChange:
|
||||
params := ConvertSimpleChangeParams(midjRequest.Content)
|
||||
if params == nil {
|
||||
return "", mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "invalid_request"), false
|
||||
}
|
||||
action = params.Action
|
||||
case mjProvider.RelayModeMidjourneyTaskFetch, mjProvider.RelayModeMidjourneyTaskFetchByCondition, mjProvider.RelayModeMidjourneyNotify:
|
||||
return "", nil, true
|
||||
default:
|
||||
return "", mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "unknown_relay_action"), false
|
||||
}
|
||||
}
|
||||
|
||||
modelName := CoverActionToModelName(action)
|
||||
return modelName, nil, true
|
||||
}
|
||||
|
||||
func CoverPlusActionToNormalAction(midjRequest *mjProvider.MidjourneyRequest) *mjProvider.MidjourneyResponse {
|
||||
// "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
|
||||
customId := midjRequest.CustomId
|
||||
if customId == "" {
|
||||
return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "custom_id_is_required")
|
||||
}
|
||||
splits := strings.Split(customId, "::")
|
||||
var action string
|
||||
if splits[1] == "JOB" {
|
||||
action = splits[2]
|
||||
} else {
|
||||
action = splits[1]
|
||||
}
|
||||
|
||||
if action == "" {
|
||||
return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "unknown_action")
|
||||
}
|
||||
if strings.Contains(action, "upsample") {
|
||||
index, err := strconv.Atoi(splits[3])
|
||||
if err != nil {
|
||||
return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "index_parse_failed")
|
||||
}
|
||||
midjRequest.Index = index
|
||||
midjRequest.Action = mjProvider.MjActionUpscale
|
||||
} else if strings.Contains(action, "variation") {
|
||||
midjRequest.Index = 1
|
||||
if action == "variation" {
|
||||
index, err := strconv.Atoi(splits[3])
|
||||
if err != nil {
|
||||
return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "index_parse_failed")
|
||||
}
|
||||
midjRequest.Index = index
|
||||
midjRequest.Action = mjProvider.MjActionVariation
|
||||
} else if action == "low_variation" {
|
||||
midjRequest.Action = mjProvider.MjActionLowVariation
|
||||
} else if action == "high_variation" {
|
||||
midjRequest.Action = mjProvider.MjActionHighVariation
|
||||
}
|
||||
} else if strings.Contains(action, "pan") {
|
||||
midjRequest.Action = mjProvider.MjActionPan
|
||||
midjRequest.Index = 1
|
||||
} else if strings.Contains(action, "reroll") {
|
||||
midjRequest.Action = mjProvider.MjActionReRoll
|
||||
midjRequest.Index = 1
|
||||
} else if action == "Outpaint" {
|
||||
midjRequest.Action = mjProvider.MjActionZoom
|
||||
midjRequest.Index = 1
|
||||
} else if action == "CustomZoom" {
|
||||
midjRequest.Action = mjProvider.MjActionCustomZoom
|
||||
midjRequest.Index = 1
|
||||
} else if action == "Inpaint" {
|
||||
midjRequest.Action = mjProvider.MjActionInPaint
|
||||
midjRequest.Index = 1
|
||||
} else {
|
||||
return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "unknown_action:"+customId)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ConvertSimpleChangeParams(content string) *mjProvider.MidjourneyRequest {
|
||||
split := strings.Split(content, " ")
|
||||
if len(split) != 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
action := strings.ToLower(split[1])
|
||||
changeParams := &mjProvider.MidjourneyRequest{}
|
||||
changeParams.TaskId = split[0]
|
||||
|
||||
if action[0] == 'u' {
|
||||
changeParams.Action = "UPSCALE"
|
||||
} else if action[0] == 'v' {
|
||||
changeParams.Action = "VARIATION"
|
||||
} else if action == "r" {
|
||||
changeParams.Action = "REROLL"
|
||||
return changeParams
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
index, err := strconv.Atoi(action[1:2])
|
||||
if err != nil || index < 1 || index > 4 {
|
||||
return nil
|
||||
}
|
||||
changeParams.Index = index
|
||||
return changeParams
|
||||
}
|
@ -170,3 +170,7 @@ func (q *Quota) Consume(c *gin.Context, usage *types.Usage) {
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
|
||||
func (q *Quota) GetInputRatio() float64 {
|
||||
return q.inputRatio
|
||||
}
|
||||
|
@ -7,22 +7,23 @@ var ModelOwnedBy map[int]string
|
||||
|
||||
func init() {
|
||||
ModelOwnedBy = map[int]string{
|
||||
common.ChannelTypeOpenAI: "OpenAI",
|
||||
common.ChannelTypeAnthropic: "Anthropic",
|
||||
common.ChannelTypeBaidu: "Baidu",
|
||||
common.ChannelTypePaLM: "Google PaLM",
|
||||
common.ChannelTypeGemini: "Google Gemini",
|
||||
common.ChannelTypeZhipu: "Zhipu",
|
||||
common.ChannelTypeAli: "Ali",
|
||||
common.ChannelTypeXunfei: "Xunfei",
|
||||
common.ChannelType360: "360",
|
||||
common.ChannelTypeTencent: "Tencent",
|
||||
common.ChannelTypeBaichuan: "Baichuan",
|
||||
common.ChannelTypeMiniMax: "MiniMax",
|
||||
common.ChannelTypeDeepseek: "Deepseek",
|
||||
common.ChannelTypeMoonshot: "Moonshot",
|
||||
common.ChannelTypeMistral: "Mistral",
|
||||
common.ChannelTypeGroq: "Groq",
|
||||
common.ChannelTypeLingyi: "Lingyiwanwu",
|
||||
common.ChannelTypeOpenAI: "OpenAI",
|
||||
common.ChannelTypeAnthropic: "Anthropic",
|
||||
common.ChannelTypeBaidu: "Baidu",
|
||||
common.ChannelTypePaLM: "Google PaLM",
|
||||
common.ChannelTypeGemini: "Google Gemini",
|
||||
common.ChannelTypeZhipu: "Zhipu",
|
||||
common.ChannelTypeAli: "Ali",
|
||||
common.ChannelTypeXunfei: "Xunfei",
|
||||
common.ChannelType360: "360",
|
||||
common.ChannelTypeTencent: "Tencent",
|
||||
common.ChannelTypeBaichuan: "Baichuan",
|
||||
common.ChannelTypeMiniMax: "MiniMax",
|
||||
common.ChannelTypeDeepseek: "Deepseek",
|
||||
common.ChannelTypeMoonshot: "Moonshot",
|
||||
common.ChannelTypeMistral: "Mistral",
|
||||
common.ChannelTypeGroq: "Groq",
|
||||
common.ChannelTypeLingyi: "Lingyiwanwu",
|
||||
common.ChannelTypeMidjourney: "Midjourney",
|
||||
}
|
||||
}
|
||||
|
@ -145,6 +145,9 @@ func SetApiRouter(router *gin.Engine) {
|
||||
|
||||
}
|
||||
|
||||
mjRoute := apiRouter.Group("/mj")
|
||||
mjRoute.GET("/self", middleware.UserAuth(), controller.GetUserMidjourney)
|
||||
mjRoute.GET("/", middleware.AdminAuth(), controller.GetAllMidjourney)
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,10 +1,11 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"github.com/gin-contrib/gzip"
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/controller"
|
||||
"one-api/middleware"
|
||||
|
||||
"github.com/gin-contrib/gzip"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func SetDashboardRouter(router *gin.Engine) {
|
||||
@ -12,7 +13,7 @@ func SetDashboardRouter(router *gin.Engine) {
|
||||
apiRouter := router.Group("/")
|
||||
apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
|
||||
apiRouter.Use(middleware.GlobalAPIRateLimit())
|
||||
apiRouter.Use(middleware.TokenAuth())
|
||||
apiRouter.Use(middleware.OpenaiAuth())
|
||||
{
|
||||
apiRouter.GET("/dashboard/billing/subscription", controller.GetSubscription)
|
||||
apiRouter.GET("/v1/dashboard/billing/subscription", controller.GetSubscription)
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"one-api/controller"
|
||||
"one-api/middleware"
|
||||
"one-api/relay"
|
||||
"one-api/relay/midjourney"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@ -11,14 +12,19 @@ import (
|
||||
func SetRelayRouter(router *gin.Engine) {
|
||||
router.Use(middleware.CORS())
|
||||
// https://platform.openai.com/docs/api-reference/introduction
|
||||
setOpenAIRouter(router)
|
||||
setMJRouter(router)
|
||||
}
|
||||
|
||||
func setOpenAIRouter(router *gin.Engine) {
|
||||
modelsRouter := router.Group("/v1/models")
|
||||
modelsRouter.Use(middleware.TokenAuth(), middleware.Distribute())
|
||||
modelsRouter.Use(middleware.OpenaiAuth(), middleware.Distribute())
|
||||
{
|
||||
modelsRouter.GET("", relay.ListModels)
|
||||
modelsRouter.GET("/:model", relay.RetrieveModel)
|
||||
}
|
||||
relayV1Router := router.Group("/v1")
|
||||
relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute())
|
||||
relayV1Router.Use(middleware.RelayPanicRecover(), middleware.OpenaiAuth(), middleware.Distribute())
|
||||
{
|
||||
relayV1Router.POST("/completions", relay.Relay)
|
||||
relayV1Router.POST("/chat/completions", relay.Relay)
|
||||
@ -71,3 +77,34 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
relayV1Router.GET("/threads/:id/runs/:runsId/steps", controller.RelayNotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
func setMJRouter(router *gin.Engine) {
|
||||
relayMjRouter := router.Group("/mj")
|
||||
registerMjRouterGroup(relayMjRouter)
|
||||
|
||||
relayMjModeRouter := router.Group("/:mode/mj")
|
||||
registerMjRouterGroup(relayMjModeRouter)
|
||||
}
|
||||
|
||||
// Author: Calcium-Ion
|
||||
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||
// Path: router/relay-router.go
|
||||
func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
|
||||
relayMjRouter.GET("/image/:id", midjourney.RelayMidjourneyImage)
|
||||
relayMjRouter.Use(middleware.MjAuth(), middleware.Distribute())
|
||||
{
|
||||
relayMjRouter.POST("/submit/action", midjourney.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/shorten", midjourney.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/modal", midjourney.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/imagine", midjourney.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/change", midjourney.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/simple-change", midjourney.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/describe", midjourney.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/blend", midjourney.RelayMidjourney)
|
||||
relayMjRouter.POST("/notify", midjourney.RelayMidjourney)
|
||||
relayMjRouter.GET("/task/:id/fetch", midjourney.RelayMidjourney)
|
||||
relayMjRouter.GET("/task/:id/image-seed", midjourney.RelayMidjourney)
|
||||
relayMjRouter.POST("/task/list-by-condition", midjourney.RelayMidjourney)
|
||||
relayMjRouter.POST("/insight-face/swap", midjourney.RelayMidjourney)
|
||||
}
|
||||
}
|
||||
|
@ -7,7 +7,7 @@
|
||||
使用了以下开源项目作为我们项目的一部分:
|
||||
|
||||
- [Berry Free React Admin Template](https://github.com/codedthemes/berry-free-react-admin-template)
|
||||
- [minimal-ui-kit](minimal-ui-kit)
|
||||
- [minimal-ui-kit](https://github.com/minimal-ui-kit/material-kit-react)
|
||||
|
||||
## 许可证
|
||||
|
||||
|
@ -132,6 +132,13 @@ export const CHANNEL_OPTIONS = {
|
||||
color: 'primary',
|
||||
url: 'https://platform.lingyiwanwu.com/details'
|
||||
},
|
||||
34: {
|
||||
key: 34,
|
||||
text: 'Midjourney',
|
||||
value: 34,
|
||||
color: 'orange',
|
||||
url: ''
|
||||
},
|
||||
24: {
|
||||
key: 24,
|
||||
text: 'Azure Speech',
|
||||
|
@ -11,7 +11,8 @@ import {
|
||||
IconUserScan,
|
||||
IconActivity,
|
||||
IconBrandTelegram,
|
||||
IconReceipt2
|
||||
IconReceipt2,
|
||||
IconBrush
|
||||
} from '@tabler/icons-react';
|
||||
|
||||
// constant
|
||||
@ -27,7 +28,8 @@ const icons = {
|
||||
IconUserScan,
|
||||
IconActivity,
|
||||
IconBrandTelegram,
|
||||
IconReceipt2
|
||||
IconReceipt2,
|
||||
IconBrush
|
||||
};
|
||||
|
||||
// ==============================|| DASHBOARD MENU ITEMS ||============================== //
|
||||
@ -96,6 +98,14 @@ const panel = {
|
||||
icon: icons.IconGardenCart,
|
||||
breadcrumbs: false
|
||||
},
|
||||
{
|
||||
id: 'midjourney',
|
||||
title: 'Midjourney',
|
||||
type: 'item',
|
||||
url: '/panel/midjourney',
|
||||
icon: icons.IconBrush,
|
||||
breadcrumbs: false
|
||||
},
|
||||
{
|
||||
id: 'user',
|
||||
title: '用户',
|
||||
|
@ -16,6 +16,7 @@ const NotFoundView = Loadable(lazy(() => import('views/Error')));
|
||||
const Analytics = Loadable(lazy(() => import('views/Analytics')));
|
||||
const Telegram = Loadable(lazy(() => import('views/Telegram')));
|
||||
const Pricing = Loadable(lazy(() => import('views/Pricing')));
|
||||
const Midjourney = Loadable(lazy(() => import('views/Midjourney')));
|
||||
|
||||
// dashboard routing
|
||||
const Dashboard = Loadable(lazy(() => import('views/Dashboard')));
|
||||
@ -81,6 +82,10 @@ const MainRoutes = {
|
||||
{
|
||||
path: 'pricing',
|
||||
element: <Pricing />
|
||||
},
|
||||
{
|
||||
path: 'midjourney',
|
||||
element: <Midjourney />
|
||||
}
|
||||
]
|
||||
};
|
||||
|
@ -12,15 +12,7 @@ export default function componentStyleOverrides(theme) {
|
||||
}
|
||||
}
|
||||
},
|
||||
MuiMenuItem: {
|
||||
styleOverrides: {
|
||||
root: {
|
||||
'&:hover': {
|
||||
backgroundColor: theme.colors?.grey100
|
||||
}
|
||||
}
|
||||
}
|
||||
}, //MuiAutocomplete-popper MuiPopover-root
|
||||
//MuiAutocomplete-popper MuiPopover-root
|
||||
MuiAutocomplete: {
|
||||
styleOverrides: {
|
||||
popper: {
|
||||
@ -247,7 +239,7 @@ export default function componentStyleOverrides(theme) {
|
||||
MuiTooltip: {
|
||||
styleOverrides: {
|
||||
tooltip: {
|
||||
color: theme.paper,
|
||||
color: theme.colors.paper,
|
||||
background: theme.colors?.grey700
|
||||
}
|
||||
}
|
||||
@ -266,6 +258,9 @@ export default function componentStyleOverrides(theme) {
|
||||
.apexcharts-menu {
|
||||
background: ${theme.backgroundDefault} !important
|
||||
}
|
||||
.apexcharts-gridline, .apexcharts-xaxistooltip-background, .apexcharts-yaxistooltip-background {
|
||||
stroke: ${theme.divider} !important;
|
||||
}
|
||||
`
|
||||
}
|
||||
};
|
||||
|
@ -19,14 +19,14 @@ const Footer = () => {
|
||||
{siteInfo.system_name} {process.env.REACT_APP_VERSION}{' '}
|
||||
</Link>
|
||||
由{' '}
|
||||
<Link href="https://github.com/songquanpeng" target="_blank">
|
||||
JustSong
|
||||
</Link>{' '}
|
||||
构建,
|
||||
<Link href="https://github.com/MartialBE" target="_blank">
|
||||
MartialBE
|
||||
</Link>
|
||||
修改,源代码遵循
|
||||
开发,基于
|
||||
<Link href="https://github.com/songquanpeng" target="_blank">
|
||||
JustSong
|
||||
</Link>{' '}
|
||||
One API,源代码遵循
|
||||
<Link href="https://opensource.org/licenses/mit-license.php"> MIT 协议</Link>
|
||||
</>
|
||||
)}
|
||||
|
@ -234,6 +234,33 @@ const typeConfig = {
|
||||
test_model: 'yi-34b-chat-0205'
|
||||
},
|
||||
modelGroup: 'Lingyiwanwu'
|
||||
},
|
||||
34: {
|
||||
input: {
|
||||
models: [
|
||||
'mj_imagine',
|
||||
'mj_variation',
|
||||
'mj_reroll',
|
||||
'mj_blend',
|
||||
'mj_modal',
|
||||
'mj_zoom',
|
||||
'mj_shorten',
|
||||
'mj_high_variation',
|
||||
'mj_low_variation',
|
||||
'mj_pan',
|
||||
'mj_inpaint',
|
||||
'mj_custom_zoom',
|
||||
'mj_describe',
|
||||
'mj_upscale',
|
||||
'swap_face'
|
||||
]
|
||||
},
|
||||
prompt: {
|
||||
key: '密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填',
|
||||
base_url: '地址填写midjourney-proxy部署的地址',
|
||||
test_model: ''
|
||||
},
|
||||
modelGroup: 'Midjourney'
|
||||
}
|
||||
};
|
||||
|
||||
|
174
web/src/views/Midjourney/component/TableRow.js
Normal file
174
web/src/views/Midjourney/component/TableRow.js
Normal file
@ -0,0 +1,174 @@
|
||||
import PropTypes from 'prop-types';
|
||||
|
||||
import { useState } from 'react';
|
||||
import {
|
||||
TableRow,
|
||||
TableCell,
|
||||
Button,
|
||||
Dialog,
|
||||
DialogActions,
|
||||
DialogContent,
|
||||
ButtonGroup,
|
||||
Popover,
|
||||
MenuItem,
|
||||
MenuList,
|
||||
Tooltip
|
||||
} from '@mui/material';
|
||||
|
||||
import { timestamp2string, copy } from 'utils/common';
|
||||
import Label from 'ui-component/Label';
|
||||
import { ACTION_TYPE, CODE_TYPE, STATUS_TYPE } from '../type/Type';
|
||||
import { IconCaretDownFilled, IconCopy, IconDownload, IconExternalLink } from '@tabler/icons-react';
|
||||
|
||||
function renderType(types, type) {
|
||||
const typeOption = types[type];
|
||||
if (typeOption) {
|
||||
return (
|
||||
<Label variant="filled" color={typeOption.color}>
|
||||
{' '}
|
||||
{typeOption.text}{' '}
|
||||
</Label>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<Label variant="filled" color="error">
|
||||
{' '}
|
||||
未知{' '}
|
||||
</Label>
|
||||
);
|
||||
}
|
||||
}
|
||||
async function downloadImage(url, filename) {
|
||||
const response = await fetch(url);
|
||||
const blob = await response.blob();
|
||||
const blobUrl = URL.createObjectURL(blob);
|
||||
const link = document.createElement('a');
|
||||
link.href = blobUrl;
|
||||
link.download = filename;
|
||||
link.click();
|
||||
URL.revokeObjectURL(blobUrl);
|
||||
}
|
||||
|
||||
function TruncatedText(text) {
|
||||
const truncatedText = text.length > 30 ? text.substring(0, 100) + '...' : text;
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
placement="top"
|
||||
title={text}
|
||||
onClick={() => {
|
||||
copy(text, '');
|
||||
}}
|
||||
>
|
||||
<span>{truncatedText}</span>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
|
||||
export default function LogTableRow({ item, userIsAdmin }) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [menuOpen, setMenuOpen] = useState(null);
|
||||
const handleClickOpen = () => {
|
||||
setOpen(true);
|
||||
};
|
||||
|
||||
const handleClose = () => {
|
||||
setOpen(false);
|
||||
};
|
||||
|
||||
const handleOpenMenu = (event) => {
|
||||
setMenuOpen(event.currentTarget);
|
||||
};
|
||||
|
||||
const handleCloseMenu = () => {
|
||||
setMenuOpen(null);
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<TableRow tabIndex={item.id}>
|
||||
<TableCell>{item.mj_id}</TableCell>
|
||||
<TableCell>{timestamp2string(item.submit_time / 1000)}</TableCell>
|
||||
|
||||
{userIsAdmin && <TableCell>{item.channel_id || ''}</TableCell>}
|
||||
{userIsAdmin && <TableCell>{item.user_id || ''}</TableCell>}
|
||||
|
||||
<TableCell>{renderType(ACTION_TYPE, item.action)}</TableCell>
|
||||
{userIsAdmin && <TableCell>{renderType(CODE_TYPE, item.code)}</TableCell>}
|
||||
{userIsAdmin && <TableCell>{renderType(STATUS_TYPE, item.status)}</TableCell>}
|
||||
<TableCell>{item.progress}</TableCell>
|
||||
<TableCell>
|
||||
{item.image_url == '' ? (
|
||||
'无'
|
||||
) : (
|
||||
<ButtonGroup size="small" aria-label="split button">
|
||||
<Button color="primary" onClick={handleClickOpen}>
|
||||
显示
|
||||
</Button>
|
||||
<Button onClick={handleOpenMenu}>
|
||||
<IconCaretDownFilled size={'16px'} />
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>{TruncatedText(item.prompt)}</TableCell>
|
||||
<TableCell>{TruncatedText(item.prompt_en)}</TableCell>
|
||||
<TableCell>{TruncatedText(item.fail_reason)}</TableCell>
|
||||
</TableRow>
|
||||
<Dialog open={open} onClose={handleClose}>
|
||||
<DialogContent>
|
||||
<img src={item.image_url} alt="item" style={{ maxWidth: '100%', maxHeight: '100%' }} />
|
||||
</DialogContent>
|
||||
<DialogActions>
|
||||
<Button onClick={handleClose} color="primary">
|
||||
关闭
|
||||
</Button>
|
||||
</DialogActions>
|
||||
</Dialog>
|
||||
|
||||
<Popover
|
||||
open={!!menuOpen}
|
||||
anchorEl={menuOpen}
|
||||
onClose={handleCloseMenu}
|
||||
anchorOrigin={{ vertical: 'top', horizontal: 'left' }}
|
||||
transformOrigin={{ vertical: 'top', horizontal: 'right' }}
|
||||
PaperProps={{
|
||||
sx: { width: 140 }
|
||||
}}
|
||||
>
|
||||
<MenuList>
|
||||
<MenuItem
|
||||
onClick={() => {
|
||||
handleCloseMenu();
|
||||
copy(item.image_url, '图片地址');
|
||||
}}
|
||||
>
|
||||
<IconCopy style={{ marginRight: '16px' }} />
|
||||
复制地址
|
||||
</MenuItem>
|
||||
|
||||
<MenuItem
|
||||
onClick={async () => {
|
||||
handleCloseMenu();
|
||||
await downloadImage(item.image_url, item.mj_id + '.png');
|
||||
}}
|
||||
>
|
||||
<IconDownload style={{ marginRight: '16px' }} /> 下载图片{' '}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
onClick={() => {
|
||||
handleCloseMenu();
|
||||
}}
|
||||
>
|
||||
<IconExternalLink style={{ marginRight: '16px' }} /> 新窗口打开{' '}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Popover>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
LogTableRow.propTypes = {
|
||||
item: PropTypes.object,
|
||||
userIsAdmin: PropTypes.bool
|
||||
};
|
113
web/src/views/Midjourney/component/TableToolBar.js
Normal file
113
web/src/views/Midjourney/component/TableToolBar.js
Normal file
@ -0,0 +1,113 @@
|
||||
import PropTypes from 'prop-types';
|
||||
import { useTheme } from '@mui/material/styles';
|
||||
import { IconBroadcast, IconCalendarEvent } from '@tabler/icons-react';
|
||||
import { InputAdornment, OutlinedInput, Stack, FormControl, InputLabel } from '@mui/material';
|
||||
import { LocalizationProvider, DateTimePicker } from '@mui/x-date-pickers';
|
||||
import { AdapterDayjs } from '@mui/x-date-pickers/AdapterDayjs';
|
||||
import dayjs from 'dayjs';
|
||||
require('dayjs/locale/zh-cn');
|
||||
// ----------------------------------------------------------------------
|
||||
|
||||
export default function TableToolBar({ filterName, handleFilterName, userIsAdmin }) {
|
||||
const theme = useTheme();
|
||||
const grey500 = theme.palette.grey[500];
|
||||
|
||||
return (
|
||||
<>
|
||||
<Stack direction={{ xs: 'column', sm: 'row' }} spacing={{ xs: 3, sm: 2, md: 4 }} padding={'24px'} paddingBottom={'0px'}>
|
||||
{userIsAdmin && (
|
||||
<FormControl>
|
||||
<InputLabel htmlFor="channel-channel_id-label">渠道ID</InputLabel>
|
||||
<OutlinedInput
|
||||
id="channel_id"
|
||||
name="channel_id"
|
||||
sx={{
|
||||
minWidth: '100%'
|
||||
}}
|
||||
label="渠道ID"
|
||||
value={filterName.channel_id}
|
||||
onChange={handleFilterName}
|
||||
placeholder="渠道ID"
|
||||
startAdornment={
|
||||
<InputAdornment position="start">
|
||||
<IconBroadcast stroke={1.5} size="20px" color={grey500} />
|
||||
</InputAdornment>
|
||||
}
|
||||
/>
|
||||
</FormControl>
|
||||
)}
|
||||
<FormControl>
|
||||
<InputLabel htmlFor="channel-mj_id-label">任务ID</InputLabel>
|
||||
<OutlinedInput
|
||||
id="mj_id"
|
||||
name="mj_id"
|
||||
sx={{
|
||||
minWidth: '100%'
|
||||
}}
|
||||
label="任务ID"
|
||||
value={filterName.mj_id}
|
||||
onChange={handleFilterName}
|
||||
placeholder="任务ID"
|
||||
startAdornment={
|
||||
<InputAdornment position="start">
|
||||
<IconCalendarEvent stroke={1.5} size="20px" color={grey500} />
|
||||
</InputAdornment>
|
||||
}
|
||||
/>
|
||||
</FormControl>
|
||||
|
||||
<FormControl>
|
||||
<LocalizationProvider dateAdapter={AdapterDayjs} adapterLocale={'zh-cn'}>
|
||||
<DateTimePicker
|
||||
label="起始时间"
|
||||
ampm={false}
|
||||
name="start_timestamp"
|
||||
value={filterName.start_timestamp === 0 ? null : dayjs.unix(filterName.start_timestamp / 1000)}
|
||||
onChange={(value) => {
|
||||
if (value === null) {
|
||||
handleFilterName({ target: { name: 'start_timestamp', value: 0 } });
|
||||
return;
|
||||
}
|
||||
handleFilterName({ target: { name: 'start_timestamp', value: value.unix() * 1000 } });
|
||||
}}
|
||||
slotProps={{
|
||||
actionBar: {
|
||||
actions: ['clear', 'today', 'accept']
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</LocalizationProvider>
|
||||
</FormControl>
|
||||
|
||||
<FormControl>
|
||||
<LocalizationProvider dateAdapter={AdapterDayjs} adapterLocale={'zh-cn'}>
|
||||
<DateTimePicker
|
||||
label="结束时间"
|
||||
name="end_timestamp"
|
||||
ampm={false}
|
||||
value={filterName.end_timestamp === 0 ? null : dayjs.unix(filterName.end_timestamp / 1000)}
|
||||
onChange={(value) => {
|
||||
if (value === null) {
|
||||
handleFilterName({ target: { name: 'end_timestamp', value: 0 } });
|
||||
return;
|
||||
}
|
||||
handleFilterName({ target: { name: 'end_timestamp', value: value.unix() * 1000 } });
|
||||
}}
|
||||
slotProps={{
|
||||
actionBar: {
|
||||
actions: ['clear', 'today', 'accept']
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</LocalizationProvider>
|
||||
</FormControl>
|
||||
</Stack>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
TableToolBar.propTypes = {
|
||||
filterName: PropTypes.object,
|
||||
handleFilterName: PropTypes.func,
|
||||
userIsAdmin: PropTypes.bool
|
||||
};
|
247
web/src/views/Midjourney/index.js
Normal file
247
web/src/views/Midjourney/index.js
Normal file
@ -0,0 +1,247 @@
|
||||
import { useState, useEffect, useCallback } from 'react';
|
||||
import { showError } from 'utils/common';
|
||||
|
||||
import Table from '@mui/material/Table';
|
||||
import TableBody from '@mui/material/TableBody';
|
||||
import TableContainer from '@mui/material/TableContainer';
|
||||
import PerfectScrollbar from 'react-perfect-scrollbar';
|
||||
import TablePagination from '@mui/material/TablePagination';
|
||||
import LinearProgress from '@mui/material/LinearProgress';
|
||||
import ButtonGroup from '@mui/material/ButtonGroup';
|
||||
import Toolbar from '@mui/material/Toolbar';
|
||||
|
||||
import { Button, Card, Stack, Container, Typography, Box } from '@mui/material';
|
||||
import LogTableRow from './component/TableRow';
|
||||
import KeywordTableHead from 'ui-component/TableHead';
|
||||
import TableToolBar from './component/TableToolBar';
|
||||
import { API } from 'utils/api';
|
||||
import { isAdmin } from 'utils/common';
|
||||
import { ITEMS_PER_PAGE } from 'constants';
|
||||
import { IconRefresh, IconSearch } from '@tabler/icons-react';
|
||||
import dayjs from 'dayjs';
|
||||
|
||||
export default function Log() {
|
||||
const originalKeyword = {
|
||||
p: 0,
|
||||
channel_id: '',
|
||||
mj_id: '',
|
||||
start_timestamp: 0,
|
||||
end_timestamp: dayjs().unix() * 1000 + 3600
|
||||
};
|
||||
|
||||
const [page, setPage] = useState(0);
|
||||
const [order, setOrder] = useState('desc');
|
||||
const [orderBy, setOrderBy] = useState('id');
|
||||
const [rowsPerPage, setRowsPerPage] = useState(ITEMS_PER_PAGE);
|
||||
const [listCount, setListCount] = useState(0);
|
||||
const [searching, setSearching] = useState(false);
|
||||
const [toolBarValue, setToolBarValue] = useState(originalKeyword);
|
||||
const [searchKeyword, setSearchKeyword] = useState(originalKeyword);
|
||||
const [refreshFlag, setRefreshFlag] = useState(false);
|
||||
|
||||
const [logs, setLogs] = useState([]);
|
||||
const userIsAdmin = isAdmin();
|
||||
|
||||
const handleSort = (event, id) => {
|
||||
const isAsc = orderBy === id && order === 'asc';
|
||||
if (id !== '') {
|
||||
setOrder(isAsc ? 'desc' : 'asc');
|
||||
setOrderBy(id);
|
||||
}
|
||||
};
|
||||
|
||||
const handleChangePage = (event, newPage) => {
|
||||
setPage(newPage);
|
||||
};
|
||||
|
||||
const handleChangeRowsPerPage = (event) => {
|
||||
setPage(0);
|
||||
setRowsPerPage(parseInt(event.target.value, 10));
|
||||
};
|
||||
|
||||
const searchLogs = async () => {
|
||||
setPage(0);
|
||||
setSearchKeyword(toolBarValue);
|
||||
};
|
||||
|
||||
const handleToolBarValue = (event) => {
|
||||
setToolBarValue({ ...toolBarValue, [event.target.name]: event.target.value });
|
||||
};
|
||||
|
||||
const fetchData = useCallback(
|
||||
async (page, rowsPerPage, keyword, order, orderBy) => {
|
||||
setSearching(true);
|
||||
try {
|
||||
if (orderBy) {
|
||||
orderBy = order === 'desc' ? '-' + orderBy : orderBy;
|
||||
}
|
||||
const url = userIsAdmin ? '/api/mj/' : '/api/mj/self/';
|
||||
if (!userIsAdmin) {
|
||||
delete keyword.channel_id;
|
||||
}
|
||||
|
||||
const res = await API.get(url, {
|
||||
params: {
|
||||
page: page + 1,
|
||||
size: rowsPerPage,
|
||||
order: orderBy,
|
||||
...keyword
|
||||
}
|
||||
});
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
setListCount(data.total_count);
|
||||
setLogs(data.data);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
}
|
||||
setSearching(false);
|
||||
},
|
||||
[userIsAdmin]
|
||||
);
|
||||
|
||||
// 处理刷新
|
||||
const handleRefresh = async () => {
|
||||
setOrderBy('id');
|
||||
setOrder('desc');
|
||||
setToolBarValue(originalKeyword);
|
||||
setSearchKeyword(originalKeyword);
|
||||
setRefreshFlag(!refreshFlag);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
fetchData(page, rowsPerPage, searchKeyword, order, orderBy);
|
||||
}, [page, rowsPerPage, searchKeyword, order, orderBy, fetchData, refreshFlag]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Stack direction="row" alignItems="center" justifyContent="space-between" mb={5}>
|
||||
<Typography variant="h4">Midjourney</Typography>
|
||||
</Stack>
|
||||
<Card>
|
||||
<Box component="form" noValidate>
|
||||
<TableToolBar filterName={toolBarValue} handleFilterName={handleToolBarValue} userIsAdmin={userIsAdmin} />
|
||||
</Box>
|
||||
<Toolbar
|
||||
sx={{
|
||||
textAlign: 'right',
|
||||
height: 50,
|
||||
display: 'flex',
|
||||
justifyContent: 'space-between',
|
||||
p: (theme) => theme.spacing(0, 1, 0, 3)
|
||||
}}
|
||||
>
|
||||
<Container>
|
||||
<ButtonGroup variant="outlined" aria-label="outlined small primary button group">
|
||||
<Button onClick={handleRefresh} startIcon={<IconRefresh width={'18px'} />}>
|
||||
刷新/清除搜索条件
|
||||
</Button>
|
||||
|
||||
<Button onClick={searchLogs} startIcon={<IconSearch width={'18px'} />}>
|
||||
搜索
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
</Container>
|
||||
</Toolbar>
|
||||
{searching && <LinearProgress />}
|
||||
<PerfectScrollbar component="div">
|
||||
<TableContainer sx={{ overflow: 'unset' }}>
|
||||
<Table sx={{ minWidth: 800 }}>
|
||||
<KeywordTableHead
|
||||
order={order}
|
||||
orderBy={orderBy}
|
||||
onRequestSort={handleSort}
|
||||
headLabel={[
|
||||
{
|
||||
id: 'mj_id',
|
||||
label: '任务ID',
|
||||
disableSort: false
|
||||
},
|
||||
{
|
||||
id: 'submit_time',
|
||||
label: '提交时间',
|
||||
disableSort: false
|
||||
},
|
||||
{
|
||||
id: 'channel_id',
|
||||
label: '渠道',
|
||||
disableSort: false,
|
||||
hide: !userIsAdmin
|
||||
},
|
||||
{
|
||||
id: 'user_id',
|
||||
label: '用户',
|
||||
disableSort: false,
|
||||
hide: !userIsAdmin
|
||||
},
|
||||
{
|
||||
id: 'action',
|
||||
label: '类型',
|
||||
disableSort: false
|
||||
},
|
||||
{
|
||||
id: 'code',
|
||||
label: '提交结果',
|
||||
disableSort: false,
|
||||
hide: !userIsAdmin
|
||||
},
|
||||
{
|
||||
id: 'status',
|
||||
label: '任务状态',
|
||||
disableSort: false,
|
||||
hide: !userIsAdmin
|
||||
},
|
||||
{
|
||||
id: 'progress',
|
||||
label: '进度',
|
||||
disableSort: true
|
||||
},
|
||||
{
|
||||
id: 'image_url',
|
||||
label: '结果图片',
|
||||
disableSort: true,
|
||||
width: '120px'
|
||||
},
|
||||
{
|
||||
id: 'prompt',
|
||||
label: 'Prompt',
|
||||
disableSort: true
|
||||
},
|
||||
{
|
||||
id: 'prompt_en',
|
||||
label: 'PromptEn',
|
||||
disableSort: true
|
||||
},
|
||||
{
|
||||
id: 'fail_reason',
|
||||
label: '失败原因',
|
||||
disableSort: true
|
||||
}
|
||||
]}
|
||||
/>
|
||||
<TableBody>
|
||||
{logs.map((row, index) => (
|
||||
<LogTableRow item={row} key={`${row.id}_${index}`} userIsAdmin={userIsAdmin} />
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
</PerfectScrollbar>
|
||||
<TablePagination
|
||||
page={page}
|
||||
component="div"
|
||||
count={listCount}
|
||||
rowsPerPage={rowsPerPage}
|
||||
onPageChange={handleChangePage}
|
||||
rowsPerPageOptions={[10, 25, 30]}
|
||||
onRowsPerPageChange={handleChangeRowsPerPage}
|
||||
showFirstButton
|
||||
showLastButton
|
||||
/>
|
||||
</Card>
|
||||
</>
|
||||
);
|
||||
}
|
33
web/src/views/Midjourney/type/Type.js
Normal file
33
web/src/views/Midjourney/type/Type.js
Normal file
@ -0,0 +1,33 @@
|
||||
export const ACTION_TYPE = {
|
||||
IMAGINE: { value: 'IMAGINE', text: '绘图', color: 'primary' },
|
||||
UPSCALE: { value: 'UPSCALE', text: '放大', color: 'orange' },
|
||||
VARIATION: { value: 'VARIATION', text: '变换', color: 'default' },
|
||||
HIGH_VARIATION: { value: 'HIGH_VARIATION', text: '强变换', color: 'default' },
|
||||
LOW_VARIATION: { value: 'LOW_VARIATION', text: '弱变换', color: 'default' },
|
||||
PAN: { value: 'PAN', text: '平移', color: 'secondary' },
|
||||
DESCRIBE: { value: 'DESCRIBE', text: '图生文', color: 'secondary' },
|
||||
BLEND: { value: 'BLEND', text: '图混合', color: 'secondary' },
|
||||
SHORTEN: { value: 'SHORTEN', text: '缩词', color: 'secondary' },
|
||||
REROLL: { value: 'REROLL', text: '重绘', color: 'secondary' },
|
||||
INPAINT: { value: 'INPAINT', text: '局部重绘-提交', color: 'secondary' },
|
||||
ZOOM: { value: 'ZOOM', text: '变焦', color: 'secondary' },
|
||||
CUSTOM_ZOOM: { value: 'CUSTOM_ZOOM', text: '自定义变焦-提交', color: 'secondary' },
|
||||
MODAL: { value: 'MODAL', text: '窗口处理', color: 'secondary' },
|
||||
SWAP_FACE: { value: 'SWAP_FACE', text: '换脸', color: 'secondary' }
|
||||
};
|
||||
|
||||
export const CODE_TYPE = {
|
||||
1: { value: 1, text: '已提交', color: 'primary' },
|
||||
21: { value: 21, text: '等待中', color: 'orange' },
|
||||
22: { value: 22, text: '重复提交', color: 'default' },
|
||||
0: { value: 0, text: '未提交', color: 'default' }
|
||||
};
|
||||
|
||||
export const STATUS_TYPE = {
|
||||
SUCCESS: { value: 'SUCCESS', text: '成功', color: 'success' },
|
||||
NOT_START: { value: 'NOT_START', text: '未启动', color: 'default' },
|
||||
SUBMITTED: { value: 'SUBMITTED', text: '队列中', color: 'secondary' },
|
||||
IN_PROGRESS: { value: 'IN_PROGRESS', text: '执行中', color: 'primary' },
|
||||
FAILURE: { value: 'FAILURE', text: '失败', color: 'orange' },
|
||||
MODAL: { value: 'MODAL', text: '窗口等待', color: 'default' }
|
||||
};
|
@ -29,7 +29,8 @@ const OperationSetting = () => {
|
||||
DisplayTokenStatEnabled: '',
|
||||
ApproximateTokenEnabled: '',
|
||||
RetryTimes: 0,
|
||||
RetryCooldownSeconds: 0
|
||||
RetryCooldownSeconds: 0,
|
||||
MjNotifyEnabled: ''
|
||||
});
|
||||
const [originInputs, setOriginInputs] = useState({});
|
||||
let [loading, setLoading] = useState(false);
|
||||
@ -278,6 +279,22 @@ const OperationSetting = () => {
|
||||
</Button>
|
||||
</Stack>
|
||||
</SubCard>
|
||||
<SubCard title="其他设置">
|
||||
<Stack justifyContent="flex-start" alignItems="flex-start" spacing={2}>
|
||||
<Stack
|
||||
direction={{ sm: 'column', md: 'row' }}
|
||||
spacing={{ xs: 3, sm: 2, md: 4 }}
|
||||
justifyContent="flex-start"
|
||||
alignItems="flex-start"
|
||||
>
|
||||
<FormControlLabel
|
||||
sx={{ marginLeft: '0px' }}
|
||||
label="Midjourney 允许回调(会泄露服务器ip地址)"
|
||||
control={<Checkbox checked={inputs.MjNotifyEnabled === 'true'} onChange={handleInputChange} name="MjNotifyEnabled" />}
|
||||
/>
|
||||
</Stack>
|
||||
</Stack>
|
||||
</SubCard>
|
||||
<SubCard title="日志设置">
|
||||
<Stack direction="column" justifyContent="flex-start" alignItems="flex-start" spacing={2}>
|
||||
<FormControlLabel
|
||||
|
Loading…
Reference in New Issue
Block a user